1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4
5use bytes::Bytes;
6use tokio::net::UdpSocket;
7
8use tracing::trace;
9
10use ndn_packet::fragment::{DEFAULT_UDP_MTU, fragment_packet};
11use ndn_transport::{Face, FaceError, FaceId, FaceKind};
12
13pub struct UdpFace {
28 id: FaceId,
29 socket: Arc<UdpSocket>,
30 peer: SocketAddr,
31 mtu: usize,
32 seq: AtomicU64,
33}
34
35impl UdpFace {
36 pub async fn bind(local: SocketAddr, peer: SocketAddr, id: FaceId) -> std::io::Result<Self> {
46 let local = if local.ip().is_unspecified() {
47 let resolved = resolve_local_addr(peer, local.port())?;
48 trace!(peer=%peer, resolved=%resolved, "udp: resolved local addr for peer");
49 resolved
50 } else {
51 local
52 };
53 let socket = UdpSocket::bind(local).await?;
54 trace!(face=%id, local=%socket.local_addr().unwrap_or(local), peer=%peer, "udp: bound");
55 Ok(Self {
56 id,
57 socket: Arc::new(socket),
58 peer,
59 mtu: DEFAULT_UDP_MTU,
60 seq: AtomicU64::new(0),
61 })
62 }
63
64 pub fn from_socket(id: FaceId, socket: UdpSocket, peer: SocketAddr) -> Self {
66 Self {
67 id,
68 socket: Arc::new(socket),
69 peer,
70 mtu: DEFAULT_UDP_MTU,
71 seq: AtomicU64::new(0),
72 }
73 }
74
75 pub fn from_shared_socket(id: FaceId, socket: Arc<UdpSocket>, peer: SocketAddr) -> Self {
81 Self {
82 id,
83 socket,
84 peer,
85 mtu: DEFAULT_UDP_MTU,
86 seq: AtomicU64::new(0),
87 }
88 }
89
90 pub fn peer(&self) -> SocketAddr {
91 self.peer
92 }
93}
94
95impl Face for UdpFace {
96 fn id(&self) -> FaceId {
97 self.id
98 }
99 fn kind(&self) -> FaceKind {
100 FaceKind::Udp
101 }
102
103 fn remote_uri(&self) -> Option<String> {
104 Some(format!("udp4://{}:{}", self.peer.ip(), self.peer.port()))
105 }
106
107 fn local_uri(&self) -> Option<String> {
108 self.socket
109 .local_addr()
110 .ok()
111 .map(|a| format!("udp4://{}:{}", a.ip(), a.port()))
112 }
113
114 async fn recv(&self) -> Result<Bytes, FaceError> {
122 let mut buf = [0u8; 9000];
125 loop {
126 let (n, src) = self.socket.recv_from(&mut buf).await?;
127 if src == self.peer {
128 trace!(face=%self.id, peer=%self.peer, len=n, "udp: recv");
129 return Ok(Bytes::copy_from_slice(&buf[..n]));
130 }
131 trace!(face=%self.id, expected=%self.peer, actual=%src, len=n, "udp: recv ignored (wrong source)");
132 }
133 }
134
135 async fn send(&self, pkt: Bytes) -> Result<(), FaceError> {
136 if ndn_packet::lp::is_lp_packet(&pkt) {
139 trace!(face=%self.id, peer=%self.peer, len=pkt.len(), "udp: send (passthrough)");
140 self.socket.send_to(&pkt, self.peer).await?;
141 return Ok(());
142 }
143
144 if pkt.len() + 4 <= self.mtu {
148 let wire = ndn_packet::lp::encode_lp_packet(&pkt);
149 trace!(face=%self.id, peer=%self.peer, len=wire.len(), "udp: send");
150 self.socket.send_to(&wire, self.peer).await?;
151 } else {
152 let seq = self.seq.fetch_add(1, Ordering::Relaxed);
153 let fragments = fragment_packet(&pkt, self.mtu, seq);
154 trace!(face=%self.id, peer=%self.peer, frags=fragments.len(), seq, "udp: send fragmented");
155 for frag in &fragments {
158 match self.socket.try_send_to(frag, self.peer) {
159 Ok(_) => {}
160 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
161 self.socket.send_to(frag, self.peer).await?;
163 }
164 Err(e) => return Err(e.into()),
165 }
166 }
167 }
168 Ok(())
169 }
170}
171
172fn resolve_local_addr(peer: SocketAddr, port: u16) -> std::io::Result<SocketAddr> {
179 let probe = std::net::UdpSocket::bind(if peer.is_ipv4() {
180 "0.0.0.0:0"
181 } else {
182 "[::]:0"
183 })?;
184 probe.connect(peer)?;
185 let mut local = probe.local_addr()?;
186 local.set_port(port);
187 Ok(local)
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 fn test_packet(tag: u8) -> Bytes {
195 use ndn_tlv::TlvWriter;
196 let mut w = TlvWriter::new();
197 w.write_tlv(0x05, &[tag]);
198 w.finish()
199 }
200
201 fn expected_on_wire(pkt: &Bytes) -> Bytes {
203 ndn_packet::lp::encode_lp_packet(pkt)
204 }
205
206 async fn face_pair() -> (UdpFace, UdpFace) {
207 let sock_a = UdpSocket::bind("127.0.0.1:0").await.unwrap();
208 let sock_b = UdpSocket::bind("127.0.0.1:0").await.unwrap();
209 let addr_a = sock_a.local_addr().unwrap();
210 let addr_b = sock_b.local_addr().unwrap();
211
212 let face_a = UdpFace::from_socket(FaceId(0), sock_a, addr_b);
213 let face_b = UdpFace::from_socket(FaceId(1), sock_b, addr_a);
214 (face_a, face_b)
215 }
216
217 #[tokio::test]
218 async fn udp_roundtrip() {
219 let (face_a, face_b) = face_pair().await;
220
221 let pkt = test_packet(0xAB);
222 face_a.send(pkt.clone()).await.unwrap();
223 let received = face_b.recv().await.unwrap();
224 assert_eq!(received, expected_on_wire(&pkt));
225 }
226
227 #[tokio::test]
228 async fn udp_bidirectional() {
229 let (face_a, face_b) = face_pair().await;
230
231 face_a.send(test_packet(1)).await.unwrap();
232 face_b.send(test_packet(2)).await.unwrap();
233
234 assert_eq!(
235 face_b.recv().await.unwrap(),
236 expected_on_wire(&test_packet(1))
237 );
238 assert_eq!(
239 face_a.recv().await.unwrap(),
240 expected_on_wire(&test_packet(2))
241 );
242 }
243
244 #[tokio::test]
245 async fn udp_multiple_sequential() {
246 let (face_a, face_b) = face_pair().await;
247
248 for i in 0u8..5 {
249 face_a.send(test_packet(i)).await.unwrap();
250 assert_eq!(
251 face_b.recv().await.unwrap(),
252 expected_on_wire(&test_packet(i))
253 );
254 }
255 }
256
257 #[test]
258 fn accessors() {
259 let peer: SocketAddr = "127.0.0.1:9999".parse().unwrap();
261 assert_eq!(FaceId(7).0, 7);
262 assert_eq!(FaceKind::Udp, FaceKind::Udp);
263 let _ = peer; }
265}