ndn_faces/net/
tcp.rs

1use std::net::SocketAddr;
2
3use tokio::net::TcpStream;
4use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
5
6use ndn_transport::{FaceId, FaceKind, StreamFace, TlvCodec};
7
8/// NDN face over TCP with TLV length-prefix framing.
9///
10/// Uses [`StreamFace`] with TCP read/write halves and [`TlvCodec`].
11/// LP-encoding is enabled for network transport.
12pub type TcpFace = StreamFace<OwnedReadHalf, OwnedWriteHalf, TlvCodec>;
13
14/// Wrap an accepted or connected `TcpStream` into a [`TcpFace`].
15pub fn tcp_face_from_stream(id: FaceId, stream: TcpStream) -> TcpFace {
16    let remote_addr = stream
17        .peer_addr()
18        .unwrap_or_else(|_| ([0, 0, 0, 0], 0).into());
19    let local_addr = stream
20        .local_addr()
21        .unwrap_or_else(|_| ([0, 0, 0, 0], 0).into());
22    let (r, w) = stream.into_split();
23    StreamFace::new(
24        id,
25        FaceKind::Tcp,
26        true,
27        Some(format!(
28            "tcp4://{}:{}",
29            remote_addr.ip(),
30            remote_addr.port()
31        )),
32        Some(format!("tcp4://{}:{}", local_addr.ip(), local_addr.port())),
33        r,
34        w,
35        TlvCodec,
36    )
37}
38
39/// Open a new TCP connection to `addr` and return a [`TcpFace`].
40pub async fn tcp_face_connect(id: FaceId, addr: SocketAddr) -> std::io::Result<TcpFace> {
41    let stream = TcpStream::connect(addr).await?;
42    Ok(tcp_face_from_stream(id, stream))
43}
44
45#[cfg(test)]
46mod tests {
47    use super::*;
48    use bytes::Bytes;
49    use ndn_transport::{Face, FaceError};
50    use tokio::net::TcpListener;
51
52    fn make_tlv(tag: u8, value: &[u8]) -> Bytes {
53        use ndn_tlv::TlvWriter;
54        let mut w = TlvWriter::new();
55        w.write_tlv(tag as u64, value);
56        w.finish()
57    }
58
59    fn expected_on_wire(pkt: &Bytes) -> Bytes {
60        ndn_packet::lp::encode_lp_packet(pkt)
61    }
62
63    async fn loopback_pair() -> (TcpFace, TcpFace) {
64        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
65        let addr = listener.local_addr().unwrap();
66        let connect_fut = tcp_face_connect(FaceId(0), addr);
67        let accept_fut = listener.accept();
68        let (client, accepted) = tokio::join!(connect_fut, accept_fut);
69        let (accepted_stream, _) = accepted.unwrap();
70        (
71            client.unwrap(),
72            tcp_face_from_stream(FaceId(1), accepted_stream),
73        )
74    }
75
76    #[tokio::test]
77    async fn send_recv_single_packet() {
78        let (client, server) = loopback_pair().await;
79        let pkt = make_tlv(0x05, b"hello");
80        client.send(pkt.clone()).await.unwrap();
81        assert_eq!(server.recv().await.unwrap(), expected_on_wire(&pkt));
82    }
83
84    #[tokio::test]
85    async fn framing_large_packet() {
86        let (client, server) = loopback_pair().await;
87        let payload = vec![0xABu8; 1000];
88        let pkt = make_tlv(0x06, &payload);
89        client.send(pkt.clone()).await.unwrap();
90        assert_eq!(server.recv().await.unwrap(), expected_on_wire(&pkt));
91    }
92
93    #[tokio::test]
94    async fn framing_multiple_sequential() {
95        let (client, server) = loopback_pair().await;
96        let pkts: Vec<Bytes> = (0u8..5).map(|i| make_tlv(0x05, &[i])).collect();
97        for pkt in &pkts {
98            client.send(pkt.clone()).await.unwrap();
99        }
100        for expected in &pkts {
101            assert_eq!(server.recv().await.unwrap(), expected_on_wire(expected));
102        }
103    }
104
105    #[tokio::test]
106    async fn recv_eof_returns_closed() {
107        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
108        let addr = listener.local_addr().unwrap();
109        let connect_fut = TcpStream::connect(addr);
110        let accept_fut = listener.accept();
111        let (stream, accepted) = tokio::join!(connect_fut, accept_fut);
112        let (accepted_stream, _) = accepted.unwrap();
113        let server = tcp_face_from_stream(FaceId(1), accepted_stream);
114        drop(stream.unwrap());
115        assert!(matches!(server.recv().await, Err(FaceError::Closed)));
116    }
117
118    #[tokio::test]
119    async fn bidirectional_exchange() {
120        let (client, server) = loopback_pair().await;
121        client.send(make_tlv(0x05, b"interest")).await.unwrap();
122        server.send(make_tlv(0x06, b"data")).await.unwrap();
123        assert_eq!(
124            server.recv().await.unwrap(),
125            expected_on_wire(&make_tlv(0x05, b"interest"))
126        );
127        assert_eq!(
128            client.recv().await.unwrap(),
129            expected_on_wire(&make_tlv(0x06, b"data"))
130        );
131    }
132
133    #[tokio::test]
134    async fn concurrent_sends_arrive_intact() {
135        use std::sync::Arc;
136        let (client, server) = loopback_pair().await;
137        let client = Arc::new(client);
138
139        let handles: Vec<_> = (0u8..8)
140            .map(|i| {
141                let c = Arc::clone(&client);
142                tokio::spawn(async move {
143                    c.send(make_tlv(0x05, &[i])).await.unwrap();
144                })
145            })
146            .collect();
147        for h in handles {
148            h.await.unwrap();
149        }
150
151        let mut received = Vec::new();
152        for _ in 0u8..8 {
153            received.push(server.recv().await.unwrap());
154        }
155        assert_eq!(received.len(), 8);
156    }
157}