1use std::collections::{HashMap, VecDeque};
7use std::time::Instant;
8
9use bytes::Bytes;
10
11use ndn_packet::fragment::FRAG_OVERHEAD;
12use ndn_packet::lp::{encode_lp_acks, encode_lp_reliable, extract_acks};
13
14const MAX_PIGGYBACKED_ACKS: usize = 16;
16
17const DEFAULT_MAX_RETRIES: u8 = 1;
21
22const MAX_RETX_PER_TICK: usize = 8;
26
27const MAX_UNACKED: usize = 256;
31
32const RFC6298_INITIAL_RTO_US: u64 = 1_000_000; const RFC6298_MIN_RTO_US: u64 = 200_000; const RFC6298_MAX_RTO_US: u64 = 4_000_000; const RFC6298_GRANULARITY_US: u64 = 100_000; const RFC6298_ALPHA: f64 = 0.125; const RFC6298_BETA: f64 = 0.25; const QUIC_INITIAL_RTO_US: u64 = 333_000; const QUIC_MIN_RTO_US: u64 = 1_000; const QUIC_MAX_RTO_US: u64 = 4_000_000; const QUIC_GRANULARITY_US: u64 = 1_000; #[derive(Debug, Clone, Default)]
59pub enum RtoStrategy {
60 #[default]
62 Rfc6298,
63 Quic,
65 MinRtt {
67 margin_us: u64,
69 },
70 Fixed {
72 rto_us: u64,
74 },
75}
76
77#[derive(Debug, Clone)]
93pub struct ReliabilityConfig {
94 pub rto_strategy: RtoStrategy,
96 pub max_retries: u8,
98 pub max_unacked: usize,
100 pub max_retx_per_tick: usize,
102}
103
104impl Default for ReliabilityConfig {
105 fn default() -> Self {
106 Self {
107 rto_strategy: RtoStrategy::Rfc6298,
108 max_retries: DEFAULT_MAX_RETRIES,
109 max_unacked: MAX_UNACKED,
110 max_retx_per_tick: MAX_RETX_PER_TICK,
111 }
112 }
113}
114
115impl ReliabilityConfig {
116 pub fn local() -> Self {
118 Self {
119 rto_strategy: RtoStrategy::Fixed { rto_us: 1_000 },
120 max_retries: 0,
121 max_unacked: 64,
122 max_retx_per_tick: 4,
123 }
124 }
125
126 pub fn ethernet() -> Self {
128 Self {
129 rto_strategy: RtoStrategy::Quic,
130 max_retries: 1,
131 max_unacked: 256,
132 max_retx_per_tick: 8,
133 }
134 }
135
136 pub fn wifi() -> Self {
138 Self {
139 rto_strategy: RtoStrategy::Rfc6298,
140 max_retries: 3,
141 max_unacked: 512,
142 max_retx_per_tick: 16,
143 }
144 }
145}
146
147struct UnackedEntry {
148 wire: Bytes,
149 first_sent: Instant,
150 last_sent: Instant,
151 retx_count: u8,
152 is_retx: bool,
153}
154
155pub struct LpReliability {
160 next_seq: u64,
161 unacked: HashMap<u64, UnackedEntry>,
162 pending_acks: VecDeque<u64>,
163 srtt_us: f64,
164 rttvar_us: f64,
165 rto_us: u64,
166 min_rtt_us: u64,
167 mtu: usize,
168 max_retries: u8,
169 max_unacked: usize,
170 max_retx_per_tick: usize,
171 rto_strategy: RtoStrategy,
172}
173
174fn initial_rto_for(strategy: &RtoStrategy) -> u64 {
175 match strategy {
176 RtoStrategy::Rfc6298 => RFC6298_INITIAL_RTO_US,
177 RtoStrategy::Quic => QUIC_INITIAL_RTO_US,
178 RtoStrategy::MinRtt { margin_us } => *margin_us,
179 RtoStrategy::Fixed { rto_us } => *rto_us,
180 }
181}
182
183impl LpReliability {
184 pub fn new(mtu: usize) -> Self {
186 Self::from_config(mtu, ReliabilityConfig::default())
187 }
188
189 pub fn from_config(mtu: usize, config: ReliabilityConfig) -> Self {
191 let initial_rto = initial_rto_for(&config.rto_strategy);
192 Self {
193 next_seq: 0,
194 unacked: HashMap::new(),
195 pending_acks: VecDeque::new(),
196 srtt_us: 0.0,
197 rttvar_us: 0.0,
198 rto_us: initial_rto,
199 min_rtt_us: u64::MAX,
200 mtu,
201 max_retries: config.max_retries,
202 max_unacked: config.max_unacked,
203 max_retx_per_tick: config.max_retx_per_tick,
204 rto_strategy: config.rto_strategy,
205 }
206 }
207
208 pub fn apply_config(&mut self, config: ReliabilityConfig) {
210 self.rto_us = initial_rto_for(&config.rto_strategy);
211 self.srtt_us = 0.0;
212 self.rttvar_us = 0.0;
213 self.min_rtt_us = u64::MAX;
214 self.max_retries = config.max_retries;
215 self.max_unacked = config.max_unacked;
216 self.max_retx_per_tick = config.max_retx_per_tick;
217 self.rto_strategy = config.rto_strategy;
218 }
219
220 pub fn config(&self) -> ReliabilityConfig {
222 ReliabilityConfig {
223 rto_strategy: self.rto_strategy.clone(),
224 max_retries: self.max_retries,
225 max_unacked: self.max_unacked,
226 max_retx_per_tick: self.max_retx_per_tick,
227 }
228 }
229
230 pub fn on_send(&mut self, pkt: &[u8]) -> Vec<Bytes> {
235 let now = Instant::now();
236
237 let acks: Vec<u64> = self
239 .pending_acks
240 .drain(..self.pending_acks.len().min(MAX_PIGGYBACKED_ACKS))
241 .collect();
242
243 let ack_overhead = acks.len() * 10;
247 let payload_cap = self
248 .mtu
249 .saturating_sub(FRAG_OVERHEAD)
250 .saturating_sub(ack_overhead);
251
252 if payload_cap == 0 {
253 return vec![];
254 }
255
256 let frag_count = pkt.len().div_ceil(payload_cap);
257 let base_seq = self.next_seq;
258 self.next_seq += frag_count as u64;
259
260 let mut wires = Vec::with_capacity(frag_count);
261 for i in 0..frag_count {
262 let start = i * payload_cap;
263 let end = (start + payload_cap).min(pkt.len());
264 let chunk = &pkt[start..end];
265 let seq = base_seq + i as u64;
266
267 let frag_info = if frag_count > 1 {
268 Some((i as u64, frag_count as u64))
269 } else {
270 None
271 };
272
273 let frag_acks = if i == 0 { &acks[..] } else { &[] };
275 let wire = encode_lp_reliable(chunk, seq, frag_info, frag_acks);
276
277 while self.unacked.len() >= self.max_unacked {
282 if let Some(&oldest_seq) = self.unacked.keys().min() {
283 self.unacked.remove(&oldest_seq);
284 } else {
285 break;
286 }
287 }
288
289 self.unacked.insert(
290 seq,
291 UnackedEntry {
292 wire: wire.clone(),
293 first_sent: now,
294 last_sent: now,
295 retx_count: 0,
296 is_retx: false,
297 },
298 );
299
300 wires.push(wire);
301 }
302
303 wires
304 }
305
306 pub fn on_receive(&mut self, raw: &[u8]) {
309 let (tx_seq, acks) = extract_acks(raw);
310
311 if let Some(seq) = tx_seq {
313 self.pending_acks.push_back(seq);
314 }
315
316 let now = Instant::now();
318 for ack_seq in acks {
319 if let Some(entry) = self.unacked.remove(&ack_seq) {
320 if !entry.is_retx {
322 let rtt_us = now.duration_since(entry.first_sent).as_micros() as f64;
323 self.update_rto(rtt_us);
324 }
325 }
326 }
327 }
328
329 pub fn check_retransmit(&mut self) -> Vec<Bytes> {
335 let now = Instant::now();
336 let rto = std::time::Duration::from_micros(self.rto_us);
337 let mut retx = Vec::new();
338 let mut expired = Vec::new();
339
340 for (&seq, entry) in &self.unacked {
341 if now.duration_since(entry.last_sent) >= rto {
342 if entry.retx_count >= self.max_retries {
343 expired.push(seq);
344 } else {
345 retx.push(seq);
346 }
347 }
348 }
349
350 for seq in expired {
352 self.unacked.remove(&seq);
353 }
354
355 let mut wires = Vec::with_capacity(retx.len().min(self.max_retx_per_tick));
357 for seq in retx.into_iter().take(self.max_retx_per_tick) {
358 if let Some(entry) = self.unacked.get_mut(&seq) {
359 entry.last_sent = now;
360 entry.retx_count += 1;
361 entry.is_retx = true;
362 wires.push(entry.wire.clone());
363 }
364 }
365
366 wires
367 }
368
369 pub fn flush_acks(&mut self) -> Option<Bytes> {
372 if self.pending_acks.is_empty() {
373 return None;
374 }
375 let acks: Vec<u64> = self.pending_acks.drain(..).collect();
376 Some(encode_lp_acks(&acks))
377 }
378
379 pub fn unacked_count(&self) -> usize {
381 self.unacked.len()
382 }
383
384 pub fn rto_us(&self) -> u64 {
386 self.rto_us
387 }
388
389 fn update_rto(&mut self, rtt_us: f64) {
391 let rtt_int = rtt_us as u64;
393 if rtt_int < self.min_rtt_us {
394 self.min_rtt_us = rtt_int;
395 }
396
397 match &self.rto_strategy {
398 RtoStrategy::Fixed { .. } => {
399 }
401 RtoStrategy::MinRtt { margin_us } => {
402 self.rto_us = self.min_rtt_us.saturating_add(*margin_us);
403 }
404 RtoStrategy::Rfc6298 => {
405 self.update_ewma(rtt_us, RFC6298_ALPHA, RFC6298_BETA);
406 let rto = self.srtt_us + (4.0 * self.rttvar_us).max(RFC6298_GRANULARITY_US as f64);
407 self.rto_us = (rto as u64).clamp(RFC6298_MIN_RTO_US, RFC6298_MAX_RTO_US);
408 }
409 RtoStrategy::Quic => {
410 self.update_ewma(rtt_us, RFC6298_ALPHA, RFC6298_BETA);
411 let rto = self.srtt_us + (4.0 * self.rttvar_us).max(QUIC_GRANULARITY_US as f64);
412 self.rto_us = (rto as u64).clamp(QUIC_MIN_RTO_US, QUIC_MAX_RTO_US);
413 }
414 }
415 }
416
417 fn update_ewma(&mut self, rtt_us: f64, alpha: f64, beta: f64) {
419 if self.srtt_us == 0.0 {
420 self.srtt_us = rtt_us;
421 self.rttvar_us = rtt_us / 2.0;
422 } else {
423 self.rttvar_us = (1.0 - beta) * self.rttvar_us + beta * (self.srtt_us - rtt_us).abs();
424 self.srtt_us = (1.0 - alpha) * self.srtt_us + alpha * rtt_us;
425 }
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 fn small_packet() -> Vec<u8> {
434 vec![0x05, 0x03, 0xAA, 0xBB, 0xCC]
435 }
436
437 #[test]
438 fn on_send_returns_one_fragment_for_small_packet() {
439 let mut rel = LpReliability::new(1400);
440 let wires = rel.on_send(&small_packet());
441 assert_eq!(wires.len(), 1);
442 assert_eq!(rel.unacked_count(), 1);
443 }
444
445 #[test]
446 fn on_send_fragments_large_packet() {
447 let mut rel = LpReliability::new(200);
448 let data: Vec<u8> = (0..3000).map(|i| (i % 256) as u8).collect();
449 let wires = rel.on_send(&data);
450 assert!(wires.len() > 1);
451 assert_eq!(rel.unacked_count(), wires.len());
452 }
453
454 #[test]
455 fn on_send_assigns_consecutive_sequences() {
456 let mut rel = LpReliability::new(1400);
457 let w1 = rel.on_send(&small_packet());
458 let w2 = rel.on_send(&small_packet());
459 let (seq1, _) = extract_acks(&w1[0]);
460 let (seq2, _) = extract_acks(&w2[0]);
461 assert_eq!(seq1, Some(0));
462 assert_eq!(seq2, Some(1));
463 }
464
465 #[test]
466 fn on_receive_queues_ack() {
467 let mut sender = LpReliability::new(1400);
468 let mut receiver = LpReliability::new(1400);
469
470 let wires = sender.on_send(&small_packet());
471 receiver.on_receive(&wires[0]);
472
473 let ack_pkt = receiver.flush_acks();
475 assert!(ack_pkt.is_some());
476 }
477
478 #[test]
479 fn ack_clears_unacked() {
480 let mut sender = LpReliability::new(1400);
481 let mut receiver = LpReliability::new(1400);
482
483 let wires = sender.on_send(&small_packet());
484 assert_eq!(sender.unacked_count(), 1);
485
486 receiver.on_receive(&wires[0]);
488 let reply = receiver.on_send(&small_packet());
489
490 sender.on_receive(&reply[0]);
492 assert_eq!(sender.unacked_count(), 0);
493 }
494
495 fn fast_rto_config() -> ReliabilityConfig {
496 ReliabilityConfig {
497 rto_strategy: RtoStrategy::Fixed { rto_us: 1_000 }, ..Default::default()
499 }
500 }
501
502 #[test]
503 fn retransmit_after_rto() {
504 let mut rel = LpReliability::from_config(1400, fast_rto_config());
505
506 let _wires = rel.on_send(&small_packet());
507 assert_eq!(rel.unacked_count(), 1);
508
509 std::thread::sleep(std::time::Duration::from_millis(5));
511
512 let retx = rel.check_retransmit();
513 assert_eq!(retx.len(), 1);
514 assert_eq!(rel.unacked_count(), 1); }
516
517 #[test]
518 fn max_retries_drops_entry() {
519 let mut rel = LpReliability::from_config(
520 1400,
521 ReliabilityConfig {
522 max_retries: 1,
523 ..fast_rto_config()
524 },
525 );
526
527 let _wires = rel.on_send(&small_packet());
528 std::thread::sleep(std::time::Duration::from_millis(5));
529
530 let retx = rel.check_retransmit();
532 assert_eq!(retx.len(), 1);
533
534 std::thread::sleep(std::time::Duration::from_millis(5));
535
536 let retx = rel.check_retransmit();
538 assert!(retx.is_empty());
539 assert_eq!(rel.unacked_count(), 0);
540 }
541
542 #[test]
543 fn rto_converges_with_measurements() {
544 let mut rel = LpReliability::new(1400);
545 assert_eq!(rel.rto_us, RFC6298_INITIAL_RTO_US);
546
547 for _ in 0..10 {
549 rel.update_rto(500.0); }
551 assert!(rel.rto_us <= RFC6298_MIN_RTO_US + RFC6298_GRANULARITY_US);
553 }
554
555 #[test]
556 fn flush_acks_returns_none_when_empty() {
557 let mut rel = LpReliability::new(1400);
558 assert!(rel.flush_acks().is_none());
559 }
560
561 #[test]
562 fn piggybacked_acks_in_outgoing_packet() {
563 let mut sender = LpReliability::new(1400);
564 let mut receiver = LpReliability::new(1400);
565
566 let wires = sender.on_send(&small_packet());
568
569 receiver.on_receive(&wires[0]);
571 let reply = receiver.on_send(&small_packet());
572
573 let (_, acks) = extract_acks(&reply[0]);
575 assert!(!acks.is_empty());
576 assert_eq!(acks[0], 0); }
578
579 #[test]
580 fn quic_strategy_lower_initial_rto() {
581 let cfg = ReliabilityConfig {
582 rto_strategy: RtoStrategy::Quic,
583 ..Default::default()
584 };
585 let rel = LpReliability::from_config(1400, cfg);
586 assert_eq!(rel.rto_us, QUIC_INITIAL_RTO_US);
587 assert!(rel.rto_us < RFC6298_INITIAL_RTO_US);
588 }
589
590 #[test]
591 fn quic_strategy_converges_tighter() {
592 let cfg = ReliabilityConfig {
593 rto_strategy: RtoStrategy::Quic,
594 ..Default::default()
595 };
596 let mut rel = LpReliability::from_config(1400, cfg);
597 for _ in 0..10 {
598 rel.update_rto(500.0);
599 }
600 assert!(rel.rto_us < RFC6298_MIN_RTO_US);
602 }
603
604 #[test]
605 fn fixed_strategy_never_changes() {
606 let cfg = ReliabilityConfig {
607 rto_strategy: RtoStrategy::Fixed { rto_us: 50_000 },
608 ..Default::default()
609 };
610 let mut rel = LpReliability::from_config(1400, cfg);
611 assert_eq!(rel.rto_us, 50_000);
612 for _ in 0..20 {
613 rel.update_rto(1_000.0);
614 }
615 assert_eq!(rel.rto_us, 50_000);
616 }
617
618 #[test]
619 fn min_rtt_strategy_tracks_minimum() {
620 let cfg = ReliabilityConfig {
621 rto_strategy: RtoStrategy::MinRtt { margin_us: 5_000 },
622 ..Default::default()
623 };
624 let mut rel = LpReliability::from_config(1400, cfg);
625 rel.update_rto(10_000.0); rel.update_rto(8_000.0); rel.update_rto(15_000.0); assert_eq!(rel.rto_us, 8_000 + 5_000);
630 }
631
632 #[test]
633 fn apply_config_resets_state() {
634 let mut rel = LpReliability::new(1400);
635 for _ in 0..10 {
636 rel.update_rto(500.0);
637 }
638 assert_ne!(rel.srtt_us, 0.0);
639
640 rel.apply_config(ReliabilityConfig {
641 rto_strategy: RtoStrategy::Fixed { rto_us: 100_000 },
642 ..Default::default()
643 });
644 assert_eq!(rel.rto_us, 100_000);
645 assert_eq!(rel.srtt_us, 0.0);
646 assert_eq!(rel.min_rtt_us, u64::MAX);
647 }
648
649 #[test]
650 fn presets_are_consistent() {
651 let local = LpReliability::from_config(1400, ReliabilityConfig::local());
653 let eth = LpReliability::from_config(1400, ReliabilityConfig::ethernet());
654 let wifi = LpReliability::from_config(1400, ReliabilityConfig::wifi());
655
656 assert!(local.rto_us < eth.rto_us);
658 assert!(wifi.config().max_retries > eth.config().max_retries);
659 }
660
661 #[test]
662 fn unacked_map_capped_at_max() {
663 let mut rel = LpReliability::new(1400);
664 for _ in 0..(MAX_UNACKED + 100) {
666 rel.on_send(&small_packet());
667 }
668 assert!(rel.unacked_count() <= MAX_UNACKED);
669 }
670}