ndn_engine/stages/
decode.rs

1use std::sync::Arc;
2
3use bytes::Bytes;
4use dashmap::DashMap;
5use tracing::trace;
6
7use crate::pipeline::{Action, DecodedPacket, DropReason, PacketContext};
8use ndn_packet::encode::ensure_nonce;
9use ndn_packet::fragment::ReassemblyBuffer;
10use ndn_packet::lp::{LpPacket, extract_fragment, is_lp_packet};
11use ndn_packet::{Data, Interest, Nack, Name, tlv_type};
12use ndn_store::NameHashes;
13use ndn_transport::{FaceId, FaceScope, FaceTable};
14
15/// Check if a name starts with `/localhost`.
16fn is_localhost_name(name: &Name) -> bool {
17    name.components()
18        .first()
19        .is_some_and(|c| c.value.as_ref() == b"localhost")
20}
21
22/// NDNLPv2 congestion mark, stored as a tag in `PacketContext::tags`.
23#[derive(Clone, Copy, Debug)]
24pub struct CongestionMark(pub u64);
25
26/// NDNLPv2 NextHopFaceId (app→forwarder), stored as a tag.
27#[derive(Clone, Copy, Debug)]
28pub struct NextHopFaceId(pub u64);
29
30/// NDNLPv2 CachePolicy from LP header, stored as a tag.
31#[derive(Clone, Copy, Debug)]
32pub struct LpCachePolicy(pub ndn_packet::CachePolicyType);
33
34/// Decodes the raw bytes in `ctx` into an `Interest`, `Data`, or `Nack`.
35///
36/// Handles both bare Interest/Data packets and NDNLPv2 LpPacket-wrapped
37/// packets. LpPackets with a Nack header produce a `DecodedPacket::Nack`.
38///
39/// On success sets `ctx.packet` and `ctx.name`. On any parse failure returns
40/// `Action::Drop(MalformedPacket)`.
41///
42/// Enforces `/localhost` scope: packets with names starting with `/localhost`
43/// arriving on non-local faces are dropped.
44pub struct TlvDecodeStage {
45    pub face_table: Arc<FaceTable>,
46    /// Per-face NDNLPv2 fragment reassembly buffers.
47    ///
48    /// Keyed by FaceId so fragments from different faces are reassembled
49    /// independently.  Buffers are created on first fragmented packet and
50    /// cleaned up lazily via `purge_expired()`.
51    pub(crate) reassembly: DashMap<FaceId, ReassemblyBuffer>,
52}
53
54impl TlvDecodeStage {
55    /// Create a new decode stage.
56    pub fn new(face_table: Arc<FaceTable>) -> Self {
57        Self {
58            face_table,
59            reassembly: DashMap::new(),
60        }
61    }
62
63    /// Fast-path fragment collection that bypasses `PacketContext` creation.
64    ///
65    /// If the raw bytes are a fragmented LpPacket, parses just the header fields
66    /// and feeds the fragment to the per-face `ReassemblyBuffer`.
67    ///
68    /// Returns:
69    /// - `Ok(Some(bytes))` — reassembly completed, `bytes` is the full packet
70    /// - `Ok(None)` — fragment buffered, waiting for more
71    /// - `Err(bytes)` — not a fragment (bare packet, unfragmented LpPacket, or
72    ///   Nack); caller should process through the full pipeline. The original
73    ///   bytes are returned back.
74    pub fn try_collect_fragment(
75        &self,
76        face_id: FaceId,
77        raw: Bytes,
78    ) -> Result<Option<Bytes>, Bytes> {
79        // Lightweight parse: only extract fragmentation fields, no Bytes
80        // allocation, no Nack/CongestionMark parsing.
81        let hdr = match extract_fragment(&raw) {
82            Some(h) => h,
83            None => return Err(raw), // Not a multi-fragment LpPacket.
84        };
85        let fragment = raw.slice(hdr.frag_start..hdr.frag_end);
86        let base_seq = hdr.sequence - hdr.frag_index;
87        let mut rb = self.reassembly.entry(face_id).or_default();
88        Ok(rb.process(base_seq, hdr.frag_index, hdr.frag_count, fragment))
89    }
90
91    pub fn process(&self, mut ctx: PacketContext) -> Action {
92        let first_byte = match ctx.raw_bytes.first() {
93            Some(&b) => b as u64,
94            None => {
95                trace!(face=%ctx.face_id, "decode: empty packet");
96                return Action::Drop(DropReason::MalformedPacket);
97            }
98        };
99
100        // NDNLPv2: unwrap LpPacket if present.
101        if is_lp_packet(&ctx.raw_bytes) {
102            trace!(face=%ctx.face_id, len=ctx.raw_bytes.len(), "decode: LpPacket");
103            return self.process_lp(ctx);
104        }
105
106        match first_byte {
107            t if t == tlv_type::INTEREST => self.decode_interest(ctx),
108            t if t == tlv_type::DATA => match Data::decode(ctx.raw_bytes.clone()) {
109                Ok(data) => {
110                    trace!(face=%ctx.face_id, name=%data.name, "decode: Data");
111                    if data.name.len() > 3 {
112                        ctx.name_hashes = Some(NameHashes::compute(&data.name));
113                    }
114                    ctx.name = Some(data.name.clone());
115                    ctx.packet = DecodedPacket::Data(Box::new(data));
116                    if let Some(drop) = self.check_scope(&ctx) {
117                        return drop;
118                    }
119                    Action::Continue(ctx)
120                }
121                Err(e) => {
122                    trace!(face=%ctx.face_id, error=%e, "decode: malformed Data");
123                    Action::Drop(DropReason::MalformedPacket)
124                }
125            },
126            _ => {
127                trace!(face=%ctx.face_id, tlv_type=first_byte, "decode: unknown TLV type");
128                Action::Drop(DropReason::MalformedPacket)
129            }
130        }
131    }
132
133    /// Decode a bare Interest, enforcing HopLimit and inserting Nonce.
134    fn decode_interest(&self, mut ctx: PacketContext) -> Action {
135        match Interest::decode(ctx.raw_bytes.clone()) {
136            Ok(interest) => {
137                if interest.hop_limit() == Some(0) {
138                    trace!(face=%ctx.face_id, name=%interest.name, "decode: HopLimit=0, dropping");
139                    return Action::Drop(DropReason::HopLimitExceeded);
140                }
141                trace!(face=%ctx.face_id, name=%interest.name, nonce=?interest.nonce(), "decode: Interest");
142                ctx.raw_bytes = ensure_nonce(&ctx.raw_bytes);
143                // Pre-compute cumulative prefix hashes only when names are long
144                // enough to recoup the cost.  For ≤3 components the per-probe
145                // re-hashing in PitMatchStage is cheaper than the upfront work.
146                if interest.name.len() > 3 {
147                    ctx.name_hashes = Some(NameHashes::compute(&interest.name));
148                }
149                ctx.name = Some(interest.name.clone());
150                ctx.packet = DecodedPacket::Interest(Box::new(interest));
151                if let Some(drop) = self.check_scope(&ctx) {
152                    return drop;
153                }
154                Action::Continue(ctx)
155            }
156            Err(e) => {
157                trace!(face=%ctx.face_id, error=%e, "decode: malformed Interest");
158                Action::Drop(DropReason::MalformedPacket)
159            }
160        }
161    }
162
163    /// Drop packets with `/localhost` names arriving on non-local faces.
164    fn check_scope(&self, ctx: &PacketContext) -> Option<Action> {
165        if let Some(ref name) = ctx.name
166            && is_localhost_name(name)
167        {
168            let is_non_local = self
169                .face_table
170                .get(ctx.face_id)
171                .is_some_and(|f| f.kind().scope() == FaceScope::NonLocal);
172            if is_non_local {
173                trace!(face=%ctx.face_id, name=%name, "decode: /localhost on non-local face, dropping");
174                return Some(Action::Drop(DropReason::ScopeViolation));
175            }
176        }
177        None
178    }
179
180    /// Process an NDNLPv2 LpPacket.
181    ///
182    /// Handles fragment reassembly: if the LpPacket is a fragment, it is
183    /// buffered per-face until all fragments arrive.  Returns `Action::Drop`
184    /// for incomplete reassemblies (waiting for more fragments) and
185    /// re-enters `process()` when the complete packet is available.
186    fn process_lp(&self, mut ctx: PacketContext) -> Action {
187        let lp = match LpPacket::decode(ctx.raw_bytes.clone()) {
188            Ok(lp) => lp,
189            Err(e) => {
190                trace!(face=%ctx.face_id, error=%e, "decode: malformed LpPacket");
191                return Action::Drop(DropReason::MalformedPacket);
192            }
193        };
194
195        // Propagate LP header fields through the pipeline via tags/context.
196        if let Some(mark) = lp.congestion_mark {
197            ctx.tags.insert(CongestionMark(mark));
198        }
199        if let Some(token) = lp.pit_token.clone() {
200            ctx.lp_pit_token = Some(token);
201        }
202        if let Some(face_id) = lp.next_hop_face_id {
203            ctx.tags.insert(NextHopFaceId(face_id));
204        }
205        if let Some(ref policy) = lp.cache_policy {
206            ctx.tags.insert(LpCachePolicy(*policy));
207        }
208
209        // Bare Ack-only packets have no payload to process.
210        if lp.is_ack_only() {
211            return Action::Drop(DropReason::FragmentCollect);
212        }
213
214        let is_fragmented = lp.is_fragmented();
215        let sequence = lp.sequence;
216        let frag_index = lp.frag_index;
217        let frag_count = lp.frag_count;
218        let nack = lp.nack;
219
220        let fragment = match lp.fragment {
221            Some(f) => f,
222            None => return Action::Drop(DropReason::MalformedPacket),
223        };
224
225        // Fragment reassembly: buffer until all fragments arrive.
226        if is_fragmented {
227            let face_id = ctx.face_id;
228            let complete = {
229                let mut rb = self.reassembly.entry(face_id).or_default();
230                let seq = sequence.unwrap_or(0);
231                let idx = frag_index.unwrap_or(0);
232                let base_seq = seq - idx;
233                rb.process(base_seq, idx, frag_count.unwrap_or(1), fragment)
234            };
235            match complete {
236                Some(packet) => {
237                    trace!(face=%ctx.face_id, len=packet.len(), "decode: reassembled");
238                    ctx.raw_bytes = packet;
239                    return self.process(ctx);
240                }
241                None => {
242                    // Still waiting for more fragments — not an error.
243                    return Action::Drop(DropReason::FragmentCollect);
244                }
245            }
246        }
247
248        if let Some(reason) = nack {
249            // LpPacket with Nack header: fragment is the nacked Interest.
250            match Interest::decode(fragment) {
251                Ok(interest) => {
252                    trace!(face=%ctx.face_id, name=%interest.name, reason=?reason, "decode: Nack");
253                    let nack = Nack::new(interest, reason);
254                    if nack.interest.name.len() > 3 {
255                        ctx.name_hashes = Some(NameHashes::compute(&nack.interest.name));
256                    }
257                    ctx.name = Some(nack.interest.name.clone());
258                    ctx.packet = DecodedPacket::Nack(Box::new(nack));
259                    if let Some(drop) = self.check_scope(&ctx) {
260                        return drop;
261                    }
262                    Action::Continue(ctx)
263                }
264                Err(e) => {
265                    trace!(face=%ctx.face_id, error=%e, "decode: malformed nacked Interest");
266                    Action::Drop(DropReason::MalformedPacket)
267                }
268            }
269        } else {
270            // Plain LpPacket wrapping Interest or Data.
271            ctx.raw_bytes = fragment;
272            self.process(ctx)
273        }
274    }
275}