ndn_engine/stages/
validation.rs1use std::collections::VecDeque;
2use std::sync::Arc;
3use std::time::Duration;
4
5use tokio::sync::Mutex;
6use tokio::time::Instant;
7use tracing::{debug, trace};
8
9use crate::pipeline::{Action, DecodedPacket, DropReason, PacketContext};
10use ndn_packet::Name;
11use ndn_security::{CertFetcher, ValidationResult, Validator};
12
13struct PendingEntry {
15 ctx: PacketContext,
16 needed_cert: Arc<Name>,
17 deadline: Instant,
18 byte_size: usize,
19}
20
21enum DrainResult {
23 Ready(Box<PacketContext>),
25 Timeout,
27}
28
29struct PendingQueue {
31 entries: VecDeque<PendingEntry>,
32 total_bytes: usize,
33 max_entries: usize,
34 max_bytes: usize,
35 default_timeout: Duration,
36}
37
38pub struct PendingQueueConfig {
40 pub max_entries: usize,
41 pub max_bytes: usize,
42 pub timeout: Duration,
43}
44
45impl Default for PendingQueueConfig {
46 fn default() -> Self {
47 Self {
48 max_entries: 256,
49 max_bytes: 4 * 1024 * 1024, timeout: Duration::from_secs(4),
51 }
52 }
53}
54
55impl PendingQueue {
56 fn new(config: &PendingQueueConfig) -> Self {
57 Self {
58 entries: VecDeque::new(),
59 total_bytes: 0,
60 max_entries: config.max_entries,
61 max_bytes: config.max_bytes,
62 default_timeout: config.timeout,
63 }
64 }
65
66 fn push(&mut self, ctx: PacketContext, needed_cert: Arc<Name>) {
68 let byte_size = ctx.raw_bytes.len();
69
70 while self.entries.len() >= self.max_entries
71 || (self.total_bytes + byte_size > self.max_bytes && !self.entries.is_empty())
72 {
73 if let Some(evicted) = self.entries.pop_front() {
74 self.total_bytes -= evicted.byte_size;
75 debug!("validation pending queue: evicted oldest entry");
76 }
77 }
78
79 self.total_bytes += byte_size;
80 self.entries.push_back(PendingEntry {
81 ctx,
82 needed_cert,
83 deadline: Instant::now() + self.default_timeout,
84 byte_size,
85 });
86 }
87
88 fn drain_ready(&mut self, validator: &Validator) -> Vec<DrainResult> {
90 let mut results = Vec::new();
91 let now = Instant::now();
92 let mut i = 0;
93
94 while i < self.entries.len() {
95 let entry = &self.entries[i];
96
97 if now >= entry.deadline {
98 let entry = self.entries.remove(i).unwrap();
99 self.total_bytes -= entry.byte_size;
100 debug!("validation pending queue: timeout");
101 results.push(DrainResult::Timeout);
102 continue;
103 }
104
105 if validator.cert_cache().get(&entry.needed_cert).is_some() {
106 let entry = self.entries.remove(i).unwrap();
107 self.total_bytes -= entry.byte_size;
108 results.push(DrainResult::Ready(Box::new(entry.ctx)));
109 continue;
110 }
111
112 i += 1;
113 }
114
115 results
116 }
117}
118
119pub struct ValidationStage {
128 pub validator: Option<Arc<Validator>>,
129 pub cert_fetcher: Option<Arc<CertFetcher>>,
130 pending: Arc<Mutex<PendingQueue>>,
131}
132
133impl ValidationStage {
134 pub fn new(
135 validator: Option<Arc<Validator>>,
136 cert_fetcher: Option<Arc<CertFetcher>>,
137 config: PendingQueueConfig,
138 ) -> Self {
139 Self {
140 validator,
141 cert_fetcher,
142 pending: Arc::new(Mutex::new(PendingQueue::new(&config))),
143 }
144 }
145
146 pub fn disabled() -> Self {
148 Self {
149 validator: None,
150 cert_fetcher: None,
151 pending: Arc::new(Mutex::new(
152 PendingQueue::new(&PendingQueueConfig::default()),
153 )),
154 }
155 }
156
157 pub async fn process(&self, ctx: PacketContext) -> Action {
158 let Some(validator) = &self.validator else {
159 return Action::Satisfy(ctx);
160 };
161
162 let data = match &ctx.packet {
163 DecodedPacket::Data(d) => d,
164 _ => return Action::Satisfy(ctx),
165 };
166
167 if data
171 .name
172 .components()
173 .first()
174 .map(|c| c.value.as_ref() == b"localhost")
175 .unwrap_or(false)
176 {
177 trace!(name=%data.name, "validation: skipping /localhost/ management data");
178 return Action::Satisfy(ctx);
179 }
180
181 match validator.validate_chain(data).await {
182 ValidationResult::Valid(_safe) => {
183 trace!(name=%data.name, "validation: valid");
184 Action::Satisfy(ctx)
185 }
186 ValidationResult::Pending => {
187 let needed_cert = data
188 .sig_info()
189 .and_then(|si| si.key_locator.as_ref())
190 .cloned();
191
192 if let Some(cert_name) = needed_cert {
193 trace!(name=%data.name, cert=%cert_name, "validation: pending, queuing");
194
195 if let Some(fetcher) = &self.cert_fetcher {
197 let fetcher = Arc::clone(fetcher);
198 let cn = Arc::clone(&cert_name);
199 tokio::spawn(async move {
200 let _ = fetcher.fetch(&cn).await;
201 });
202 }
203
204 self.pending.lock().await.push(ctx, cert_name);
205 Action::Drop(DropReason::ValidationFailed)
208 } else {
209 debug!(name=%data.name, "validation: pending but no key locator");
210 Action::Drop(DropReason::ValidationFailed)
211 }
212 }
213 ValidationResult::Invalid(e) => {
214 debug!(name=%data.name, error=%e, "validation: FAILED");
215 Action::Drop(DropReason::ValidationFailed)
216 }
217 }
218 }
219
220 pub async fn drain_pending(&self) -> Vec<Action> {
225 let Some(validator) = &self.validator else {
226 return Vec::new();
227 };
228
229 let results = self.pending.lock().await.drain_ready(validator);
230 let mut actions = Vec::with_capacity(results.len());
231
232 for result in results {
233 match result {
234 DrainResult::Timeout => {
235 actions.push(Action::Drop(DropReason::ValidationTimeout));
236 }
237 DrainResult::Ready(ctx) => {
238 let ctx = *ctx;
239 let data = match &ctx.packet {
240 DecodedPacket::Data(d) => d,
241 _ => {
242 actions.push(Action::Satisfy(ctx));
243 continue;
244 }
245 };
246 match validator.validate_chain(data).await {
248 ValidationResult::Valid(_) => {
249 trace!(name=%data.name, "validation: re-validated after cert fetch");
250 actions.push(Action::Satisfy(ctx));
251 }
252 ValidationResult::Pending => {
253 debug!(name=%data.name, "validation: still pending after cert fetch, dropping");
256 actions.push(Action::Drop(DropReason::ValidationFailed));
257 }
258 ValidationResult::Invalid(e) => {
259 debug!(name=%data.name, error=%e, "validation: re-validation FAILED");
260 actions.push(Action::Drop(DropReason::ValidationFailed));
261 }
262 }
263 }
264 }
265 }
266
267 actions
268 }
269}