1use std::collections::HashMap;
13use std::time::{Duration, Instant};
14
15use bytes::Bytes;
16use ndn_tlv::TlvWriter;
17
18use crate::encode::nni;
19use crate::tlv_type;
20
21pub const DEFAULT_UDP_MTU: usize = 1400;
23
24const DEFAULT_REASSEMBLY_TIMEOUT: Duration = Duration::from_secs(5);
26
27pub const FRAG_OVERHEAD: usize = 50;
30
31pub fn fragment_packet(packet: &[u8], mtu: usize, base_seq: u64) -> Vec<Bytes> {
47 let payload_cap = mtu
48 .checked_sub(FRAG_OVERHEAD)
49 .expect("MTU too small for fragmentation overhead");
50 assert!(payload_cap > 0, "MTU too small");
51
52 let frag_count = packet.len().div_ceil(payload_cap);
53
54 let mut fragments = Vec::with_capacity(frag_count);
55 for i in 0..frag_count {
56 let start = i * payload_cap;
57 let end = (start + payload_cap).min(packet.len());
58 let chunk = &packet[start..end];
59
60 let mut w = TlvWriter::new();
61 w.write_nested(tlv_type::LP_PACKET, |w| {
62 let (buf, len) = nni(base_seq + i as u64);
63 w.write_tlv(tlv_type::LP_SEQUENCE, &buf[..len]);
64 let (buf, len) = nni(i as u64);
65 w.write_tlv(tlv_type::LP_FRAG_INDEX, &buf[..len]);
66 let (buf, len) = nni(frag_count as u64);
67 w.write_tlv(tlv_type::LP_FRAG_COUNT, &buf[..len]);
68 w.write_tlv(tlv_type::LP_FRAGMENT, chunk);
69 });
70 fragments.push(w.finish());
71 }
72
73 fragments
74}
75
76struct Pending {
78 fragments: Vec<Option<Bytes>>,
79 frag_count: usize,
80 received: usize,
81 created: Instant,
82}
83
84pub struct ReassemblyBuffer {
89 pending: HashMap<u64, Pending>,
90 timeout: Duration,
91}
92
93impl ReassemblyBuffer {
94 pub fn new(timeout: Duration) -> Self {
95 Self {
96 pending: HashMap::new(),
97 timeout,
98 }
99 }
100
101 pub fn process(
107 &mut self,
108 seq: u64,
109 frag_index: u64,
110 frag_count: u64,
111 fragment: Bytes,
112 ) -> Option<Bytes> {
113 let count = frag_count as usize;
114 let idx = frag_index as usize;
115
116 if count == 0 || idx >= count {
117 return None;
118 }
119
120 let entry = self.pending.entry(seq).or_insert_with(|| Pending {
121 fragments: vec![None; count],
122 frag_count: count,
123 received: 0,
124 created: Instant::now(),
125 });
126
127 if entry.frag_count != count || idx >= entry.frag_count {
129 return None;
130 }
131
132 if entry.fragments[idx].is_none() {
134 entry.received += 1;
135 }
136 entry.fragments[idx] = Some(fragment);
137
138 if entry.received == entry.frag_count {
139 let entry = self.pending.remove(&seq).unwrap();
140 let total_len: usize = entry
141 .fragments
142 .iter()
143 .map(|f| f.as_ref().unwrap().len())
144 .sum();
145 let mut buf = Vec::with_capacity(total_len);
146 for frag in &entry.fragments {
147 buf.extend_from_slice(frag.as_ref().unwrap());
148 }
149 Some(Bytes::from(buf))
150 } else {
151 None
152 }
153 }
154
155 pub fn purge_expired(&mut self) {
157 let timeout = self.timeout;
158 self.pending.retain(|_, v| v.created.elapsed() < timeout);
159 }
160
161 pub fn pending_count(&self) -> usize {
163 self.pending.len()
164 }
165}
166
167impl Default for ReassemblyBuffer {
168 fn default() -> Self {
169 Self::new(DEFAULT_REASSEMBLY_TIMEOUT)
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 #[test]
178 fn single_fragment_roundtrip() {
179 let data = vec![0x06, 0x03, 0xAA, 0xBB, 0xCC]; let frags = fragment_packet(&data, DEFAULT_UDP_MTU, 100);
181 assert_eq!(frags.len(), 1);
182
183 let lp = crate::lp::LpPacket::decode(frags[0].clone()).unwrap();
185 assert_eq!(lp.sequence, Some(100));
186 assert_eq!(lp.frag_index, Some(0));
187 assert_eq!(lp.frag_count, Some(1));
188 assert_eq!(lp.fragment.as_deref().unwrap(), &data[..]);
189 }
190
191 #[test]
192 fn multi_fragment_roundtrip() {
193 let data: Vec<u8> = (0..3000).map(|i| (i % 256) as u8).collect();
195 let frags = fragment_packet(&data, 200, 42);
196 assert!(
197 frags.len() > 1,
198 "expected multiple fragments, got {}",
199 frags.len()
200 );
201
202 let mut buf = ReassemblyBuffer::default();
204 let mut result = None;
205 for (i, frag_bytes) in frags.iter().enumerate() {
206 let lp = crate::lp::LpPacket::decode(frag_bytes.clone()).unwrap();
207 assert_eq!(lp.sequence, Some(42 + i as u64));
209 assert!(lp.is_fragmented());
210
211 let base_seq = lp.sequence.unwrap() - lp.frag_index.unwrap();
212 result = buf.process(
213 base_seq,
214 lp.frag_index.unwrap(),
215 lp.frag_count.unwrap(),
216 lp.fragment.unwrap(),
217 );
218 }
219
220 let reassembled = result.expect("reassembly should complete");
221 assert_eq!(reassembled.as_ref(), &data[..]);
222 assert_eq!(buf.pending_count(), 0);
223 }
224
225 fn base_seq(lp: &crate::lp::LpPacket) -> u64 {
227 lp.sequence.unwrap() - lp.frag_index.unwrap()
228 }
229
230 #[test]
231 fn out_of_order_reassembly() {
232 let data: Vec<u8> = (0..3000).map(|i| (i % 256) as u8).collect();
233 let frags = fragment_packet(&data, 200, 7);
234 assert!(frags.len() > 2);
235
236 let mut buf = ReassemblyBuffer::default();
238 let mut result = None;
239 for frag_bytes in frags.iter().rev() {
240 let lp = crate::lp::LpPacket::decode(frag_bytes.clone()).unwrap();
241 result = buf.process(
242 base_seq(&lp),
243 lp.frag_index.unwrap(),
244 lp.frag_count.unwrap(),
245 lp.fragment.unwrap(),
246 );
247 }
248
249 let reassembled = result.expect("out-of-order reassembly should complete");
250 assert_eq!(reassembled.as_ref(), &data[..]);
251 }
252
253 #[test]
254 fn duplicate_fragment_handled() {
255 let data: Vec<u8> = (0..3000).map(|i| (i % 256) as u8).collect();
256 let frags = fragment_packet(&data, 200, 1);
257
258 let mut buf = ReassemblyBuffer::default();
259 for frag_bytes in &frags[..frags.len() - 1] {
261 let lp = crate::lp::LpPacket::decode(frag_bytes.clone()).unwrap();
262 let r = buf.process(
263 base_seq(&lp),
264 lp.frag_index.unwrap(),
265 lp.frag_count.unwrap(),
266 lp.fragment.unwrap(),
267 );
268 assert!(r.is_none());
269 }
270 let lp0 = crate::lp::LpPacket::decode(frags[0].clone()).unwrap();
272 let r = buf.process(
273 base_seq(&lp0),
274 lp0.frag_index.unwrap(),
275 lp0.frag_count.unwrap(),
276 lp0.fragment.unwrap(),
277 );
278 assert!(r.is_none());
279
280 let lp_last = crate::lp::LpPacket::decode(frags.last().unwrap().clone()).unwrap();
282 let r = buf.process(
283 base_seq(&lp_last),
284 lp_last.frag_index.unwrap(),
285 lp_last.frag_count.unwrap(),
286 lp_last.fragment.unwrap(),
287 );
288 assert!(r.is_some());
289 }
290
291 #[test]
292 fn purge_expired() {
293 let data: Vec<u8> = (0..3000).map(|i| (i % 256) as u8).collect();
294 let frags = fragment_packet(&data, 200, 1);
295
296 let mut buf = ReassemblyBuffer::new(Duration::from_millis(0));
297 let lp = crate::lp::LpPacket::decode(frags[0].clone()).unwrap();
299 buf.process(
300 base_seq(&lp),
301 lp.frag_index.unwrap(),
302 lp.frag_count.unwrap(),
303 lp.fragment.unwrap(),
304 );
305 assert_eq!(buf.pending_count(), 1);
306
307 buf.purge_expired();
309 assert_eq!(buf.pending_count(), 0);
310 }
311
312 #[test]
313 fn each_fragment_within_mtu() {
314 let data: Vec<u8> = (0..5000).map(|i| (i % 256) as u8).collect();
315 let mtu = 500;
316 let frags = fragment_packet(&data, mtu, 0);
317 for (i, frag) in frags.iter().enumerate() {
318 assert!(
319 frag.len() <= mtu,
320 "fragment {i} is {} bytes, exceeds MTU {mtu}",
321 frag.len()
322 );
323 }
324 }
325}