1#[cfg(feature = "websocket-tls")]
24use std::path::PathBuf;
25
26use bytes::Bytes;
27use futures::{SinkExt, StreamExt};
28use tokio::net::TcpStream;
29use tokio::sync::Mutex;
30use tokio_tungstenite::tungstenite::Message;
31use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
32use tracing::trace;
33
34use ndn_transport::{Face, FaceError, FaceId, FaceKind};
35
36type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
37
38pub struct WebSocketFace {
43 id: FaceId,
44 remote_addr: String,
45 local_addr: String,
46 reader: Mutex<futures::stream::SplitStream<WsStream>>,
47 writer: Mutex<futures::stream::SplitSink<WsStream, Message>>,
48}
49
50impl WebSocketFace {
51 pub async fn connect(
53 id: FaceId,
54 url: &str,
55 ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
56 let (ws, _response) = tokio_tungstenite::connect_async(url).await?;
57
58 let (remote_addr, local_addr) = match ws.get_ref() {
60 MaybeTlsStream::Plain(tcp) => (
61 tcp.peer_addr().map(|a| a.to_string()).unwrap_or_default(),
62 tcp.local_addr().map(|a| a.to_string()).unwrap_or_default(),
63 ),
64 _ => (url.to_string(), String::new()),
65 };
66
67 let (writer, reader) = ws.split();
68 Ok(Self {
69 id,
70 remote_addr,
71 local_addr,
72 reader: Mutex::new(reader),
73 writer: Mutex::new(writer),
74 })
75 }
76
77 pub fn from_stream(
79 id: FaceId,
80 ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
81 remote_addr: String,
82 local_addr: String,
83 ) -> Self {
84 let (writer, reader) = ws.split();
85 Self {
86 id,
87 remote_addr,
88 local_addr,
89 reader: Mutex::new(reader),
90 writer: Mutex::new(writer),
91 }
92 }
93
94 pub fn remote_addr(&self) -> &str {
95 &self.remote_addr
96 }
97 pub fn local_addr(&self) -> &str {
98 &self.local_addr
99 }
100}
101
102impl Face for WebSocketFace {
103 fn id(&self) -> FaceId {
104 self.id
105 }
106 fn kind(&self) -> FaceKind {
107 FaceKind::WebSocket
108 }
109
110 fn remote_uri(&self) -> Option<String> {
111 Some(self.remote_addr.clone())
112 }
113
114 fn local_uri(&self) -> Option<String> {
115 Some(self.local_addr.clone())
116 }
117
118 async fn recv(&self) -> Result<Bytes, FaceError> {
119 let mut reader = self.reader.lock().await;
120 loop {
121 let msg = reader
122 .next()
123 .await
124 .ok_or(FaceError::Closed)?
125 .map_err(|e| FaceError::Io(std::io::Error::other(e)))?;
126
127 match msg {
128 Message::Binary(data) => {
129 trace!(face=%self.id, len=data.len(), "ws: recv binary");
130 return Ok(data);
131 }
132 Message::Close(_) => return Err(FaceError::Closed),
133 _ => continue,
135 }
136 }
137 }
138
139 async fn send(&self, pkt: Bytes) -> Result<(), FaceError> {
140 let wire = ndn_packet::lp::encode_lp_packet(&pkt);
141 trace!(face=%self.id, len=wire.len(), "ws: send binary");
142 let mut writer = self.writer.lock().await;
143 writer
144 .send(Message::Binary(wire.to_vec().into()))
145 .await
146 .map_err(|e| FaceError::Io(std::io::Error::other(e)))
147 }
148}
149
150#[cfg(feature = "websocket-tls")]
157pub enum TlsConfig {
158 SelfSigned,
163 UserSupplied {
168 cert_pem: PathBuf,
170 key_pem: PathBuf,
172 },
173}
174
175#[cfg(feature = "websocket-tls")]
181pub struct TlsWebSocketFace {
182 id: FaceId,
183 remote_addr: String,
184 local_addr: String,
185 reader: Mutex<
186 futures::stream::SplitStream<
187 tokio_tungstenite::WebSocketStream<
188 tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
189 >,
190 >,
191 >,
192 writer: Mutex<
193 futures::stream::SplitSink<
194 tokio_tungstenite::WebSocketStream<
195 tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
196 >,
197 Message,
198 >,
199 >,
200}
201
202#[cfg(feature = "websocket-tls")]
203impl Face for TlsWebSocketFace {
204 fn id(&self) -> FaceId {
205 self.id
206 }
207 fn kind(&self) -> FaceKind {
208 FaceKind::WebSocket
209 }
210 fn remote_uri(&self) -> Option<String> {
211 Some(self.remote_addr.clone())
212 }
213 fn local_uri(&self) -> Option<String> {
214 Some(self.local_addr.clone())
215 }
216
217 async fn recv(&self) -> Result<Bytes, FaceError> {
218 let mut reader = self.reader.lock().await;
219 loop {
220 let msg = reader
221 .next()
222 .await
223 .ok_or(FaceError::Closed)?
224 .map_err(|e| FaceError::Io(std::io::Error::other(e)))?;
225 match msg {
226 Message::Binary(data) => return Ok(data.into()),
227 Message::Close(_) => return Err(FaceError::Closed),
228 _ => continue,
229 }
230 }
231 }
232
233 async fn send(&self, pkt: Bytes) -> Result<(), FaceError> {
234 let wire = ndn_packet::lp::encode_lp_packet(&pkt);
235 let mut writer = self.writer.lock().await;
236 writer
237 .send(Message::Binary(wire.to_vec().into()))
238 .await
239 .map_err(|e| FaceError::Io(std::io::Error::other(e)))
240 }
241}
242
243#[cfg(feature = "websocket-tls")]
245pub struct WebSocketListener {
246 inner: tokio::net::TcpListener,
247 acceptor: tokio_rustls::TlsAcceptor,
248}
249
250#[cfg(feature = "websocket-tls")]
251impl WebSocketListener {
252 pub async fn accept(&self, id: FaceId) -> Result<TlsWebSocketFace, FaceError> {
256 let (tcp, peer) = self.inner.accept().await.map_err(FaceError::Io)?;
257 let local = tcp.local_addr().map(|a| a.to_string()).unwrap_or_default();
258 let remote = peer.to_string();
259
260 let tls_stream = self.acceptor.accept(tcp).await.map_err(FaceError::Io)?;
261 let ws = tokio_tungstenite::accept_async(tls_stream)
262 .await
263 .map_err(|e| FaceError::Io(std::io::Error::other(e)))?;
264 let (writer, reader) = ws.split();
265 Ok(TlsWebSocketFace {
266 id,
267 remote_addr: remote,
268 local_addr: local,
269 reader: Mutex::new(reader),
270 writer: Mutex::new(writer),
271 })
272 }
273}
274
275#[cfg(feature = "websocket-tls")]
276impl WebSocketFace {
277 pub async fn listen_tls(
300 addr: std::net::SocketAddr,
301 tls: TlsConfig,
302 ) -> Result<WebSocketListener, FaceError> {
303 let config = build_tls_server_config(tls).await?;
304 let acceptor = tokio_rustls::TlsAcceptor::from(std::sync::Arc::new(config));
305 let inner = tokio::net::TcpListener::bind(addr)
306 .await
307 .map_err(FaceError::Io)?;
308 Ok(WebSocketListener { inner, acceptor })
309 }
310}
311
312#[cfg(feature = "websocket-tls")]
314async fn build_tls_server_config(
315 tls: TlsConfig,
316) -> Result<tokio_rustls::rustls::ServerConfig, FaceError> {
317 use tokio_rustls::rustls::{self, pki_types};
318
319 match tls {
320 TlsConfig::SelfSigned => {
321 let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()])
322 .map_err(|e| FaceError::Io(std::io::Error::other(e)))?;
323 let cert_der = pki_types::CertificateDer::from(cert.cert.der().to_vec());
324 let key_der = pki_types::PrivateKeyDer::try_from(cert.key_pair.serialize_der())
325 .map_err(|e| FaceError::Io(std::io::Error::other(e)))?;
326 let config = rustls::ServerConfig::builder()
327 .with_no_client_auth()
328 .with_single_cert(vec![cert_der], key_der)
329 .map_err(|e| FaceError::Io(std::io::Error::other(e)))?;
330 Ok(config)
331 }
332 TlsConfig::UserSupplied { cert_pem, key_pem } => {
333 let cert_bytes = std::fs::read(&cert_pem).map_err(FaceError::Io)?;
334 let key_bytes = std::fs::read(&key_pem).map_err(FaceError::Io)?;
335
336 let certs = rustls_pemfile::certs(&mut cert_bytes.as_slice())
337 .map(|r| r.map(|c| pki_types::CertificateDer::from(c.to_vec())))
338 .collect::<Result<Vec<_>, _>>()
339 .map_err(|e| FaceError::Io(std::io::Error::other(e)))?;
340 let key = rustls_pemfile::private_key(&mut key_bytes.as_slice())
341 .map_err(|e| FaceError::Io(std::io::Error::other(e)))?
342 .ok_or_else(|| FaceError::Io(std::io::Error::other("no private key in PEM")))?;
343
344 let config = rustls::ServerConfig::builder()
345 .with_no_client_auth()
346 .with_single_cert(certs, key)
347 .map_err(|e| FaceError::Io(std::io::Error::other(e)))?;
348 Ok(config)
349 }
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use tokio::net::TcpListener;
357 use tokio_tungstenite::accept_async;
358
359 async fn loopback_pair() -> (WebSocketFace, WebSocketFace) {
360 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
361 let addr = listener.local_addr().unwrap();
362 let url = format!("ws://127.0.0.1:{}", addr.port());
363
364 let accept_fut = async {
365 let (stream, peer) = listener.accept().await.unwrap();
366 let ws = accept_async(MaybeTlsStream::Plain(stream)).await.unwrap();
367 WebSocketFace::from_stream(FaceId(1), ws, peer.to_string(), addr.to_string())
368 };
369
370 let connect_fut = WebSocketFace::connect(FaceId(0), &url);
371
372 let (server, client) = tokio::join!(accept_fut, connect_fut);
373 (client.unwrap(), server)
374 }
375
376 fn make_tlv(tag: u8, value: &[u8]) -> Bytes {
377 use ndn_tlv::TlvWriter;
378 let mut w = TlvWriter::new();
379 w.write_tlv(tag as u64, value);
380 w.finish()
381 }
382
383 fn expected_on_wire(pkt: &Bytes) -> Bytes {
384 ndn_packet::lp::encode_lp_packet(pkt)
385 }
386
387 #[tokio::test]
388 async fn send_recv_single_packet() {
389 let (client, server) = loopback_pair().await;
390 let pkt = make_tlv(0x05, b"hello");
391 client.send(pkt.clone()).await.unwrap();
392 assert_eq!(server.recv().await.unwrap(), expected_on_wire(&pkt));
393 }
394
395 #[tokio::test]
396 async fn bidirectional_exchange() {
397 let (client, server) = loopback_pair().await;
398 client.send(make_tlv(0x05, b"interest")).await.unwrap();
399 server.send(make_tlv(0x06, b"data")).await.unwrap();
400 assert_eq!(
401 server.recv().await.unwrap(),
402 expected_on_wire(&make_tlv(0x05, b"interest"))
403 );
404 assert_eq!(
405 client.recv().await.unwrap(),
406 expected_on_wire(&make_tlv(0x06, b"data"))
407 );
408 }
409
410 #[tokio::test]
411 async fn concurrent_sends_arrive_intact() {
412 use std::sync::Arc;
413 let (client, server) = loopback_pair().await;
414 let client = Arc::new(client);
415
416 let handles: Vec<_> = (0u8..8)
417 .map(|i| {
418 let c = Arc::clone(&client);
419 tokio::spawn(async move {
420 c.send(make_tlv(0x05, &[i])).await.unwrap();
421 })
422 })
423 .collect();
424 for h in handles {
425 h.await.unwrap();
426 }
427
428 let mut received = Vec::new();
429 for _ in 0u8..8 {
430 received.push(server.recv().await.unwrap());
431 }
432 assert_eq!(received.len(), 8);
433 }
434
435 #[tokio::test]
436 async fn close_detection() {
437 let (client, server) = loopback_pair().await;
438 drop(client);
439 let result = server.recv().await;
441 assert!(result.is_err());
442 }
443}