1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use dashmap::DashMap;
5use ndn_packet::Name;
6use ndn_transport::FaceId;
7
8use crate::fib::{Fib, FibNexthop};
9
10#[derive(Clone, Debug)]
16pub struct RibRoute {
17 pub face_id: FaceId,
18 pub origin: u64,
21 pub cost: u32,
22 pub flags: u64,
25 pub expires_at: Option<Instant>,
27}
28
29impl RibRoute {
30 pub fn remaining(&self) -> Option<Duration> {
32 self.expires_at
33 .map(|exp| exp.saturating_duration_since(Instant::now()))
34 }
35}
36
37pub struct Rib {
69 routes: DashMap<Name, Vec<RibRoute>>,
70}
71
72impl Rib {
73 pub fn new() -> Self {
74 Self {
75 routes: DashMap::new(),
76 }
77 }
78
79 pub fn add(&self, prefix: &Name, route: RibRoute) -> bool {
85 let mut entry = self.routes.entry(prefix.clone()).or_default();
86 let routes = entry.value_mut();
87 if let Some(existing) = routes
88 .iter_mut()
89 .find(|r| r.face_id == route.face_id && r.origin == route.origin)
90 {
91 let changed = existing.cost != route.cost
92 || existing.flags != route.flags
93 || existing.expires_at != route.expires_at;
94 *existing = route;
95 changed
96 } else {
97 routes.push(route);
98 true
99 }
100 }
101
102 pub fn remove(&self, prefix: &Name, face_id: FaceId, origin: u64) -> bool {
106 let Some(mut entry) = self.routes.get_mut(prefix) else {
107 return false;
108 };
109 let before = entry.len();
110 entry.retain(|r| !(r.face_id == face_id && r.origin == origin));
111 let changed = entry.len() != before;
112 if entry.is_empty() {
113 drop(entry);
114 self.routes.remove(prefix);
115 }
116 changed
117 }
118
119 pub fn remove_nexthop(&self, prefix: &Name, face_id: FaceId) -> bool {
123 let Some(mut entry) = self.routes.get_mut(prefix) else {
124 return false;
125 };
126 let before = entry.len();
127 entry.retain(|r| r.face_id != face_id);
128 let changed = entry.len() != before;
129 if entry.is_empty() {
130 drop(entry);
131 self.routes.remove(prefix);
132 }
133 changed
134 }
135
136 pub fn flush_origin(&self, origin: u64) -> Vec<Name> {
143 let mut affected = Vec::new();
144 self.routes.retain(|name, routes| {
145 let before = routes.len();
146 routes.retain(|r| r.origin != origin);
147 if routes.len() != before {
148 affected.push(name.clone());
149 }
150 !routes.is_empty()
151 });
152 affected
153 }
154
155 pub fn flush_face(&self, face_id: FaceId) -> Vec<Name> {
162 let mut affected = Vec::new();
163 self.routes.retain(|name, routes| {
164 let before = routes.len();
165 routes.retain(|r| r.face_id != face_id);
166 if routes.len() != before {
167 affected.push(name.clone());
168 }
169 !routes.is_empty()
170 });
171 affected
172 }
173
174 pub fn drain_expired(&self) -> Vec<Name> {
181 let now = Instant::now();
182 let mut affected = Vec::new();
183 self.routes.retain(|name, routes| {
184 let before = routes.len();
185 routes.retain(|r| r.expires_at.is_none_or(|exp| exp > now));
186 if routes.len() != before {
187 affected.push(name.clone());
188 }
189 !routes.is_empty()
190 });
191 affected
192 }
193
194 pub fn apply_to_fib(&self, prefix: &Name, fib: &Fib) {
200 let Some(entry) = self.routes.get(prefix) else {
201 fib.set_nexthops(prefix, Vec::new());
202 return;
203 };
204
205 let mut best: HashMap<FaceId, (u32, u64)> = HashMap::new();
206 for route in entry.iter() {
207 let e = best.entry(route.face_id).or_insert((u32::MAX, u64::MAX));
208 if route.cost < e.0 || (route.cost == e.0 && route.origin < e.1) {
209 *e = (route.cost, route.origin);
210 }
211 }
212
213 let nexthops: Vec<FibNexthop> = best
214 .into_iter()
215 .map(|(face_id, (cost, _))| FibNexthop { face_id, cost })
216 .collect();
217
218 fib.set_nexthops(prefix, nexthops);
219 }
220
221 pub fn handle_face_down(&self, face_id: FaceId, fib: &Fib) {
229 let affected = self.flush_face(face_id);
230 for prefix in &affected {
231 self.apply_to_fib(prefix, fib);
232 }
233 }
234
235 pub fn dump(&self) -> Vec<(Name, Vec<RibRoute>)> {
237 self.routes
238 .iter()
239 .map(|e| (e.key().clone(), e.value().clone()))
240 .collect()
241 }
242}
243
244impl Default for Rib {
245 fn default() -> Self {
246 Self::new()
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use bytes::Bytes;
254 use ndn_packet::NameComponent;
255
256 fn name(s: &'static str) -> Name {
257 Name::from_components([NameComponent::generic(Bytes::from_static(s.as_bytes()))])
258 }
259
260 fn route(face_id: u32, origin: u64, cost: u32) -> RibRoute {
261 RibRoute {
262 face_id: FaceId(face_id),
263 origin,
264 cost,
265 flags: 0,
266 expires_at: None,
267 }
268 }
269
270 #[test]
271 fn add_and_dump() {
272 let rib = Rib::new();
273 rib.add(&name("ndn"), route(1, 128, 5));
274 let entries = rib.dump();
275 assert_eq!(entries.len(), 1);
276 assert_eq!(entries[0].1.len(), 1);
277 }
278
279 #[test]
280 fn add_updates_existing() {
281 let rib = Rib::new();
282 rib.add(&name("ndn"), route(1, 128, 5));
283 rib.add(&name("ndn"), route(1, 128, 10));
284 let entries = rib.dump();
285 assert_eq!(entries[0].1.len(), 1);
286 assert_eq!(entries[0].1[0].cost, 10);
287 }
288
289 #[test]
290 fn multiple_origins_same_face() {
291 let rib = Rib::new();
292 rib.add(&name("ndn"), route(1, 128, 5)); rib.add(&name("ndn"), route(1, 255, 100)); let entries = rib.dump();
295 assert_eq!(entries[0].1.len(), 2);
296 }
297
298 #[test]
299 fn remove_by_face_and_origin() {
300 let rib = Rib::new();
301 rib.add(&name("ndn"), route(1, 128, 5));
302 rib.add(&name("ndn"), route(1, 255, 100));
303 rib.remove(&name("ndn"), FaceId(1), 128);
304 let entries = rib.dump();
305 assert_eq!(entries[0].1.len(), 1);
307 assert_eq!(entries[0].1[0].origin, 255);
308 }
309
310 #[test]
311 fn flush_origin_removes_matching() {
312 let rib = Rib::new();
313 rib.add(&name("a"), route(1, 128, 5));
314 rib.add(&name("b"), route(2, 128, 10));
315 rib.add(&name("a"), route(1, 255, 100));
316
317 let affected = rib.flush_origin(128);
318 assert_eq!(affected.len(), 2);
319 let entries = rib.dump();
321 assert_eq!(entries.len(), 1);
322 assert_eq!(entries[0].1[0].origin, 255);
323 }
324
325 #[test]
326 fn flush_face_removes_all_for_face() {
327 let rib = Rib::new();
328 rib.add(&name("a"), route(1, 128, 5));
329 rib.add(&name("a"), route(2, 128, 10));
330 rib.add(&name("b"), route(1, 128, 3));
331
332 let affected = rib.flush_face(FaceId(1));
333 assert_eq!(affected.len(), 2);
334 let entries = rib.dump();
336 assert_eq!(entries.len(), 1);
337 assert_eq!(entries[0].1[0].face_id, FaceId(2));
338 }
339
340 #[test]
341 fn drain_expired_removes_stale() {
342 let rib = Rib::new();
343 let past = Instant::now() - Duration::from_secs(1);
344 rib.add(
345 &name("a"),
346 RibRoute {
347 face_id: FaceId(1),
348 origin: 128,
349 cost: 5,
350 flags: 0,
351 expires_at: Some(past),
352 },
353 );
354 rib.add(&name("b"), route(2, 128, 10)); let affected = rib.drain_expired();
357 assert_eq!(affected.len(), 1);
358 assert_eq!(rib.dump().len(), 1);
359 }
360}