1use std::collections::HashMap;
2
3use tokio::sync::RwLock;
4
5use ndn_packet::Name;
6
7#[derive(Clone, Debug, PartialEq, Eq)]
9pub struct StateVectorEntry {
10 pub node: String,
12 pub seq: u64,
13}
14
15pub struct SvsNode {
25 local_key: String,
26 vector: RwLock<HashMap<String, u64>>,
27}
28
29impl SvsNode {
30 pub fn new(local_name: &Name) -> Self {
31 let key = local_name.to_string();
32 let mut map = HashMap::new();
33 map.insert(key.clone(), 0u64);
34 Self {
35 local_key: key,
36 vector: RwLock::new(map),
37 }
38 }
39
40 pub fn local_key(&self) -> &str {
41 &self.local_key
42 }
43
44 pub async fn local_seq(&self) -> u64 {
46 *self.vector.read().await.get(&self.local_key).unwrap_or(&0)
47 }
48
49 pub async fn advance(&self) -> u64 {
51 let mut map = self.vector.write().await;
52 let seq = map.entry(self.local_key.clone()).or_insert(0);
53 *seq += 1;
54 *seq
55 }
56
57 pub async fn merge(&self, received: &[(String, u64)]) -> Vec<(String, u64, u64)> {
64 let mut gaps = Vec::new();
65 let mut map = self.vector.write().await;
66 for (node, remote_seq) in received {
67 let local_seq = map.entry(node.clone()).or_insert(0);
68 if *remote_seq > *local_seq {
69 gaps.push((node.clone(), *local_seq + 1, *remote_seq));
70 *local_seq = *remote_seq;
71 }
72 }
73 gaps
74 }
75
76 pub async fn snapshot(&self) -> Vec<StateVectorEntry> {
78 self.vector
79 .read()
80 .await
81 .iter()
82 .map(|(k, &seq)| StateVectorEntry {
83 node: k.clone(),
84 seq,
85 })
86 .collect()
87 }
88
89 pub async fn seq_for(&self, node_key: &str) -> u64 {
91 *self.vector.read().await.get(node_key).unwrap_or(&0)
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use bytes::Bytes;
99 use ndn_packet::NameComponent;
100
101 fn name(s: &'static str) -> Name {
102 Name::from_components([NameComponent::generic(Bytes::from_static(s.as_bytes()))])
103 }
104
105 #[tokio::test]
106 async fn new_node_starts_at_seq_zero() {
107 let node = SvsNode::new(&name("a"));
108 assert_eq!(node.local_seq().await, 0);
109 }
110
111 #[tokio::test]
112 async fn advance_increments_seq() {
113 let node = SvsNode::new(&name("a"));
114 assert_eq!(node.advance().await, 1);
115 assert_eq!(node.advance().await, 2);
116 assert_eq!(node.local_seq().await, 2);
117 }
118
119 #[tokio::test]
120 async fn merge_updates_higher_seq() {
121 let node = SvsNode::new(&name("a"));
122 let gaps = node.merge(&[("b".to_string(), 3)]).await;
123 assert_eq!(gaps.len(), 1);
124 assert_eq!(gaps[0], ("b".to_string(), 1, 3));
125 assert_eq!(node.seq_for("b").await, 3);
126 }
127
128 #[tokio::test]
129 async fn merge_ignores_equal_or_lower_seq() {
130 let node = SvsNode::new(&name("a"));
131 node.merge(&[("b".to_string(), 5)]).await;
132 let gaps = node.merge(&[("b".to_string(), 3)]).await;
133 assert!(gaps.is_empty());
134 assert_eq!(node.seq_for("b").await, 5);
135 }
136
137 #[tokio::test]
138 async fn merge_does_not_downgrade_local_seq() {
139 let node = SvsNode::new(&name("a"));
140 node.advance().await;
141 let local_key = node.local_key().to_string();
142 let gaps = node.merge(&[(local_key, 0)]).await;
144 assert!(gaps.is_empty());
145 assert_eq!(node.local_seq().await, 1);
146 }
147
148 #[tokio::test]
149 async fn snapshot_contains_local_entry() {
150 let node = SvsNode::new(&name("a"));
151 let snap = node.snapshot().await;
152 assert_eq!(snap.len(), 1);
153 assert_eq!(snap[0].seq, 0);
154 }
155
156 #[tokio::test]
157 async fn merge_multiple_peers() {
158 let node = SvsNode::new(&name("a"));
159 let gaps = node
160 .merge(&[("b".to_string(), 2), ("c".to_string(), 4)])
161 .await;
162 assert_eq!(gaps.len(), 2);
163 assert_eq!(node.seq_for("b").await, 2);
164 assert_eq!(node.seq_for("c").await, 4);
165 }
166}