ndn_face_net/
websocket.rs1use bytes::Bytes;
11use futures::{SinkExt, StreamExt};
12use tokio::net::TcpStream;
13use tokio::sync::Mutex;
14use tokio_tungstenite::tungstenite::Message;
15use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
16use tracing::trace;
17
18use ndn_transport::{Face, FaceError, FaceId, FaceKind};
19
20type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
21
22pub struct WebSocketFace {
27 id: FaceId,
28 remote_addr: String,
29 local_addr: String,
30 reader: Mutex<futures::stream::SplitStream<WsStream>>,
31 writer: Mutex<futures::stream::SplitSink<WsStream, Message>>,
32}
33
34impl WebSocketFace {
35 pub async fn connect(
37 id: FaceId,
38 url: &str,
39 ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
40 let (ws, _response) = tokio_tungstenite::connect_async(url).await?;
41
42 let (remote_addr, local_addr) = match ws.get_ref() {
44 MaybeTlsStream::Plain(tcp) => (
45 tcp.peer_addr().map(|a| a.to_string()).unwrap_or_default(),
46 tcp.local_addr().map(|a| a.to_string()).unwrap_or_default(),
47 ),
48 _ => (url.to_string(), String::new()),
49 };
50
51 let (writer, reader) = ws.split();
52 Ok(Self {
53 id,
54 remote_addr,
55 local_addr,
56 reader: Mutex::new(reader),
57 writer: Mutex::new(writer),
58 })
59 }
60
61 pub fn from_stream(
63 id: FaceId,
64 ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
65 remote_addr: String,
66 local_addr: String,
67 ) -> Self {
68 let (writer, reader) = ws.split();
69 Self {
70 id,
71 remote_addr,
72 local_addr,
73 reader: Mutex::new(reader),
74 writer: Mutex::new(writer),
75 }
76 }
77
78 pub fn remote_addr(&self) -> &str {
79 &self.remote_addr
80 }
81 pub fn local_addr(&self) -> &str {
82 &self.local_addr
83 }
84}
85
86impl Face for WebSocketFace {
87 fn id(&self) -> FaceId {
88 self.id
89 }
90 fn kind(&self) -> FaceKind {
91 FaceKind::WebSocket
92 }
93
94 fn remote_uri(&self) -> Option<String> {
95 Some(self.remote_addr.clone())
96 }
97
98 fn local_uri(&self) -> Option<String> {
99 Some(self.local_addr.clone())
100 }
101
102 async fn recv(&self) -> Result<Bytes, FaceError> {
103 let mut reader = self.reader.lock().await;
104 loop {
105 let msg = reader
106 .next()
107 .await
108 .ok_or(FaceError::Closed)?
109 .map_err(|e| FaceError::Io(std::io::Error::other(e)))?;
110
111 match msg {
112 Message::Binary(data) => {
113 trace!(face=%self.id, len=data.len(), "ws: recv binary");
114 return Ok(data);
115 }
116 Message::Close(_) => return Err(FaceError::Closed),
117 _ => continue,
119 }
120 }
121 }
122
123 async fn send(&self, pkt: Bytes) -> Result<(), FaceError> {
124 let wire = ndn_packet::lp::encode_lp_packet(&pkt);
125 trace!(face=%self.id, len=wire.len(), "ws: send binary");
126 let mut writer = self.writer.lock().await;
127 writer
128 .send(Message::Binary(wire.to_vec().into()))
129 .await
130 .map_err(|e| FaceError::Io(std::io::Error::other(e)))
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use tokio::net::TcpListener;
138 use tokio_tungstenite::accept_async;
139
140 async fn loopback_pair() -> (WebSocketFace, WebSocketFace) {
141 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
142 let addr = listener.local_addr().unwrap();
143 let url = format!("ws://127.0.0.1:{}", addr.port());
144
145 let accept_fut = async {
146 let (stream, peer) = listener.accept().await.unwrap();
147 let ws = accept_async(MaybeTlsStream::Plain(stream)).await.unwrap();
148 WebSocketFace::from_stream(FaceId(1), ws, peer.to_string(), addr.to_string())
149 };
150
151 let connect_fut = WebSocketFace::connect(FaceId(0), &url);
152
153 let (server, client) = tokio::join!(accept_fut, connect_fut);
154 (client.unwrap(), server)
155 }
156
157 fn make_tlv(tag: u8, value: &[u8]) -> Bytes {
158 use ndn_tlv::TlvWriter;
159 let mut w = TlvWriter::new();
160 w.write_tlv(tag as u64, value);
161 w.finish()
162 }
163
164 fn expected_on_wire(pkt: &Bytes) -> Bytes {
165 ndn_packet::lp::encode_lp_packet(pkt)
166 }
167
168 #[tokio::test]
169 async fn send_recv_single_packet() {
170 let (client, server) = loopback_pair().await;
171 let pkt = make_tlv(0x05, b"hello");
172 client.send(pkt.clone()).await.unwrap();
173 assert_eq!(server.recv().await.unwrap(), expected_on_wire(&pkt));
174 }
175
176 #[tokio::test]
177 async fn bidirectional_exchange() {
178 let (client, server) = loopback_pair().await;
179 client.send(make_tlv(0x05, b"interest")).await.unwrap();
180 server.send(make_tlv(0x06, b"data")).await.unwrap();
181 assert_eq!(
182 server.recv().await.unwrap(),
183 expected_on_wire(&make_tlv(0x05, b"interest"))
184 );
185 assert_eq!(
186 client.recv().await.unwrap(),
187 expected_on_wire(&make_tlv(0x06, b"data"))
188 );
189 }
190
191 #[tokio::test]
192 async fn concurrent_sends_arrive_intact() {
193 use std::sync::Arc;
194 let (client, server) = loopback_pair().await;
195 let client = Arc::new(client);
196
197 let handles: Vec<_> = (0u8..8)
198 .map(|i| {
199 let c = Arc::clone(&client);
200 tokio::spawn(async move {
201 c.send(make_tlv(0x05, &[i])).await.unwrap();
202 })
203 })
204 .collect();
205 for h in handles {
206 h.await.unwrap();
207 }
208
209 let mut received = Vec::new();
210 for _ in 0u8..8 {
211 received.push(server.recv().await.unwrap());
212 }
213 assert_eq!(received.len(), 8);
214 }
215
216 #[tokio::test]
217 async fn close_detection() {
218 let (client, server) = loopback_pair().await;
219 drop(client);
220 let result = server.recv().await;
222 assert!(result.is_err());
223 }
224}