ndn_faces/net/
websocket.rs

1//! NDN face over WebSocket (binary frames).
2//!
3//! WebSocket provides its own message framing, so no `TlvCodec` is needed —
4//! each WebSocket binary message carries exactly one NDN packet (wrapped in
5//! NDNLPv2 `LpPacket`).
6//!
7//! Supports both client-initiated (`connect`) and server-accepted (`from_stream`)
8//! connections.  Compatible with NFD's WebSocket face.
9//!
10//! ## TLS support (feature `websocket-tls`)
11//!
12//! Enable the `websocket-tls` feature to unlock [`TlsConfig`] and
13//! [`WebSocketFace::listen_tls`].  Two modes are available:
14//!
15//! - [`TlsConfig::SelfSigned`] — an Ed25519 certificate is generated at runtime
16//!   using `rcgen`; no external CA needed.
17//! - [`TlsConfig::UserSupplied`] — load a PEM certificate and private key from
18//!   disk (e.g., Let's Encrypt or your own CA).
19//!
20//! Server-side TLS for ACME certificate distribution (SVS fleet cert sync)
21//! is targeted for v0.2.0.
22
23#[cfg(feature = "websocket-tls")]
24use std::path::PathBuf;
25
26use bytes::Bytes;
27use futures::{SinkExt, StreamExt};
28use tokio::net::TcpStream;
29use tokio::sync::Mutex;
30use tokio_tungstenite::tungstenite::Message;
31use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
32use tracing::trace;
33
34use ndn_transport::{Face, FaceError, FaceId, FaceKind};
35
36type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
37
38/// NDN face over WebSocket with binary message framing.
39///
40/// The WebSocket stream is split into independent read and write halves, each
41/// behind its own `Mutex` — mirroring the `TcpFace` pattern.
42pub struct WebSocketFace {
43    id: FaceId,
44    remote_addr: String,
45    local_addr: String,
46    reader: Mutex<futures::stream::SplitStream<WsStream>>,
47    writer: Mutex<futures::stream::SplitSink<WsStream, Message>>,
48}
49
50impl WebSocketFace {
51    /// Connect to a WebSocket endpoint (client side).
52    pub async fn connect(
53        id: FaceId,
54        url: &str,
55    ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
56        let (ws, _response) = tokio_tungstenite::connect_async(url).await?;
57
58        // Extract addresses from the underlying TCP stream before splitting.
59        let (remote_addr, local_addr) = match ws.get_ref() {
60            MaybeTlsStream::Plain(tcp) => (
61                tcp.peer_addr().map(|a| a.to_string()).unwrap_or_default(),
62                tcp.local_addr().map(|a| a.to_string()).unwrap_or_default(),
63            ),
64            _ => (url.to_string(), String::new()),
65        };
66
67        let (writer, reader) = ws.split();
68        Ok(Self {
69            id,
70            remote_addr,
71            local_addr,
72            reader: Mutex::new(reader),
73            writer: Mutex::new(writer),
74        })
75    }
76
77    /// Wrap an already-accepted WebSocket stream (server side).
78    pub fn from_stream(
79        id: FaceId,
80        ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
81        remote_addr: String,
82        local_addr: String,
83    ) -> Self {
84        let (writer, reader) = ws.split();
85        Self {
86            id,
87            remote_addr,
88            local_addr,
89            reader: Mutex::new(reader),
90            writer: Mutex::new(writer),
91        }
92    }
93
94    pub fn remote_addr(&self) -> &str {
95        &self.remote_addr
96    }
97    pub fn local_addr(&self) -> &str {
98        &self.local_addr
99    }
100}
101
102impl Face for WebSocketFace {
103    fn id(&self) -> FaceId {
104        self.id
105    }
106    fn kind(&self) -> FaceKind {
107        FaceKind::WebSocket
108    }
109
110    fn remote_uri(&self) -> Option<String> {
111        Some(self.remote_addr.clone())
112    }
113
114    fn local_uri(&self) -> Option<String> {
115        Some(self.local_addr.clone())
116    }
117
118    async fn recv(&self) -> Result<Bytes, FaceError> {
119        let mut reader = self.reader.lock().await;
120        loop {
121            let msg = reader
122                .next()
123                .await
124                .ok_or(FaceError::Closed)?
125                .map_err(|e| FaceError::Io(std::io::Error::other(e)))?;
126
127            match msg {
128                Message::Binary(data) => {
129                    trace!(face=%self.id, len=data.len(), "ws: recv binary");
130                    return Ok(data);
131                }
132                Message::Close(_) => return Err(FaceError::Closed),
133                // Skip text, ping, pong frames.
134                _ => continue,
135            }
136        }
137    }
138
139    async fn send(&self, pkt: Bytes) -> Result<(), FaceError> {
140        let wire = ndn_packet::lp::encode_lp_packet(&pkt);
141        trace!(face=%self.id, len=wire.len(), "ws: send binary");
142        let mut writer = self.writer.lock().await;
143        writer
144            .send(Message::Binary(wire.to_vec().into()))
145            .await
146            .map_err(|e| FaceError::Io(std::io::Error::other(e)))
147    }
148}
149
150// ── TLS support (feature = "websocket-tls") ───────────────────────────────────
151
152/// TLS configuration for a WebSocket server listener.
153///
154/// Used with [`WebSocketFace::listen_tls`].  Requires the `websocket-tls`
155/// feature.
156#[cfg(feature = "websocket-tls")]
157pub enum TlsConfig {
158    /// Generate a self-signed ECDSA certificate at runtime using `rcgen`.
159    ///
160    /// The certificate is valid for `localhost`. Clients must be configured
161    /// to trust this certificate explicitly (no CA verification).
162    SelfSigned,
163    /// Load a PEM-encoded certificate and private key from disk.
164    ///
165    /// Use this for certificates issued by a recognised CA (e.g., Let's
166    /// Encrypt or an internal CA). The files must be readable at startup.
167    UserSupplied {
168        /// Path to the PEM certificate chain file.
169        cert_pem: PathBuf,
170        /// Path to the PEM private key file (PKCS#8 or SEC1 format).
171        key_pem: PathBuf,
172    },
173}
174
175/// An NDN face over TLS WebSocket (server side).
176///
177/// Created by [`WebSocketFace::listen_tls`] + [`WebSocketListener::accept`].
178/// Implements [`Face`] identically to [`WebSocketFace`] but carries a rustls
179/// TLS layer instead of the client-side `MaybeTlsStream`.
180#[cfg(feature = "websocket-tls")]
181pub struct TlsWebSocketFace {
182    id: FaceId,
183    remote_addr: String,
184    local_addr: String,
185    reader: Mutex<
186        futures::stream::SplitStream<
187            tokio_tungstenite::WebSocketStream<
188                tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
189            >,
190        >,
191    >,
192    writer: Mutex<
193        futures::stream::SplitSink<
194            tokio_tungstenite::WebSocketStream<
195                tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
196            >,
197            Message,
198        >,
199    >,
200}
201
202#[cfg(feature = "websocket-tls")]
203impl Face for TlsWebSocketFace {
204    fn id(&self) -> FaceId {
205        self.id
206    }
207    fn kind(&self) -> FaceKind {
208        FaceKind::WebSocket
209    }
210    fn remote_uri(&self) -> Option<String> {
211        Some(self.remote_addr.clone())
212    }
213    fn local_uri(&self) -> Option<String> {
214        Some(self.local_addr.clone())
215    }
216
217    async fn recv(&self) -> Result<Bytes, FaceError> {
218        let mut reader = self.reader.lock().await;
219        loop {
220            let msg = reader
221                .next()
222                .await
223                .ok_or(FaceError::Closed)?
224                .map_err(|e| FaceError::Io(std::io::Error::other(e)))?;
225            match msg {
226                Message::Binary(data) => return Ok(data.into()),
227                Message::Close(_) => return Err(FaceError::Closed),
228                _ => continue,
229            }
230        }
231    }
232
233    async fn send(&self, pkt: Bytes) -> Result<(), FaceError> {
234        let wire = ndn_packet::lp::encode_lp_packet(&pkt);
235        let mut writer = self.writer.lock().await;
236        writer
237            .send(Message::Binary(wire.to_vec().into()))
238            .await
239            .map_err(|e| FaceError::Io(std::io::Error::other(e)))
240    }
241}
242
243/// A TLS WebSocket server listener, returned by [`WebSocketFace::listen_tls`].
244#[cfg(feature = "websocket-tls")]
245pub struct WebSocketListener {
246    inner: tokio::net::TcpListener,
247    acceptor: tokio_rustls::TlsAcceptor,
248}
249
250#[cfg(feature = "websocket-tls")]
251impl WebSocketListener {
252    /// Accept one incoming TLS WebSocket connection.
253    ///
254    /// Returns a [`TlsWebSocketFace`] ready for NDN traffic.
255    pub async fn accept(&self, id: FaceId) -> Result<TlsWebSocketFace, FaceError> {
256        let (tcp, peer) = self.inner.accept().await.map_err(FaceError::Io)?;
257        let local = tcp.local_addr().map(|a| a.to_string()).unwrap_or_default();
258        let remote = peer.to_string();
259
260        let tls_stream = self.acceptor.accept(tcp).await.map_err(FaceError::Io)?;
261        let ws = tokio_tungstenite::accept_async(tls_stream)
262            .await
263            .map_err(|e| FaceError::Io(std::io::Error::other(e)))?;
264        let (writer, reader) = ws.split();
265        Ok(TlsWebSocketFace {
266            id,
267            remote_addr: remote,
268            local_addr: local,
269            reader: Mutex::new(reader),
270            writer: Mutex::new(writer),
271        })
272    }
273}
274
275#[cfg(feature = "websocket-tls")]
276impl WebSocketFace {
277    /// Create a TLS WebSocket server listener bound to `addr`.
278    ///
279    /// The returned [`WebSocketListener`] accepts incoming connections via
280    /// [`WebSocketListener::accept`], each yielding a [`TlsWebSocketFace`]
281    /// that implements [`Face`].
282    ///
283    /// # Example
284    ///
285    /// ```rust,no_run
286    /// # use ndn_faces::net::websocket::{WebSocketFace, TlsConfig};
287    /// # use ndn_transport::FaceId;
288    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
289    /// let listener = WebSocketFace::listen_tls(
290    ///     "0.0.0.0:9696".parse()?,
291    ///     TlsConfig::SelfSigned,
292    /// ).await?;
293    ///
294    /// let face = listener.accept(FaceId(0)).await?;
295    /// // hand `face` off to the engine
296    /// # Ok(())
297    /// # }
298    /// ```
299    pub async fn listen_tls(
300        addr: std::net::SocketAddr,
301        tls: TlsConfig,
302    ) -> Result<WebSocketListener, FaceError> {
303        let config = build_tls_server_config(tls).await?;
304        let acceptor = tokio_rustls::TlsAcceptor::from(std::sync::Arc::new(config));
305        let inner = tokio::net::TcpListener::bind(addr)
306            .await
307            .map_err(FaceError::Io)?;
308        Ok(WebSocketListener { inner, acceptor })
309    }
310}
311
312/// Build a `rustls::ServerConfig` from the given [`TlsConfig`].
313#[cfg(feature = "websocket-tls")]
314async fn build_tls_server_config(
315    tls: TlsConfig,
316) -> Result<tokio_rustls::rustls::ServerConfig, FaceError> {
317    use tokio_rustls::rustls::{self, pki_types};
318
319    match tls {
320        TlsConfig::SelfSigned => {
321            let cert = rcgen::generate_simple_self_signed(vec!["localhost".to_owned()])
322                .map_err(|e| FaceError::Io(std::io::Error::other(e)))?;
323            let cert_der = pki_types::CertificateDer::from(cert.cert.der().to_vec());
324            let key_der = pki_types::PrivateKeyDer::try_from(cert.key_pair.serialize_der())
325                .map_err(|e| FaceError::Io(std::io::Error::other(e)))?;
326            let config = rustls::ServerConfig::builder()
327                .with_no_client_auth()
328                .with_single_cert(vec![cert_der], key_der)
329                .map_err(|e| FaceError::Io(std::io::Error::other(e)))?;
330            Ok(config)
331        }
332        TlsConfig::UserSupplied { cert_pem, key_pem } => {
333            let cert_bytes = std::fs::read(&cert_pem).map_err(FaceError::Io)?;
334            let key_bytes = std::fs::read(&key_pem).map_err(FaceError::Io)?;
335
336            let certs = rustls_pemfile::certs(&mut cert_bytes.as_slice())
337                .map(|r| r.map(|c| pki_types::CertificateDer::from(c.to_vec())))
338                .collect::<Result<Vec<_>, _>>()
339                .map_err(|e| FaceError::Io(std::io::Error::other(e)))?;
340            let key = rustls_pemfile::private_key(&mut key_bytes.as_slice())
341                .map_err(|e| FaceError::Io(std::io::Error::other(e)))?
342                .ok_or_else(|| FaceError::Io(std::io::Error::other("no private key in PEM")))?;
343
344            let config = rustls::ServerConfig::builder()
345                .with_no_client_auth()
346                .with_single_cert(certs, key)
347                .map_err(|e| FaceError::Io(std::io::Error::other(e)))?;
348            Ok(config)
349        }
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use tokio::net::TcpListener;
357    use tokio_tungstenite::accept_async;
358
359    async fn loopback_pair() -> (WebSocketFace, WebSocketFace) {
360        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
361        let addr = listener.local_addr().unwrap();
362        let url = format!("ws://127.0.0.1:{}", addr.port());
363
364        let accept_fut = async {
365            let (stream, peer) = listener.accept().await.unwrap();
366            let ws = accept_async(MaybeTlsStream::Plain(stream)).await.unwrap();
367            WebSocketFace::from_stream(FaceId(1), ws, peer.to_string(), addr.to_string())
368        };
369
370        let connect_fut = WebSocketFace::connect(FaceId(0), &url);
371
372        let (server, client) = tokio::join!(accept_fut, connect_fut);
373        (client.unwrap(), server)
374    }
375
376    fn make_tlv(tag: u8, value: &[u8]) -> Bytes {
377        use ndn_tlv::TlvWriter;
378        let mut w = TlvWriter::new();
379        w.write_tlv(tag as u64, value);
380        w.finish()
381    }
382
383    fn expected_on_wire(pkt: &Bytes) -> Bytes {
384        ndn_packet::lp::encode_lp_packet(pkt)
385    }
386
387    #[tokio::test]
388    async fn send_recv_single_packet() {
389        let (client, server) = loopback_pair().await;
390        let pkt = make_tlv(0x05, b"hello");
391        client.send(pkt.clone()).await.unwrap();
392        assert_eq!(server.recv().await.unwrap(), expected_on_wire(&pkt));
393    }
394
395    #[tokio::test]
396    async fn bidirectional_exchange() {
397        let (client, server) = loopback_pair().await;
398        client.send(make_tlv(0x05, b"interest")).await.unwrap();
399        server.send(make_tlv(0x06, b"data")).await.unwrap();
400        assert_eq!(
401            server.recv().await.unwrap(),
402            expected_on_wire(&make_tlv(0x05, b"interest"))
403        );
404        assert_eq!(
405            client.recv().await.unwrap(),
406            expected_on_wire(&make_tlv(0x06, b"data"))
407        );
408    }
409
410    #[tokio::test]
411    async fn concurrent_sends_arrive_intact() {
412        use std::sync::Arc;
413        let (client, server) = loopback_pair().await;
414        let client = Arc::new(client);
415
416        let handles: Vec<_> = (0u8..8)
417            .map(|i| {
418                let c = Arc::clone(&client);
419                tokio::spawn(async move {
420                    c.send(make_tlv(0x05, &[i])).await.unwrap();
421                })
422            })
423            .collect();
424        for h in handles {
425            h.await.unwrap();
426        }
427
428        let mut received = Vec::new();
429        for _ in 0u8..8 {
430            received.push(server.recv().await.unwrap());
431        }
432        assert_eq!(received.len(), 8);
433    }
434
435    #[tokio::test]
436    async fn close_detection() {
437        let (client, server) = loopback_pair().await;
438        drop(client);
439        // Server should detect the close.
440        let result = server.recv().await;
441        assert!(result.is_err());
442    }
443}