ndn_security/
cert_fetcher.rs

1//! Asynchronous certificate fetcher for NDN trust chain resolution.
2//!
3//! `CertFetcher` retrieves certificates over NDN by expressing Interests
4//! for certificate names. It deduplicates concurrent requests for the same
5//! certificate and caches results in the shared `CertCache`.
6
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::time::Duration;
11
12use dashmap::DashMap;
13use ndn_packet::{Data, Name};
14use tokio::sync::broadcast;
15
16use crate::TrustError;
17use crate::cert_cache::{CertCache, Certificate};
18
19/// Type alias for the async fetch callback.
20///
21/// The callback takes a certificate Name and returns `Option<Data>`.
22/// `None` means the fetch failed (timeout, no route, etc.).
23pub type FetchFn =
24    Arc<dyn Fn(Name) -> Pin<Box<dyn Future<Output = Option<Data>> + Send>> + Send + Sync>;
25
26/// Fetches certificates over NDN with deduplication and caching.
27pub struct CertFetcher {
28    cert_cache: Arc<CertCache>,
29    fetch_fn: FetchFn,
30    /// In-flight fetch deduplication.
31    /// When a cert is being fetched, subsequent requests wait on a broadcast.
32    in_flight: DashMap<Arc<Name>, broadcast::Sender<Option<Certificate>>>,
33    timeout: Duration,
34}
35
36impl CertFetcher {
37    /// Create a new fetcher.
38    ///
39    /// `fetch_fn` is called to express an Interest and return the Data response.
40    /// It is typically wired to an `AppHandle` in the engine.
41    pub fn new(cert_cache: Arc<CertCache>, fetch_fn: FetchFn, timeout: Duration) -> Self {
42        Self {
43            cert_cache,
44            fetch_fn,
45            in_flight: DashMap::new(),
46            timeout,
47        }
48    }
49
50    /// Fetch a certificate by name.
51    ///
52    /// Returns immediately if the cert is already cached. Otherwise, expresses
53    /// an Interest, decodes the response, and caches it. Concurrent requests
54    /// for the same name are deduplicated (only one Interest is sent).
55    pub async fn fetch(&self, cert_name: &Arc<Name>) -> Result<Certificate, TrustError> {
56        // Fast path: already cached.
57        if let Some(cert) = self.cert_cache.get(cert_name) {
58            return Ok(cert);
59        }
60
61        // Check if someone is already fetching this cert.
62        if let Some(entry) = self.in_flight.get(cert_name) {
63            let mut rx = entry.subscribe();
64            drop(entry);
65            return match rx.recv().await {
66                Ok(Some(cert)) => Ok(cert),
67                _ => Err(TrustError::CertNotFound {
68                    name: cert_name.to_string(),
69                }),
70            };
71        }
72
73        // We're the first — initiate the fetch.
74        let (tx, _) = broadcast::channel(1);
75        self.in_flight.insert(Arc::clone(cert_name), tx.clone());
76
77        let result = self.do_fetch(cert_name).await;
78
79        // Notify waiters and clean up.
80        let cert = result.as_ref().ok().cloned();
81        let _ = tx.send(cert);
82        self.in_flight.remove(cert_name);
83
84        result
85    }
86
87    async fn do_fetch(&self, cert_name: &Arc<Name>) -> Result<Certificate, TrustError> {
88        let name = cert_name.as_ref().clone();
89
90        let data = tokio::time::timeout(self.timeout, (self.fetch_fn)(name))
91            .await
92            .map_err(|_| TrustError::CertNotFound {
93                name: format!("timeout fetching {}", cert_name),
94            })?
95            .ok_or_else(|| TrustError::CertNotFound {
96                name: cert_name.to_string(),
97            })?;
98
99        let cert = Certificate::decode(&data)?;
100        self.cert_cache.insert(cert.clone());
101        Ok(cert)
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use bytes::Bytes;
109    use ndn_packet::NameComponent;
110    use ndn_tlv::TlvWriter;
111    use std::sync::atomic::{AtomicUsize, Ordering};
112
113    fn make_cert_name(id: &str) -> Arc<Name> {
114        Arc::new(Name::from_components([
115            NameComponent::generic(Bytes::copy_from_slice(id.as_bytes())),
116            NameComponent::generic(Bytes::from_static(b"KEY")),
117            NameComponent::generic(Bytes::from_static(b"k1")),
118        ]))
119    }
120
121    /// Build a minimal cert Data packet for testing.
122    fn make_cert_data(name: &Name, pk: &[u8]) -> Data {
123        let mut signed = TlvWriter::new();
124        // Name
125        signed.write_nested(0x07, |w| {
126            for comp in name.components() {
127                w.write_tlv(comp.typ, &comp.value);
128            }
129        });
130        // Content with public key
131        signed.write_nested(0x15, |w| {
132            w.write_tlv(0x00, pk);
133        });
134        // SignatureInfo (minimal Ed25519)
135        signed.write_nested(0x16, |w| {
136            w.write_tlv(0x1b, &[5u8]);
137        });
138        let region = signed.finish();
139        let mut inner = region.to_vec();
140        {
141            let mut sw = TlvWriter::new();
142            sw.write_tlv(0x17, &[0u8; 64]);
143            inner.extend_from_slice(&sw.finish());
144        }
145        let mut outer = TlvWriter::new();
146        outer.write_tlv(0x06, &inner);
147        Data::decode(outer.finish()).unwrap()
148    }
149
150    #[tokio::test]
151    async fn cache_hit_skips_fetch() {
152        let cache = Arc::new(CertCache::new());
153        let cert_name = make_cert_name("alice");
154        cache.insert(Certificate {
155            name: Arc::clone(&cert_name),
156            public_key: Bytes::from_static(&[1; 32]),
157            valid_from: 0,
158            valid_until: u64::MAX,
159            issuer: None,
160            signed_region: None,
161            sig_value: None,
162        });
163
164        let fetch_count = Arc::new(AtomicUsize::new(0));
165        let fc = Arc::clone(&fetch_count);
166        let fetch_fn: FetchFn = Arc::new(move |_| {
167            fc.fetch_add(1, Ordering::Relaxed);
168            Box::pin(async { None })
169        });
170
171        let fetcher = CertFetcher::new(cache, fetch_fn, Duration::from_secs(1));
172        let cert = fetcher.fetch(&cert_name).await.unwrap();
173        assert_eq!(cert.public_key.as_ref(), &[1; 32]);
174        assert_eq!(fetch_count.load(Ordering::Relaxed), 0);
175    }
176
177    #[tokio::test]
178    async fn successful_fetch_caches_result() {
179        let cache = Arc::new(CertCache::new());
180        let cert_name = make_cert_name("bob");
181
182        let cn = Arc::clone(&cert_name);
183        let fetch_fn: FetchFn = Arc::new(move |_| {
184            let data = make_cert_data(&cn, &[2; 32]);
185            Box::pin(async move { Some(data) })
186        });
187
188        let fetcher = CertFetcher::new(Arc::clone(&cache), fetch_fn, Duration::from_secs(1));
189        let cert = fetcher.fetch(&cert_name).await.unwrap();
190        assert_eq!(cert.public_key.as_ref(), &[2; 32]);
191
192        // Should be in cache now.
193        assert!(cache.get(&cert_name).is_some());
194    }
195
196    #[tokio::test]
197    async fn fetch_timeout_returns_error() {
198        let cache = Arc::new(CertCache::new());
199        let cert_name = make_cert_name("slow");
200
201        let fetch_fn: FetchFn = Arc::new(|_| {
202            Box::pin(async {
203                tokio::time::sleep(Duration::from_secs(10)).await;
204                None
205            })
206        });
207
208        let fetcher = CertFetcher::new(cache, fetch_fn, Duration::from_millis(50));
209        let result = fetcher.fetch(&cert_name).await;
210        assert!(result.is_err());
211    }
212
213    #[tokio::test]
214    async fn deduplication_sends_one_interest() {
215        let cache = Arc::new(CertCache::new());
216        let cert_name = make_cert_name("carol");
217
218        let fetch_count = Arc::new(AtomicUsize::new(0));
219        let fc = Arc::clone(&fetch_count);
220        let cn = Arc::clone(&cert_name);
221        let fetch_fn: FetchFn = Arc::new(move |_| {
222            fc.fetch_add(1, Ordering::Relaxed);
223            let data = make_cert_data(&cn, &[3; 32]);
224            Box::pin(async move {
225                tokio::time::sleep(Duration::from_millis(50)).await;
226                Some(data)
227            })
228        });
229
230        let fetcher = Arc::new(CertFetcher::new(cache, fetch_fn, Duration::from_secs(1)));
231
232        // Launch two concurrent fetches for the same cert.
233        let f1 = {
234            let fetcher = Arc::clone(&fetcher);
235            let name = Arc::clone(&cert_name);
236            tokio::spawn(async move { fetcher.fetch(&name).await })
237        };
238        let f2 = {
239            let fetcher = Arc::clone(&fetcher);
240            let name = Arc::clone(&cert_name);
241            tokio::spawn(async move { fetcher.fetch(&name).await })
242        };
243
244        let (r1, r2) = tokio::join!(f1, f2);
245        assert!(r1.unwrap().is_ok());
246        assert!(r2.unwrap().is_ok());
247        // Only one actual fetch should have been made.
248        assert_eq!(fetch_count.load(Ordering::Relaxed), 1);
249    }
250}