1use std::collections::HashMap;
37use std::future::Future;
38use std::time::Duration;
39
40use bytes::{BufMut, Bytes, BytesMut};
41use tokio::sync::mpsc;
42use tokio::time::Instant;
43use tokio_util::sync::CancellationToken;
44
45use ndn_packet::Name;
46use ndn_packet::encode::InterestBuilder;
47
48use crate::protocol::{SyncHandle, SyncUpdate};
49use crate::svs::SvsNode;
50
51const TLV_STATE_VECTOR: u64 = 201; const TLV_SV_ENTRY: u64 = 202; const TLV_SV_SEQ_NO: u64 = 204; const TLV_MAPPING_DATA: u64 = 205; const TLV_MAPPING_ENTRY: u64 = 206; const TLV_NDN_NAME: u64 = 7; #[derive(Clone, Debug)]
77pub struct RetryPolicy {
78 pub max_retries: u32,
80 pub base_delay: Duration,
82 pub backoff_factor: f64,
84}
85
86impl Default for RetryPolicy {
87 fn default() -> Self {
88 Self {
89 max_retries: 4,
90 base_delay: Duration::from_secs(1),
91 backoff_factor: 2.0,
92 }
93 }
94}
95
96pub async fn fetch_with_retry<F, Fut, T, E>(policy: RetryPolicy, mut fetch: F) -> Result<T, E>
104where
105 F: FnMut(u32) -> Fut,
106 Fut: Future<Output = Result<T, E>>,
107{
108 let mut delay = policy.base_delay;
109 for attempt in 0..=policy.max_retries {
110 match fetch(attempt).await {
111 Ok(v) => return Ok(v),
112 Err(e) => {
113 if attempt == policy.max_retries {
114 return Err(e);
115 }
116 tokio::time::sleep(delay).await;
117 delay = Duration::from_secs_f64(
118 (delay.as_secs_f64() * policy.backoff_factor).min(60.0),
119 );
120 }
121 }
122 }
123 unreachable!()
124}
125
126#[derive(Clone, Debug)]
130pub struct SvsConfig {
131 pub sync_interval: Duration,
133 pub jitter_ms: u64,
135 pub channel_capacity: usize,
137 pub retry_policy: RetryPolicy,
142}
143
144impl Default for SvsConfig {
145 fn default() -> Self {
146 Self {
147 sync_interval: Duration::from_secs(30),
148 jitter_ms: 3000,
149 channel_capacity: 256,
150 retry_policy: RetryPolicy::default(),
151 }
152 }
153}
154
155pub fn join_svs_group(
176 group: Name,
177 local_name: Name,
178 send: mpsc::Sender<Bytes>,
179 recv: mpsc::Receiver<Bytes>,
180 config: SvsConfig,
181) -> SyncHandle {
182 let cancel = CancellationToken::new();
183 let (update_tx, update_rx) = mpsc::channel(config.channel_capacity);
184 let (publish_tx, publish_rx) = mpsc::channel(64);
185
186 let task_cancel = cancel.clone();
187 tokio::spawn(async move {
188 svs_task(
189 group,
190 local_name,
191 send,
192 recv,
193 publish_rx,
194 update_tx,
195 config,
196 task_cancel,
197 )
198 .await;
199 });
200
201 SyncHandle::new(update_rx, publish_tx, cancel)
202}
203
204#[allow(clippy::too_many_arguments)]
207async fn svs_task(
208 group: Name,
209 local_name: Name,
210 send: mpsc::Sender<Bytes>,
211 mut recv: mpsc::Receiver<Bytes>,
212 mut publish_rx: mpsc::Receiver<(Name, Option<Bytes>)>,
213 update_tx: mpsc::Sender<SyncUpdate>,
214 config: SvsConfig,
215 cancel: CancellationToken,
216) {
217 let node = SvsNode::new(&local_name);
218 let local_key = node.local_key().to_string();
219
220 let mut current_mapping: Option<Bytes> = None;
223
224 let mut next_send = Instant::now() + jitter_interval(&config);
226
227 loop {
228 tokio::select! {
229 _ = cancel.cancelled() => break,
230
231 _ = tokio::time::sleep_until(next_send) => {
232 send_sync_interest(&group, &node, &send, current_mapping.clone()).await;
233 next_send = Instant::now() + jitter_interval(&config);
234 }
235
236 Some(raw) = recv.recv() => {
237 if let Some((remote_sv, peer_mappings)) = parse_sync_interest(&group, &raw) {
238 let snapshot = node.snapshot().await;
240 let covers_local = remote_covers_local(&snapshot, &remote_sv);
241
242 let gaps = node.merge(&remote_sv).await;
243 for (peer_key, low, high) in gaps {
244 if peer_key == local_key { continue; }
245 let mapping = peer_mappings.get(&peer_key).cloned();
247 let update = SyncUpdate {
248 publisher: peer_key.clone(),
249 name: group.clone().append(&peer_key),
250 low_seq: low,
251 high_seq: high,
252 mapping,
253 };
254 let _ = update_tx.send(update).await;
255 }
256
257 if covers_local {
258 next_send = Instant::now() + jitter_interval(&config);
260 }
261 }
262 }
263
264 Some((pub_name, mapping)) = publish_rx.recv() => {
265 current_mapping = mapping;
267 node.advance().await;
268 let _ = pub_name; send_sync_interest(&group, &node, &send, current_mapping.clone()).await;
270 next_send = Instant::now() + jitter_interval(&config);
271 }
272 }
273 }
274}
275
276async fn send_sync_interest(
280 group: &Name,
281 node: &SvsNode,
282 send: &mpsc::Sender<Bytes>,
283 mapping: Option<Bytes>,
284) {
285 let snapshot = node.snapshot().await;
286 let mut app_params = encode_state_vector(&snapshot);
287
288 if let Some(mapping_bytes) = mapping {
289 let local_key = node.local_key();
290 let local_name: Name = local_key.parse().unwrap_or_else(|_| Name::root());
291 let seq = node.local_seq().await;
292 let mapping_tlv = encode_mapping_data(&local_name, seq, &mapping_bytes);
293 app_params.extend_from_slice(&mapping_tlv);
294 }
295
296 let sync_name = group.clone().append("svs");
297 let wire = InterestBuilder::new(sync_name)
298 .lifetime(Duration::from_millis(1000))
299 .app_parameters(app_params)
300 .build();
301 let _ = send.send(wire).await;
302}
303
304fn jitter_interval(config: &SvsConfig) -> Duration {
306 let jitter = Duration::from_millis(fastrand::u64(0..=config.jitter_ms));
307 config.sync_interval + jitter
308}
309
310fn remote_covers_local(
312 local_snapshot: &[crate::svs::StateVectorEntry],
313 remote_sv: &[(String, u64)],
314) -> bool {
315 let remote_map: HashMap<&str, u64> = remote_sv.iter().map(|(k, v)| (k.as_str(), *v)).collect();
316 local_snapshot
317 .iter()
318 .all(|e| remote_map.get(e.node.as_str()).copied().unwrap_or(0) >= e.seq)
319}
320
321fn write_varnumber(buf: &mut BytesMut, n: u64) {
324 if n < 0xFD {
325 buf.put_u8(n as u8);
326 } else if n <= 0xFFFF {
327 buf.put_u8(0xFD);
328 buf.put_u16(n as u16);
329 } else if n <= 0xFFFF_FFFF {
330 buf.put_u8(0xFE);
331 buf.put_u32(n as u32);
332 } else {
333 buf.put_u8(0xFF);
334 buf.put_u64(n);
335 }
336}
337
338fn write_tlv(buf: &mut BytesMut, typ: u64, value: &[u8]) {
339 write_varnumber(buf, typ);
340 write_varnumber(buf, value.len() as u64);
341 buf.put_slice(value);
342}
343
344fn encode_nni(v: u64) -> Vec<u8> {
346 if v <= 0xFF {
347 vec![v as u8]
348 } else if v <= 0xFFFF {
349 (v as u16).to_be_bytes().to_vec()
350 } else if v <= 0xFFFF_FFFF {
351 (v as u32).to_be_bytes().to_vec()
352 } else {
353 v.to_be_bytes().to_vec()
354 }
355}
356
357fn encode_name_tlv(name: &Name) -> Vec<u8> {
359 let mut inner = BytesMut::new();
360 for comp in name.components() {
361 write_tlv(&mut inner, comp.typ, &comp.value);
362 }
363 let mut outer = BytesMut::new();
364 write_tlv(&mut outer, TLV_NDN_NAME, &inner);
365 outer.to_vec()
366}
367
368use crate::svs::StateVectorEntry;
371
372fn encode_state_vector(entries: &[StateVectorEntry]) -> Vec<u8> {
374 let mut sv_inner = BytesMut::new();
375 for e in entries {
376 let name: Name = e.node.parse().unwrap_or_else(|_| Name::root());
377 let name_bytes = encode_name_tlv(&name);
378 let seq_bytes = encode_nni(e.seq);
379
380 let mut entry_inner = BytesMut::new();
381 entry_inner.put_slice(&name_bytes);
382 write_tlv(&mut entry_inner, TLV_SV_SEQ_NO, &seq_bytes);
383
384 write_tlv(&mut sv_inner, TLV_SV_ENTRY, &entry_inner);
385 }
386
387 let mut buf = BytesMut::new();
388 write_tlv(&mut buf, TLV_STATE_VECTOR, &sv_inner);
389 buf.to_vec()
390}
391
392fn encode_mapping_data(node_name: &Name, seq: u64, app_data: &[u8]) -> Vec<u8> {
402 let name_bytes = encode_name_tlv(node_name);
403 let seq_bytes = encode_nni(seq);
404
405 let mut entry_inner = BytesMut::new();
406 entry_inner.put_slice(&name_bytes);
407 write_tlv(&mut entry_inner, TLV_SV_SEQ_NO, &seq_bytes);
408 entry_inner.put_slice(app_data); let mut mapping_inner = BytesMut::new();
411 write_tlv(&mut mapping_inner, TLV_MAPPING_ENTRY, &entry_inner);
412
413 let mut buf = BytesMut::new();
414 write_tlv(&mut buf, TLV_MAPPING_DATA, &mapping_inner);
415 buf.to_vec()
416}
417
418fn read_tlv(cursor: &[u8]) -> Option<(u64, &[u8], &[u8])> {
421 let (typ, rest) = read_varnumber(cursor)?;
422 let (len, rest) = read_varnumber(rest)?;
423 let len = len as usize;
424 if rest.len() < len {
425 return None;
426 }
427 Some((typ, &rest[..len], &rest[len..]))
428}
429
430fn read_varnumber(cursor: &[u8]) -> Option<(u64, &[u8])> {
431 let (&first, rest) = cursor.split_first()?;
432 match first {
433 0xFF => {
434 if rest.len() < 8 {
435 return None;
436 }
437 let v = u64::from_be_bytes(rest[..8].try_into().ok()?);
438 Some((v, &rest[8..]))
439 }
440 0xFE => {
441 if rest.len() < 4 {
442 return None;
443 }
444 let v = u32::from_be_bytes(rest[..4].try_into().ok()?) as u64;
445 Some((v, &rest[4..]))
446 }
447 0xFD => {
448 if rest.len() < 2 {
449 return None;
450 }
451 let v = u16::from_be_bytes(rest[..2].try_into().ok()?) as u64;
452 Some((v, &rest[2..]))
453 }
454 b => Some((b as u64, rest)),
455 }
456}
457
458fn decode_nni(bytes: &[u8]) -> u64 {
459 match bytes.len() {
460 0 => 0,
461 1 => bytes[0] as u64,
462 2 => u16::from_be_bytes(bytes.try_into().unwrap_or_default()) as u64,
463 4 => u32::from_be_bytes(bytes.try_into().unwrap_or_default()) as u64,
464 8 => u64::from_be_bytes(bytes.try_into().unwrap_or_default()),
465 _ => {
466 let mut arr = [0u8; 8];
467 let start = 8usize.saturating_sub(bytes.len());
468 let copy_len = bytes.len().min(8);
469 arr[start..start + copy_len].copy_from_slice(&bytes[..copy_len]);
470 u64::from_be_bytes(arr)
471 }
472 }
473}
474
475fn decode_name_key(name_tlv: &[u8]) -> Option<String> {
477 let (typ, value, _) = read_tlv(name_tlv)?;
478 if typ != TLV_NDN_NAME {
479 return None;
480 }
481 let name = Name::decode(Bytes::copy_from_slice(value)).ok()?;
482 Some(name.to_string())
483}
484
485fn decode_state_vector(sv_tlv: &[u8]) -> Option<Vec<(String, u64)>> {
487 let (typ, mut body, _) = read_tlv(sv_tlv)?;
488 if typ != TLV_STATE_VECTOR {
489 return None;
490 }
491
492 let mut entries = Vec::new();
493 while !body.is_empty() {
494 let (entry_typ, mut entry_body, rest) = read_tlv(body)?;
495 body = rest;
496 if entry_typ != TLV_SV_ENTRY {
497 continue;
498 }
499
500 let (name_typ, name_val, after_name) = read_tlv(entry_body)?;
502 if name_typ != TLV_NDN_NAME {
503 continue;
504 }
505 let mut name_bytes = BytesMut::new();
506 write_tlv(&mut name_bytes, name_typ, name_val);
507 let Some(node_key) = decode_name_key(&name_bytes) else {
508 continue;
509 };
510
511 entry_body = after_name;
512
513 let (seq_typ, seq_val, _) = read_tlv(entry_body)?;
515 if seq_typ != TLV_SV_SEQ_NO {
516 continue;
517 }
518 entries.push((node_key, decode_nni(seq_val)));
519 }
520
521 Some(entries)
522}
523
524fn decode_mapping_data(md_tlv: &[u8]) -> HashMap<String, Bytes> {
529 let mut result = HashMap::new();
530 let Some((typ, mut body, _)) = read_tlv(md_tlv) else {
531 return result;
532 };
533 if typ != TLV_MAPPING_DATA {
534 return result;
535 }
536
537 while !body.is_empty() {
538 let Some((entry_typ, mut entry_body, rest)) = read_tlv(body) else {
539 break;
540 };
541 body = rest;
542 if entry_typ != TLV_MAPPING_ENTRY {
543 continue;
544 }
545
546 let Some((name_typ, name_val, after_name)) = read_tlv(entry_body) else {
548 continue;
549 };
550 if name_typ != TLV_NDN_NAME {
551 continue;
552 }
553 let mut name_bytes = BytesMut::new();
554 write_tlv(&mut name_bytes, name_typ, name_val);
555 let Some(node_key) = decode_name_key(&name_bytes) else {
556 continue;
557 };
558
559 entry_body = after_name;
560
561 let Some((seq_typ, _, after_seq)) = read_tlv(entry_body) else {
563 continue;
564 };
565 if seq_typ != TLV_SV_SEQ_NO {
566 continue;
567 }
568
569 let app_data = Bytes::copy_from_slice(after_seq);
571 result.insert(node_key, app_data);
572 }
573
574 result
575}
576
577type ParsedSyncInterest = (Vec<(String, u64)>, HashMap<String, Bytes>);
584
585fn parse_sync_interest(group: &Name, raw: &[u8]) -> Option<ParsedSyncInterest> {
586 let interest = ndn_packet::Interest::decode(Bytes::copy_from_slice(raw)).ok()?;
587 let components = interest.name.components();
588
589 let group_len = group.components().len();
591 if components.len() < group_len + 1 {
592 return None;
593 }
594 if components[group_len].value.as_ref() != b"svs" {
595 return None;
596 }
597
598 let app_params = interest.app_parameters()?;
599
600 let mut sv: Option<Vec<(String, u64)>> = None;
602 let mut mappings: HashMap<String, Bytes> = HashMap::new();
603 let mut cursor: &[u8] = app_params;
604
605 while !cursor.is_empty() {
606 let Some((typ, _value, rest)) = read_tlv(cursor) else {
607 break;
608 };
609 let consumed = cursor.len() - rest.len();
611 let full_tlv = &cursor[..consumed];
612
613 match typ {
614 TLV_STATE_VECTOR => {
615 sv = decode_state_vector(full_tlv);
616 }
617 TLV_MAPPING_DATA => {
618 mappings = decode_mapping_data(full_tlv);
619 }
620 _ => {} }
622
623 cursor = rest;
624 }
625
626 sv.map(|v| (v, mappings))
627}
628
629#[cfg(test)]
630mod tests {
631 use super::*;
632
633 #[test]
634 fn state_vector_roundtrip() {
635 let entries = vec![
636 StateVectorEntry {
637 node: "/alice".to_string(),
638 seq: 5,
639 },
640 StateVectorEntry {
641 node: "/bob".to_string(),
642 seq: 12,
643 },
644 ];
645 let encoded = encode_state_vector(&entries);
646 let decoded = decode_state_vector(&encoded).expect("decode should succeed");
647 assert_eq!(decoded.len(), 2);
648 let alice = decoded.iter().find(|(k, _)| k == "/alice");
649 let bob = decoded.iter().find(|(k, _)| k == "/bob");
650 assert_eq!(alice.map(|(_, s)| *s), Some(5));
651 assert_eq!(bob.map(|(_, s)| *s), Some(12));
652 }
653
654 #[test]
655 fn decode_empty_state_vector() {
656 let entries: Vec<StateVectorEntry> = vec![];
657 let encoded = encode_state_vector(&entries);
658 let decoded = decode_state_vector(&encoded).expect("decode empty sv");
659 assert!(decoded.is_empty());
660 }
661
662 #[test]
663 fn encode_uses_tlv_type_201() {
664 let entries = vec![StateVectorEntry {
665 node: "/n".to_string(),
666 seq: 1,
667 }];
668 let encoded = encode_state_vector(&entries);
669 assert_eq!(encoded[0], 0xC9, "StateVector type must be 201 (0xC9)");
670 }
671
672 #[test]
673 fn mapping_data_roundtrip() {
674 let name: Name = "/alice".parse().unwrap();
675 let app_data = Bytes::from_static(b"hello-mapping");
676 let encoded = encode_mapping_data(&name, 42, &app_data);
677
678 assert_eq!(encoded[0], 0xCD, "MappingData type must be 205 (0xCD)");
679
680 let decoded = decode_mapping_data(&encoded);
681 let got = decoded.get("/alice").cloned().expect("entry for /alice");
682 assert_eq!(got, app_data);
683 }
684
685 #[test]
686 fn mapping_data_multiple_entries_roundtrip() {
687 let a = encode_mapping_data(&"/a".parse().unwrap(), 1, b"data-a");
689 let b = encode_mapping_data(&"/b".parse().unwrap(), 2, b"data-b");
690
691 let da = decode_mapping_data(&a);
694 let db = decode_mapping_data(&b);
695 assert_eq!(da["/a"].as_ref(), b"data-a");
696 assert_eq!(db["/b"].as_ref(), b"data-b");
697 }
698
699 #[test]
700 fn remote_covers_local_true() {
701 let local = vec![
702 StateVectorEntry {
703 node: "/a".to_string(),
704 seq: 3,
705 },
706 StateVectorEntry {
707 node: "/b".to_string(),
708 seq: 1,
709 },
710 ];
711 let remote = vec![("/a".to_string(), 3u64), ("/b".to_string(), 5)];
712 assert!(remote_covers_local(&local, &remote));
713 }
714
715 #[test]
716 fn remote_covers_local_false_when_behind() {
717 let local = vec![StateVectorEntry {
718 node: "/a".to_string(),
719 seq: 5,
720 }];
721 let remote = vec![("/a".to_string(), 3u64)];
722 assert!(!remote_covers_local(&local, &remote));
723 }
724
725 #[test]
726 fn remote_covers_local_false_when_missing_node() {
727 let local = vec![StateVectorEntry {
728 node: "/a".to_string(),
729 seq: 1,
730 }];
731 let remote: Vec<(String, u64)> = vec![];
732 assert!(!remote_covers_local(&local, &remote));
733 }
734
735 #[tokio::test]
736 async fn fetch_with_retry_succeeds_on_first_try() {
737 let result = fetch_with_retry(RetryPolicy::default(), |_attempt| async {
738 Ok::<_, &str>("ok")
739 })
740 .await;
741 assert_eq!(result, Ok("ok"));
742 }
743
744 #[tokio::test]
745 async fn fetch_with_retry_retries_on_failure() {
746 use std::sync::Arc;
747 use std::sync::atomic::{AtomicU32, Ordering};
748
749 let calls = Arc::new(AtomicU32::new(0));
750 let calls2 = calls.clone();
751
752 let policy = RetryPolicy {
753 max_retries: 3,
754 base_delay: Duration::from_millis(1), backoff_factor: 1.0,
756 };
757
758 let result: Result<(), &str> = fetch_with_retry(policy, move |_| {
759 let c = calls2.clone();
760 async move {
761 let n = c.fetch_add(1, Ordering::SeqCst);
762 if n < 2 { Err("fail") } else { Ok(()) }
763 }
764 })
765 .await;
766
767 assert!(result.is_ok());
768 assert_eq!(calls.load(Ordering::SeqCst), 3); }
770
771 #[tokio::test]
772 async fn fetch_with_retry_exhausts_retries() {
773 let policy = RetryPolicy {
774 max_retries: 2,
775 base_delay: Duration::from_millis(1),
776 backoff_factor: 1.0,
777 };
778
779 let result: Result<(), &str> =
780 fetch_with_retry(policy, |_| async { Err("always fail") }).await;
781 assert_eq!(result, Err("always fail"));
782 }
783
784 #[tokio::test]
785 async fn join_and_leave() {
786 let (send_tx, _send_rx) = mpsc::channel(16);
787 let (_recv_tx, recv_rx) = mpsc::channel(16);
788
789 let group: Name = "/test/svs".parse().unwrap();
790 let local: Name = "/test/svs/node-a".parse().unwrap();
791
792 let handle = join_svs_group(group, local, send_tx, recv_rx, SvsConfig::default());
793 handle.leave();
794 }
795
796 #[tokio::test]
797 async fn sync_interest_carries_app_params() {
798 let (send_tx, mut send_rx) = mpsc::channel(16);
799 let (_recv_tx, recv_rx) = mpsc::channel(16);
800
801 let group: Name = "/test/svs".parse().unwrap();
802 let local: Name = "/node-a".parse().unwrap();
803
804 let config = SvsConfig {
805 sync_interval: Duration::from_millis(10),
806 jitter_ms: 0,
807 ..Default::default()
808 };
809
810 let _handle = join_svs_group(group.clone(), local.clone(), send_tx, recv_rx, config);
811
812 let raw = tokio::time::timeout(Duration::from_secs(2), send_rx.recv())
813 .await
814 .expect("timed out")
815 .expect("channel closed");
816
817 let interest = ndn_packet::Interest::decode(raw).expect("decode interest");
818 let ap = interest.app_parameters().expect("must have AppParameters");
819 let sv = decode_state_vector(ap).expect("must decode StateVector");
820 assert!(!sv.is_empty(), "state vector should contain local node");
821 }
822
823 #[tokio::test]
824 async fn sync_interest_carries_mapping_after_publish_with_mapping() {
825 let (send_tx, mut send_rx) = mpsc::channel(16);
826 let (_recv_tx, recv_rx) = mpsc::channel(16);
827
828 let group: Name = "/test/svs".parse().unwrap();
829 let local: Name = "/node-m".parse().unwrap();
830
831 let config = SvsConfig {
832 sync_interval: Duration::from_secs(60), jitter_ms: 0,
834 ..Default::default()
835 };
836
837 let handle = join_svs_group(group.clone(), local.clone(), send_tx, recv_rx, config);
838
839 handle
841 .publish_with_mapping(local.clone(), Bytes::from_static(b"test-mapping"))
842 .await
843 .expect("publish_with_mapping");
844
845 let raw = tokio::time::timeout(Duration::from_secs(2), send_rx.recv())
846 .await
847 .expect("timed out")
848 .expect("channel closed");
849
850 let interest = ndn_packet::Interest::decode(raw).expect("decode interest");
851 let ap = interest.app_parameters().expect("AppParameters present");
852
853 let mut found_mapping = false;
855 let mut cursor: &[u8] = ap;
856 while !cursor.is_empty() {
857 let Some((typ, _val, rest)) = read_tlv(cursor) else {
858 break;
859 };
860 let consumed = cursor.len() - rest.len();
861 if typ == TLV_MAPPING_DATA {
862 let mappings = decode_mapping_data(&cursor[..consumed]);
863 let key = local.to_string();
864 if let Some(data) = mappings.get(&key) {
865 assert_eq!(data.as_ref(), b"test-mapping");
866 found_mapping = true;
867 }
868 }
869 cursor = rest;
870 }
871 assert!(found_mapping, "MappingData TLV not found in AppParameters");
872 }
873}