ndn_security/
cert_fetcher.rs1use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::time::Duration;
11
12use dashmap::DashMap;
13use ndn_packet::{Data, Name};
14use tokio::sync::broadcast;
15
16use crate::TrustError;
17use crate::cert_cache::{CertCache, Certificate};
18
19pub type FetchFn =
24 Arc<dyn Fn(Name) -> Pin<Box<dyn Future<Output = Option<Data>> + Send>> + Send + Sync>;
25
26pub struct CertFetcher {
28 cert_cache: Arc<CertCache>,
29 fetch_fn: FetchFn,
30 in_flight: DashMap<Arc<Name>, broadcast::Sender<Option<Certificate>>>,
33 timeout: Duration,
34}
35
36impl CertFetcher {
37 pub fn new(cert_cache: Arc<CertCache>, fetch_fn: FetchFn, timeout: Duration) -> Self {
42 Self {
43 cert_cache,
44 fetch_fn,
45 in_flight: DashMap::new(),
46 timeout,
47 }
48 }
49
50 pub async fn fetch(&self, cert_name: &Arc<Name>) -> Result<Certificate, TrustError> {
56 if let Some(cert) = self.cert_cache.get(cert_name) {
58 return Ok(cert);
59 }
60
61 if let Some(entry) = self.in_flight.get(cert_name) {
63 let mut rx = entry.subscribe();
64 drop(entry);
65 return match rx.recv().await {
66 Ok(Some(cert)) => Ok(cert),
67 _ => Err(TrustError::CertNotFound {
68 name: cert_name.to_string(),
69 }),
70 };
71 }
72
73 let (tx, _) = broadcast::channel(1);
75 self.in_flight.insert(Arc::clone(cert_name), tx.clone());
76
77 let result = self.do_fetch(cert_name).await;
78
79 let cert = result.as_ref().ok().cloned();
81 let _ = tx.send(cert);
82 self.in_flight.remove(cert_name);
83
84 result
85 }
86
87 async fn do_fetch(&self, cert_name: &Arc<Name>) -> Result<Certificate, TrustError> {
88 let name = cert_name.as_ref().clone();
89
90 let data = tokio::time::timeout(self.timeout, (self.fetch_fn)(name))
91 .await
92 .map_err(|_| TrustError::CertNotFound {
93 name: format!("timeout fetching {}", cert_name),
94 })?
95 .ok_or_else(|| TrustError::CertNotFound {
96 name: cert_name.to_string(),
97 })?;
98
99 let cert = Certificate::decode(&data)?;
100 self.cert_cache.insert(cert.clone());
101 Ok(cert)
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use bytes::Bytes;
109 use ndn_packet::NameComponent;
110 use ndn_tlv::TlvWriter;
111 use std::sync::atomic::{AtomicUsize, Ordering};
112
113 fn make_cert_name(id: &str) -> Arc<Name> {
114 Arc::new(Name::from_components([
115 NameComponent::generic(Bytes::copy_from_slice(id.as_bytes())),
116 NameComponent::generic(Bytes::from_static(b"KEY")),
117 NameComponent::generic(Bytes::from_static(b"k1")),
118 ]))
119 }
120
121 fn make_cert_data(name: &Name, pk: &[u8]) -> Data {
123 let mut signed = TlvWriter::new();
124 signed.write_nested(0x07, |w| {
126 for comp in name.components() {
127 w.write_tlv(comp.typ, &comp.value);
128 }
129 });
130 signed.write_nested(0x15, |w| {
132 w.write_tlv(0x00, pk);
133 });
134 signed.write_nested(0x16, |w| {
136 w.write_tlv(0x1b, &[5u8]);
137 });
138 let region = signed.finish();
139 let mut inner = region.to_vec();
140 {
141 let mut sw = TlvWriter::new();
142 sw.write_tlv(0x17, &[0u8; 64]);
143 inner.extend_from_slice(&sw.finish());
144 }
145 let mut outer = TlvWriter::new();
146 outer.write_tlv(0x06, &inner);
147 Data::decode(outer.finish()).unwrap()
148 }
149
150 #[tokio::test]
151 async fn cache_hit_skips_fetch() {
152 let cache = Arc::new(CertCache::new());
153 let cert_name = make_cert_name("alice");
154 cache.insert(Certificate {
155 name: Arc::clone(&cert_name),
156 public_key: Bytes::from_static(&[1; 32]),
157 valid_from: 0,
158 valid_until: u64::MAX,
159 issuer: None,
160 signed_region: None,
161 sig_value: None,
162 });
163
164 let fetch_count = Arc::new(AtomicUsize::new(0));
165 let fc = Arc::clone(&fetch_count);
166 let fetch_fn: FetchFn = Arc::new(move |_| {
167 fc.fetch_add(1, Ordering::Relaxed);
168 Box::pin(async { None })
169 });
170
171 let fetcher = CertFetcher::new(cache, fetch_fn, Duration::from_secs(1));
172 let cert = fetcher.fetch(&cert_name).await.unwrap();
173 assert_eq!(cert.public_key.as_ref(), &[1; 32]);
174 assert_eq!(fetch_count.load(Ordering::Relaxed), 0);
175 }
176
177 #[tokio::test]
178 async fn successful_fetch_caches_result() {
179 let cache = Arc::new(CertCache::new());
180 let cert_name = make_cert_name("bob");
181
182 let cn = Arc::clone(&cert_name);
183 let fetch_fn: FetchFn = Arc::new(move |_| {
184 let data = make_cert_data(&cn, &[2; 32]);
185 Box::pin(async move { Some(data) })
186 });
187
188 let fetcher = CertFetcher::new(Arc::clone(&cache), fetch_fn, Duration::from_secs(1));
189 let cert = fetcher.fetch(&cert_name).await.unwrap();
190 assert_eq!(cert.public_key.as_ref(), &[2; 32]);
191
192 assert!(cache.get(&cert_name).is_some());
194 }
195
196 #[tokio::test]
197 async fn fetch_timeout_returns_error() {
198 let cache = Arc::new(CertCache::new());
199 let cert_name = make_cert_name("slow");
200
201 let fetch_fn: FetchFn = Arc::new(|_| {
202 Box::pin(async {
203 tokio::time::sleep(Duration::from_secs(10)).await;
204 None
205 })
206 });
207
208 let fetcher = CertFetcher::new(cache, fetch_fn, Duration::from_millis(50));
209 let result = fetcher.fetch(&cert_name).await;
210 assert!(result.is_err());
211 }
212
213 #[tokio::test]
214 async fn deduplication_sends_one_interest() {
215 let cache = Arc::new(CertCache::new());
216 let cert_name = make_cert_name("carol");
217
218 let fetch_count = Arc::new(AtomicUsize::new(0));
219 let fc = Arc::clone(&fetch_count);
220 let cn = Arc::clone(&cert_name);
221 let fetch_fn: FetchFn = Arc::new(move |_| {
222 fc.fetch_add(1, Ordering::Relaxed);
223 let data = make_cert_data(&cn, &[3; 32]);
224 Box::pin(async move {
225 tokio::time::sleep(Duration::from_millis(50)).await;
226 Some(data)
227 })
228 });
229
230 let fetcher = Arc::new(CertFetcher::new(cache, fetch_fn, Duration::from_secs(1)));
231
232 let f1 = {
234 let fetcher = Arc::clone(&fetcher);
235 let name = Arc::clone(&cert_name);
236 tokio::spawn(async move { fetcher.fetch(&name).await })
237 };
238 let f2 = {
239 let fetcher = Arc::clone(&fetcher);
240 let name = Arc::clone(&cert_name);
241 tokio::spawn(async move { fetcher.fetch(&name).await })
242 };
243
244 let (r1, r2) = tokio::join!(f1, f2);
245 assert!(r1.unwrap().is_ok());
246 assert!(r2.unwrap().is_ok());
247 assert_eq!(fetch_count.load(Ordering::Relaxed), 1);
249 }
250}