ndn_face_net/
websocket.rs

1//! NDN face over WebSocket (binary frames).
2//!
3//! WebSocket provides its own message framing, so no `TlvCodec` is needed —
4//! each WebSocket binary message carries exactly one NDN packet (wrapped in
5//! NDNLPv2 `LpPacket`).
6//!
7//! Supports both client-initiated (`connect`) and server-accepted (`from_stream`)
8//! connections.  Compatible with NFD's WebSocket face.
9
10use bytes::Bytes;
11use futures::{SinkExt, StreamExt};
12use tokio::net::TcpStream;
13use tokio::sync::Mutex;
14use tokio_tungstenite::tungstenite::Message;
15use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
16use tracing::trace;
17
18use ndn_transport::{Face, FaceError, FaceId, FaceKind};
19
20type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
21
22/// NDN face over WebSocket with binary message framing.
23///
24/// The WebSocket stream is split into independent read and write halves, each
25/// behind its own `Mutex` — mirroring the `TcpFace` pattern.
26pub struct WebSocketFace {
27    id: FaceId,
28    remote_addr: String,
29    local_addr: String,
30    reader: Mutex<futures::stream::SplitStream<WsStream>>,
31    writer: Mutex<futures::stream::SplitSink<WsStream, Message>>,
32}
33
34impl WebSocketFace {
35    /// Connect to a WebSocket endpoint (client side).
36    pub async fn connect(
37        id: FaceId,
38        url: &str,
39    ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
40        let (ws, _response) = tokio_tungstenite::connect_async(url).await?;
41
42        // Extract addresses from the underlying TCP stream before splitting.
43        let (remote_addr, local_addr) = match ws.get_ref() {
44            MaybeTlsStream::Plain(tcp) => (
45                tcp.peer_addr().map(|a| a.to_string()).unwrap_or_default(),
46                tcp.local_addr().map(|a| a.to_string()).unwrap_or_default(),
47            ),
48            _ => (url.to_string(), String::new()),
49        };
50
51        let (writer, reader) = ws.split();
52        Ok(Self {
53            id,
54            remote_addr,
55            local_addr,
56            reader: Mutex::new(reader),
57            writer: Mutex::new(writer),
58        })
59    }
60
61    /// Wrap an already-accepted WebSocket stream (server side).
62    pub fn from_stream(
63        id: FaceId,
64        ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
65        remote_addr: String,
66        local_addr: String,
67    ) -> Self {
68        let (writer, reader) = ws.split();
69        Self {
70            id,
71            remote_addr,
72            local_addr,
73            reader: Mutex::new(reader),
74            writer: Mutex::new(writer),
75        }
76    }
77
78    pub fn remote_addr(&self) -> &str {
79        &self.remote_addr
80    }
81    pub fn local_addr(&self) -> &str {
82        &self.local_addr
83    }
84}
85
86impl Face for WebSocketFace {
87    fn id(&self) -> FaceId {
88        self.id
89    }
90    fn kind(&self) -> FaceKind {
91        FaceKind::WebSocket
92    }
93
94    fn remote_uri(&self) -> Option<String> {
95        Some(self.remote_addr.clone())
96    }
97
98    fn local_uri(&self) -> Option<String> {
99        Some(self.local_addr.clone())
100    }
101
102    async fn recv(&self) -> Result<Bytes, FaceError> {
103        let mut reader = self.reader.lock().await;
104        loop {
105            let msg = reader
106                .next()
107                .await
108                .ok_or(FaceError::Closed)?
109                .map_err(|e| FaceError::Io(std::io::Error::other(e)))?;
110
111            match msg {
112                Message::Binary(data) => {
113                    trace!(face=%self.id, len=data.len(), "ws: recv binary");
114                    return Ok(data);
115                }
116                Message::Close(_) => return Err(FaceError::Closed),
117                // Skip text, ping, pong frames.
118                _ => continue,
119            }
120        }
121    }
122
123    async fn send(&self, pkt: Bytes) -> Result<(), FaceError> {
124        let wire = ndn_packet::lp::encode_lp_packet(&pkt);
125        trace!(face=%self.id, len=wire.len(), "ws: send binary");
126        let mut writer = self.writer.lock().await;
127        writer
128            .send(Message::Binary(wire.to_vec().into()))
129            .await
130            .map_err(|e| FaceError::Io(std::io::Error::other(e)))
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use tokio::net::TcpListener;
138    use tokio_tungstenite::accept_async;
139
140    async fn loopback_pair() -> (WebSocketFace, WebSocketFace) {
141        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
142        let addr = listener.local_addr().unwrap();
143        let url = format!("ws://127.0.0.1:{}", addr.port());
144
145        let accept_fut = async {
146            let (stream, peer) = listener.accept().await.unwrap();
147            let ws = accept_async(MaybeTlsStream::Plain(stream)).await.unwrap();
148            WebSocketFace::from_stream(FaceId(1), ws, peer.to_string(), addr.to_string())
149        };
150
151        let connect_fut = WebSocketFace::connect(FaceId(0), &url);
152
153        let (server, client) = tokio::join!(accept_fut, connect_fut);
154        (client.unwrap(), server)
155    }
156
157    fn make_tlv(tag: u8, value: &[u8]) -> Bytes {
158        use ndn_tlv::TlvWriter;
159        let mut w = TlvWriter::new();
160        w.write_tlv(tag as u64, value);
161        w.finish()
162    }
163
164    fn expected_on_wire(pkt: &Bytes) -> Bytes {
165        ndn_packet::lp::encode_lp_packet(pkt)
166    }
167
168    #[tokio::test]
169    async fn send_recv_single_packet() {
170        let (client, server) = loopback_pair().await;
171        let pkt = make_tlv(0x05, b"hello");
172        client.send(pkt.clone()).await.unwrap();
173        assert_eq!(server.recv().await.unwrap(), expected_on_wire(&pkt));
174    }
175
176    #[tokio::test]
177    async fn bidirectional_exchange() {
178        let (client, server) = loopback_pair().await;
179        client.send(make_tlv(0x05, b"interest")).await.unwrap();
180        server.send(make_tlv(0x06, b"data")).await.unwrap();
181        assert_eq!(
182            server.recv().await.unwrap(),
183            expected_on_wire(&make_tlv(0x05, b"interest"))
184        );
185        assert_eq!(
186            client.recv().await.unwrap(),
187            expected_on_wire(&make_tlv(0x06, b"data"))
188        );
189    }
190
191    #[tokio::test]
192    async fn concurrent_sends_arrive_intact() {
193        use std::sync::Arc;
194        let (client, server) = loopback_pair().await;
195        let client = Arc::new(client);
196
197        let handles: Vec<_> = (0u8..8)
198            .map(|i| {
199                let c = Arc::clone(&client);
200                tokio::spawn(async move {
201                    c.send(make_tlv(0x05, &[i])).await.unwrap();
202                })
203            })
204            .collect();
205        for h in handles {
206            h.await.unwrap();
207        }
208
209        let mut received = Vec::new();
210        for _ in 0u8..8 {
211            received.push(server.recv().await.unwrap());
212        }
213        assert_eq!(received.len(), 8);
214    }
215
216    #[tokio::test]
217    async fn close_detection() {
218        let (client, server) = loopback_pair().await;
219        drop(client);
220        // Server should detect the close.
221        let result = server.recv().await;
222        assert!(result.is_err());
223    }
224}