ndn_discovery/
composite.rs

1//! `CompositeDiscovery` — runs multiple protocols simultaneously.
2//!
3//! Validates at construction time that no two protocols claim overlapping name
4//! prefixes, then routes inbound packets to the correct protocol by prefix
5//! match.  Face lifecycle hooks are delivered to all protocols in registration
6//! order.
7
8use std::sync::Arc;
9use std::time::Instant;
10
11use bytes::Bytes;
12use ndn_packet::Name;
13use ndn_transport::FaceId;
14
15use crate::{DiscoveryContext, DiscoveryProtocol, InboundMeta, ProtocolId};
16
17/// Wrapper that runs multiple [`DiscoveryProtocol`] implementations in parallel.
18///
19/// # Namespace safety
20///
21/// [`CompositeDiscovery::new`] returns an error if any two protocols claim
22/// overlapping name prefixes (one is a prefix of the other).  Each protocol
23/// must use a distinct sub-tree of `/ndn/local/`.
24///
25/// # Inbound routing
26///
27/// When a raw packet arrives, `CompositeDiscovery` tries to parse its top-level
28/// NDN name and routes it to the first protocol whose `claimed_prefixes` contains
29/// a matching prefix.  If the name cannot be parsed or no protocol matches, the
30/// packet is not consumed (returns `false`).
31///
32/// # Tick delivery
33///
34/// All protocols receive every `on_tick` call.  Order is not guaranteed.
35pub struct CompositeDiscovery {
36    protocols: Vec<Arc<dyn DiscoveryProtocol>>,
37}
38
39impl CompositeDiscovery {
40    /// Construct a composite from a list of protocols.
41    ///
42    /// Returns `Err` with a human-readable message if any two protocols claim
43    /// overlapping prefixes.
44    pub fn new(protocols: Vec<Arc<dyn DiscoveryProtocol>>) -> Result<Self, String> {
45        // Collect all (prefix, protocol_id) pairs and check for overlaps.
46        let mut all_prefixes: Vec<(Name, ProtocolId)> = Vec::new();
47        for proto in &protocols {
48            for prefix in proto.claimed_prefixes() {
49                // Check against all previously registered prefixes.
50                for (existing, existing_id) in &all_prefixes {
51                    if prefixes_overlap(existing, prefix) {
52                        return Err(format!(
53                            "protocol '{}' prefix '{}' overlaps with protocol '{}' prefix '{}'",
54                            proto.protocol_id(),
55                            prefix,
56                            existing_id,
57                            existing,
58                        ));
59                    }
60                }
61                all_prefixes.push((prefix.clone(), proto.protocol_id()));
62            }
63        }
64        Ok(Self { protocols })
65    }
66
67    /// Number of contained protocols.
68    pub fn len(&self) -> usize {
69        self.protocols.len()
70    }
71
72    pub fn is_empty(&self) -> bool {
73        self.protocols.is_empty()
74    }
75
76    /// Collect all prefixes claimed by any child protocol.
77    ///
78    /// Unlike `claimed_prefixes()` (which returns the composite's own
79    /// top-level claims), this method flattens the claims of all children.
80    /// Use this to enumerate the full set of prefixes owned by the discovery
81    /// stack (e.g. for management security enforcement).
82    pub fn all_claimed_prefixes(&self) -> Vec<Name> {
83        self.protocols
84            .iter()
85            .flat_map(|p| p.claimed_prefixes().iter().cloned())
86            .collect()
87    }
88}
89
90impl DiscoveryProtocol for CompositeDiscovery {
91    fn protocol_id(&self) -> ProtocolId {
92        ProtocolId("composite")
93    }
94
95    fn claimed_prefixes(&self) -> &[Name] {
96        // CompositeDiscovery doesn't claim additional prefixes beyond
97        // what its children claim — return empty here since children
98        // are already registered and checked.
99        &[]
100    }
101
102    fn on_face_up(&self, face_id: FaceId, ctx: &dyn DiscoveryContext) {
103        for proto in &self.protocols {
104            proto.on_face_up(face_id, ctx);
105        }
106    }
107
108    fn on_face_down(&self, face_id: FaceId, ctx: &dyn DiscoveryContext) {
109        for proto in &self.protocols {
110            proto.on_face_down(face_id, ctx);
111        }
112    }
113
114    fn on_inbound(
115        &self,
116        raw: &Bytes,
117        incoming_face: FaceId,
118        meta: &InboundMeta,
119        ctx: &dyn DiscoveryContext,
120    ) -> bool {
121        // Try to parse the packet name for prefix-based routing.
122        if let Some(name) = parse_first_name(raw) {
123            for proto in &self.protocols {
124                for prefix in proto.claimed_prefixes() {
125                    if name.has_prefix(prefix) {
126                        return proto.on_inbound(raw, incoming_face, meta, ctx);
127                    }
128                }
129            }
130            // Name parsed but no protocol claimed it — not consumed.
131            return false;
132        }
133
134        // Name parse failed — try all protocols in order (fallback).
135        for proto in &self.protocols {
136            if proto.on_inbound(raw, incoming_face, meta, ctx) {
137                return true;
138            }
139        }
140        false
141    }
142
143    fn on_tick(&self, now: Instant, ctx: &dyn DiscoveryContext) {
144        for proto in &self.protocols {
145            proto.on_tick(now, ctx);
146        }
147    }
148}
149
150// ── Helpers ───────────────────────────────────────────────────────────────────
151
152/// Returns true if `a` is a prefix of `b` or vice-versa (i.e. they overlap
153/// in the name tree).
154fn prefixes_overlap(a: &Name, b: &Name) -> bool {
155    b.has_prefix(a) || a.has_prefix(b)
156}
157
158/// Try to parse the first NDN name out of a raw TLV packet.
159///
160/// NDN packet TLV: Interest (0x05) or Data (0x06), then a Name TLV (0x07)
161/// immediately as the first child.  This does a minimal parse — just enough
162/// to route the packet to the correct sub-protocol.
163///
164/// Bytes arrive LP-unwrapped from the pipeline (TlvDecodeStage strips LP
165/// before on_inbound is called), so no LP handling is needed here.
166fn parse_first_name(raw: &Bytes) -> Option<Name> {
167    // Require at least a 2-byte TLV header.
168    if raw.len() < 4 {
169        return None;
170    }
171    let pkt_type = raw[0];
172    if pkt_type != 0x05 && pkt_type != 0x06 {
173        return None; // Not an Interest or Data
174    }
175    // Skip packet type + length (variable-length varint).
176    let (_, inner) = skip_tlv_header(raw)?;
177    // First child should be a Name TLV (type 0x07).
178    if inner.is_empty() || inner[0] != 0x07 {
179        return None;
180    }
181    // inner begins with the Name TLV; skip its type+length to get just the
182    // component bytes that Name::decode expects.
183    let (_, name_value) = skip_tlv_header(inner)?;
184    let name_bytes = bytes::Bytes::copy_from_slice(name_value);
185    Name::decode(name_bytes).ok()
186}
187
188/// Skip a TLV type+length prefix, returning a slice of the value bytes.
189/// Returns `(type, value_bytes)` or `None` on truncation.
190fn skip_tlv_header(buf: &[u8]) -> Option<(u8, &[u8])> {
191    if buf.is_empty() {
192        return None;
193    }
194    let t = buf[0];
195    let (len, hdr_size) = read_varu(buf.get(1..)?)?;
196    let end = 1 + hdr_size + len;
197    Some((t, buf.get(1 + hdr_size..end)?))
198}
199
200/// Read a minimal NDN TLV varint.  Returns `(value, bytes_consumed)`.
201fn read_varu(buf: &[u8]) -> Option<(usize, usize)> {
202    match buf.first()? {
203        b if *b < 253 => Some((*b as usize, 1)),
204        253 => {
205            let hi = *buf.get(1)? as usize;
206            let lo = *buf.get(2)? as usize;
207            Some(((hi << 8) | lo, 3))
208        }
209        254 => {
210            let b1 = *buf.get(1)? as usize;
211            let b2 = *buf.get(2)? as usize;
212            let b3 = *buf.get(3)? as usize;
213            let b4 = *buf.get(4)? as usize;
214            Some(((b1 << 24) | (b2 << 16) | (b3 << 8) | b4, 5))
215        }
216        _ => None, // 8-byte form not needed for discovery packets
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use crate::{
224        DiscoveryContext, InboundMeta, NeighborTable, NeighborTableView, NeighborUpdate,
225        NoDiscovery, ProtocolId,
226    };
227    use std::str::FromStr;
228    use std::sync::atomic::{AtomicBool, Ordering};
229
230    // ── Helpers ───────────────────────────────────────────────────────────────
231
232    /// Build a minimal, parseable Interest TLV for `name`.
233    ///
234    /// Wire: `0x05 <len> 0x07 <name_len> <components...>`
235    fn minimal_interest(name: &Name) -> Bytes {
236        use ndn_tlv::TlvWriter;
237        let mut w = TlvWriter::new();
238        w.write_nested(0x05u64, |w: &mut TlvWriter| {
239            w.write_nested(0x07u64, |w: &mut TlvWriter| {
240                for comp in name.components() {
241                    w.write_tlv(comp.typ, &comp.value);
242                }
243            });
244        });
245        w.finish()
246    }
247
248    struct MockProto {
249        id: ProtocolId,
250        prefixes: Vec<Name>,
251        called: AtomicBool,
252    }
253
254    impl MockProto {
255        fn new(id: &'static str, prefix: &str) -> Arc<Self> {
256            Arc::new(Self {
257                id: ProtocolId(id),
258                prefixes: vec![Name::from_str(prefix).unwrap()],
259                called: AtomicBool::new(false),
260            })
261        }
262    }
263
264    impl DiscoveryProtocol for MockProto {
265        fn protocol_id(&self) -> ProtocolId {
266            self.id
267        }
268        fn claimed_prefixes(&self) -> &[Name] {
269            &self.prefixes
270        }
271        fn on_face_up(&self, _: FaceId, _: &dyn DiscoveryContext) {}
272        fn on_face_down(&self, _: FaceId, _: &dyn DiscoveryContext) {}
273        fn on_inbound(
274            &self,
275            _: &Bytes,
276            _: FaceId,
277            _: &InboundMeta,
278            _: &dyn DiscoveryContext,
279        ) -> bool {
280            self.called.store(true, Ordering::SeqCst);
281            true
282        }
283        fn on_tick(&self, _: Instant, _: &dyn DiscoveryContext) {}
284    }
285
286    struct NullCtx;
287
288    impl DiscoveryContext for NullCtx {
289        fn alloc_face_id(&self) -> FaceId {
290            FaceId(0)
291        }
292        fn add_face(&self, _: Arc<dyn ndn_transport::ErasedFace>) -> FaceId {
293            FaceId(0)
294        }
295        fn remove_face(&self, _: FaceId) {}
296        fn add_fib_entry(&self, _: &Name, _: FaceId, _: u32, _: ProtocolId) {}
297        fn remove_fib_entry(&self, _: &Name, _: FaceId, _: ProtocolId) {}
298        fn remove_fib_entries_by_owner(&self, _: ProtocolId) {}
299        fn neighbors(&self) -> Arc<dyn NeighborTableView> {
300            NeighborTable::new()
301        }
302        fn update_neighbor(&self, _: NeighborUpdate) {}
303        fn send_on(&self, _: FaceId, _: Bytes) {}
304        fn now(&self) -> Instant {
305            Instant::now()
306        }
307    }
308
309    // ── Construction tests ────────────────────────────────────────────────────
310
311    #[test]
312    fn no_overlap_is_ok() {
313        let p1 = MockProto::new("nd", "/ndn/local/nd");
314        let p2 = MockProto::new("sd", "/ndn/local/sd");
315        assert!(CompositeDiscovery::new(vec![p1, p2]).is_ok());
316    }
317
318    #[test]
319    fn overlap_is_rejected() {
320        let p1 = MockProto::new("nd", "/ndn/local/nd");
321        // /ndn/local/nd/hello is a sub-prefix of /ndn/local/nd → overlap
322        let p2 = MockProto::new("nd2", "/ndn/local/nd/hello");
323        assert!(CompositeDiscovery::new(vec![p1, p2]).is_err());
324    }
325
326    #[test]
327    fn empty_composite_works() {
328        let c = CompositeDiscovery::new(vec![]).unwrap();
329        assert!(c.is_empty());
330    }
331
332    #[test]
333    fn no_discovery_doesnt_conflict() {
334        let nd = Arc::new(NoDiscovery) as Arc<dyn DiscoveryProtocol>;
335        let nd2 = Arc::new(NoDiscovery) as Arc<dyn DiscoveryProtocol>;
336        // Both claim no prefixes → no conflict.
337        assert!(CompositeDiscovery::new(vec![nd, nd2]).is_ok());
338    }
339
340    // ── on_inbound routing tests ──────────────────────────────────────────────
341
342    #[test]
343    fn routes_to_matching_protocol() {
344        let p1 = MockProto::new("nd", "/ndn/local/nd");
345        let p2 = MockProto::new("sd", "/ndn/local/sd");
346        let p1_ref = Arc::clone(&p1);
347        let p2_ref = Arc::clone(&p2);
348        let composite = CompositeDiscovery::new(vec![p1, p2]).unwrap();
349
350        // Build an Interest with name /ndn/local/nd/hello — matches p1's prefix.
351        let name = Name::from_str("/ndn/local/nd/hello").unwrap();
352        let pkt = minimal_interest(&name);
353
354        let consumed = composite.on_inbound(&pkt, FaceId(0), &InboundMeta::none(), &NullCtx);
355        assert!(consumed, "composite should consume packet matching p1");
356        assert!(
357            p1_ref.called.load(Ordering::SeqCst),
358            "p1 should have been called"
359        );
360        assert!(
361            !p2_ref.called.load(Ordering::SeqCst),
362            "p2 should NOT have been called"
363        );
364    }
365
366    #[test]
367    fn routes_to_second_protocol() {
368        let p1 = MockProto::new("nd", "/ndn/local/nd");
369        let p2 = MockProto::new("sd", "/ndn/local/sd");
370        let p1_ref = Arc::clone(&p1);
371        let p2_ref = Arc::clone(&p2);
372        let composite = CompositeDiscovery::new(vec![p1, p2]).unwrap();
373
374        // Build an Interest with name /ndn/local/sd/hello — matches p2's prefix.
375        let name = Name::from_str("/ndn/local/sd/hello").unwrap();
376        let pkt = minimal_interest(&name);
377
378        let consumed = composite.on_inbound(&pkt, FaceId(0), &InboundMeta::none(), &NullCtx);
379        assert!(consumed, "composite should consume packet matching p2");
380        assert!(
381            !p1_ref.called.load(Ordering::SeqCst),
382            "p1 should NOT have been called"
383        );
384        assert!(
385            p2_ref.called.load(Ordering::SeqCst),
386            "p2 should have been called"
387        );
388    }
389
390    #[test]
391    fn no_match_returns_false() {
392        let p1 = MockProto::new("nd", "/ndn/local/nd");
393        let p2 = MockProto::new("sd", "/ndn/local/sd");
394        let composite = CompositeDiscovery::new(vec![p1, p2]).unwrap();
395
396        // Build an Interest with name /ndn/local/other — matches neither.
397        let name = Name::from_str("/ndn/local/other/hello").unwrap();
398        let pkt = minimal_interest(&name);
399
400        let consumed = composite.on_inbound(&pkt, FaceId(0), &InboundMeta::none(), &NullCtx);
401        assert!(!consumed, "composite should NOT consume unmatched packet");
402    }
403
404    #[test]
405    fn garbage_bytes_not_consumed_when_no_protocol_claims_them() {
406        // A protocol that never claims any packet (returns false from on_inbound).
407        struct NullProto;
408        impl DiscoveryProtocol for NullProto {
409            fn protocol_id(&self) -> ProtocolId {
410                ProtocolId("null")
411            }
412            fn claimed_prefixes(&self) -> &[Name] {
413                &[]
414            }
415            fn on_face_up(&self, _: FaceId, _: &dyn DiscoveryContext) {}
416            fn on_face_down(&self, _: FaceId, _: &dyn DiscoveryContext) {}
417            fn on_inbound(
418                &self,
419                _: &Bytes,
420                _: FaceId,
421                _: &InboundMeta,
422                _: &dyn DiscoveryContext,
423            ) -> bool {
424                false
425            }
426            fn on_tick(&self, _: Instant, _: &dyn DiscoveryContext) {}
427        }
428        let composite =
429            CompositeDiscovery::new(vec![Arc::new(NullProto) as Arc<dyn DiscoveryProtocol>])
430                .unwrap();
431
432        let junk = Bytes::from_static(b"\xFF\xFF\xFF");
433        let consumed = composite.on_inbound(&junk, FaceId(0), &InboundMeta::none(), &NullCtx);
434        assert!(
435            !consumed,
436            "garbage packet should not be consumed when no protocol claims it"
437        );
438    }
439
440    #[test]
441    fn face_lifecycle_delivered_to_all() {
442        let p1 = MockProto::new("nd", "/ndn/local/nd");
443        let p2 = MockProto::new("sd", "/ndn/local/sd");
444        let p1_ref = Arc::clone(&p1);
445        let p2_ref = Arc::clone(&p2);
446
447        // Wrap in a tracking impl for on_face_up.
448        struct TrackFaceUp {
449            inner: Arc<MockProto>,
450            up_called: AtomicBool,
451        }
452        impl DiscoveryProtocol for TrackFaceUp {
453            fn protocol_id(&self) -> ProtocolId {
454                self.inner.id
455            }
456            fn claimed_prefixes(&self) -> &[Name] {
457                &self.inner.prefixes
458            }
459            fn on_face_up(&self, _: FaceId, _: &dyn DiscoveryContext) {
460                self.up_called.store(true, Ordering::SeqCst);
461            }
462            fn on_face_down(&self, _: FaceId, _: &dyn DiscoveryContext) {}
463            fn on_inbound(
464                &self,
465                _: &Bytes,
466                _: FaceId,
467                _: &InboundMeta,
468                _: &dyn DiscoveryContext,
469            ) -> bool {
470                false
471            }
472            fn on_tick(&self, _: Instant, _: &dyn DiscoveryContext) {}
473        }
474
475        let t1 = Arc::new(TrackFaceUp {
476            inner: Arc::clone(&p1_ref),
477            up_called: AtomicBool::new(false),
478        });
479        let t2 = Arc::new(TrackFaceUp {
480            inner: Arc::clone(&p2_ref),
481            up_called: AtomicBool::new(false),
482        });
483        let t1_ref = Arc::clone(&t1);
484        let t2_ref = Arc::clone(&t2);
485
486        let composite = CompositeDiscovery::new(vec![
487            t1 as Arc<dyn DiscoveryProtocol>,
488            t2 as Arc<dyn DiscoveryProtocol>,
489        ])
490        .unwrap();
491        composite.on_face_up(FaceId(3), &NullCtx);
492
493        assert!(
494            t1_ref.up_called.load(Ordering::SeqCst),
495            "p1 should have received on_face_up"
496        );
497        assert!(
498            t2_ref.up_called.load(Ordering::SeqCst),
499            "p2 should have received on_face_up"
500        );
501    }
502}