1use std::net::SocketAddr;
2
3use tokio::net::TcpStream;
4use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
5
6use ndn_transport::{FaceId, FaceKind, StreamFace, TlvCodec};
7
8pub type TcpFace = StreamFace<OwnedReadHalf, OwnedWriteHalf, TlvCodec>;
13
14pub fn tcp_face_from_stream(id: FaceId, stream: TcpStream) -> TcpFace {
16 let remote_addr = stream
17 .peer_addr()
18 .unwrap_or_else(|_| ([0, 0, 0, 0], 0).into());
19 let local_addr = stream
20 .local_addr()
21 .unwrap_or_else(|_| ([0, 0, 0, 0], 0).into());
22 let (r, w) = stream.into_split();
23 StreamFace::new(
24 id,
25 FaceKind::Tcp,
26 true,
27 Some(format!(
28 "tcp4://{}:{}",
29 remote_addr.ip(),
30 remote_addr.port()
31 )),
32 Some(format!("tcp4://{}:{}", local_addr.ip(), local_addr.port())),
33 r,
34 w,
35 TlvCodec,
36 )
37}
38
39pub async fn tcp_face_connect(id: FaceId, addr: SocketAddr) -> std::io::Result<TcpFace> {
41 let stream = TcpStream::connect(addr).await?;
42 Ok(tcp_face_from_stream(id, stream))
43}
44
45#[cfg(test)]
46mod tests {
47 use super::*;
48 use bytes::Bytes;
49 use ndn_transport::{Face, FaceError};
50 use tokio::net::TcpListener;
51
52 fn make_tlv(tag: u8, value: &[u8]) -> Bytes {
53 use ndn_tlv::TlvWriter;
54 let mut w = TlvWriter::new();
55 w.write_tlv(tag as u64, value);
56 w.finish()
57 }
58
59 fn expected_on_wire(pkt: &Bytes) -> Bytes {
60 ndn_packet::lp::encode_lp_packet(pkt)
61 }
62
63 async fn loopback_pair() -> (TcpFace, TcpFace) {
64 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
65 let addr = listener.local_addr().unwrap();
66 let connect_fut = tcp_face_connect(FaceId(0), addr);
67 let accept_fut = listener.accept();
68 let (client, accepted) = tokio::join!(connect_fut, accept_fut);
69 let (accepted_stream, _) = accepted.unwrap();
70 (
71 client.unwrap(),
72 tcp_face_from_stream(FaceId(1), accepted_stream),
73 )
74 }
75
76 #[tokio::test]
77 async fn send_recv_single_packet() {
78 let (client, server) = loopback_pair().await;
79 let pkt = make_tlv(0x05, b"hello");
80 client.send(pkt.clone()).await.unwrap();
81 assert_eq!(server.recv().await.unwrap(), expected_on_wire(&pkt));
82 }
83
84 #[tokio::test]
85 async fn framing_large_packet() {
86 let (client, server) = loopback_pair().await;
87 let payload = vec![0xABu8; 1000];
88 let pkt = make_tlv(0x06, &payload);
89 client.send(pkt.clone()).await.unwrap();
90 assert_eq!(server.recv().await.unwrap(), expected_on_wire(&pkt));
91 }
92
93 #[tokio::test]
94 async fn framing_multiple_sequential() {
95 let (client, server) = loopback_pair().await;
96 let pkts: Vec<Bytes> = (0u8..5).map(|i| make_tlv(0x05, &[i])).collect();
97 for pkt in &pkts {
98 client.send(pkt.clone()).await.unwrap();
99 }
100 for expected in &pkts {
101 assert_eq!(server.recv().await.unwrap(), expected_on_wire(expected));
102 }
103 }
104
105 #[tokio::test]
106 async fn recv_eof_returns_closed() {
107 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
108 let addr = listener.local_addr().unwrap();
109 let connect_fut = TcpStream::connect(addr);
110 let accept_fut = listener.accept();
111 let (stream, accepted) = tokio::join!(connect_fut, accept_fut);
112 let (accepted_stream, _) = accepted.unwrap();
113 let server = tcp_face_from_stream(FaceId(1), accepted_stream);
114 drop(stream.unwrap());
115 assert!(matches!(server.recv().await, Err(FaceError::Closed)));
116 }
117
118 #[tokio::test]
119 async fn bidirectional_exchange() {
120 let (client, server) = loopback_pair().await;
121 client.send(make_tlv(0x05, b"interest")).await.unwrap();
122 server.send(make_tlv(0x06, b"data")).await.unwrap();
123 assert_eq!(
124 server.recv().await.unwrap(),
125 expected_on_wire(&make_tlv(0x05, b"interest"))
126 );
127 assert_eq!(
128 client.recv().await.unwrap(),
129 expected_on_wire(&make_tlv(0x06, b"data"))
130 );
131 }
132
133 #[tokio::test]
134 async fn concurrent_sends_arrive_intact() {
135 use std::sync::Arc;
136 let (client, server) = loopback_pair().await;
137 let client = Arc::new(client);
138
139 let handles: Vec<_> = (0u8..8)
140 .map(|i| {
141 let c = Arc::clone(&client);
142 tokio::spawn(async move {
143 c.send(make_tlv(0x05, &[i])).await.unwrap();
144 })
145 })
146 .collect();
147 for h in handles {
148 h.await.unwrap();
149 }
150
151 let mut received = Vec::new();
152 for _ in 0u8..8 {
153 received.push(server.recv().await.unwrap());
154 }
155 assert_eq!(received.len(), 8);
156 }
157}