1use bytes::{Buf, BufMut, Bytes, BytesMut};
17use tokio_util::codec::{Decoder, Encoder};
18
19const DEFAULT_MAX_FRAME_LEN: usize = 8800;
21
22#[derive(Clone)]
24pub struct CobsCodec {
25 max_frame_len: usize,
26}
27
28impl CobsCodec {
29 pub fn new() -> Self {
30 Self {
31 max_frame_len: DEFAULT_MAX_FRAME_LEN,
32 }
33 }
34
35 pub fn with_max_frame_len(max_frame_len: usize) -> Self {
36 Self { max_frame_len }
37 }
38}
39
40impl Default for CobsCodec {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46fn cobs_encode(src: &[u8], dst: &mut BytesMut) {
50 let max_overhead = (src.len() / 254) + 2;
52 dst.reserve(src.len() + max_overhead);
53
54 let mut code_idx = dst.len();
55 dst.put_u8(0); let mut code: u8 = 1;
57
58 for &byte in src {
59 if byte == 0x00 {
60 dst[code_idx] = code;
62 code_idx = dst.len();
63 dst.put_u8(0); code = 1;
65 } else {
66 dst.put_u8(byte);
67 code += 1;
68 if code == 0xFF {
69 dst[code_idx] = code;
71 code_idx = dst.len();
72 dst.put_u8(0); code = 1;
74 }
75 }
76 }
77 dst[code_idx] = code;
79}
80
81fn cobs_decode(src: &[u8], dst: &mut BytesMut) -> Result<(), std::io::Error> {
83 dst.reserve(src.len());
84 let mut i = 0;
85 while i < src.len() {
86 let code = src[i] as usize;
87 i += 1;
88 if code == 0 {
89 return Err(std::io::Error::new(
90 std::io::ErrorKind::InvalidData,
91 "unexpected zero in COBS data",
92 ));
93 }
94 let run_len = code - 1;
95 if i + run_len > src.len() {
96 return Err(std::io::Error::new(
97 std::io::ErrorKind::InvalidData,
98 "COBS run exceeds input",
99 ));
100 }
101 dst.extend_from_slice(&src[i..i + run_len]);
102 i += run_len;
103 if code < 0xFF && i < src.len() {
106 dst.put_u8(0x00);
107 }
108 }
109 Ok(())
110}
111
112impl Decoder for CobsCodec {
115 type Item = Bytes;
116 type Error = std::io::Error;
117
118 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Bytes>, std::io::Error> {
119 let delim_pos = buf.iter().position(|&b| b == 0x00);
121 let delim_pos = match delim_pos {
122 Some(pos) => pos,
123 None => {
124 if buf.len() > self.max_frame_len * 2 {
127 buf.clear();
128 }
129 return Ok(None);
130 }
131 };
132
133 let encoded = buf.split_to(delim_pos);
135 buf.advance(1); if encoded.is_empty() {
139 return Ok(None);
140 }
141
142 let mut decoded = BytesMut::new();
144 match cobs_decode(&encoded, &mut decoded) {
145 Ok(()) => {
146 if decoded.len() > self.max_frame_len {
147 return Err(std::io::Error::new(
148 std::io::ErrorKind::InvalidData,
149 "COBS frame exceeds max length",
150 ));
151 }
152 Ok(Some(decoded.freeze()))
153 }
154 Err(_) => {
155 Ok(None)
157 }
158 }
159 }
160}
161
162impl Encoder<Bytes> for CobsCodec {
165 type Error = std::io::Error;
166
167 fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> Result<(), std::io::Error> {
168 cobs_encode(&item, dst);
169 dst.put_u8(0x00); Ok(())
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177
178 fn roundtrip(data: &[u8]) -> Vec<u8> {
179 let mut encoded = BytesMut::new();
180 cobs_encode(data, &mut encoded);
181 encoded.put_u8(0x00);
182
183 assert!(
185 !encoded[..encoded.len() - 1].contains(&0x00),
186 "encoded payload must not contain 0x00"
187 );
188
189 let mut codec = CobsCodec::new();
190 let decoded = codec.decode(&mut encoded).unwrap().unwrap();
191 decoded.to_vec()
192 }
193
194 #[test]
195 fn empty_payload() {
196 assert_eq!(roundtrip(&[]), Vec::<u8>::new());
197 }
198
199 #[test]
200 fn single_byte() {
201 assert_eq!(roundtrip(&[0x42]), vec![0x42]);
202 }
203
204 #[test]
205 fn single_zero() {
206 assert_eq!(roundtrip(&[0x00]), vec![0x00]);
207 }
208
209 #[test]
210 fn multiple_zeros() {
211 let data = vec![0x00; 10];
212 assert_eq!(roundtrip(&data), data);
213 }
214
215 #[test]
216 fn no_zeros() {
217 let data: Vec<u8> = (1..=255).collect();
218 assert_eq!(roundtrip(&data), data);
219 }
220
221 #[test]
222 fn boundary_254_bytes() {
223 let data: Vec<u8> = (1..=254).collect();
224 assert_eq!(roundtrip(&data), data);
225 }
226
227 #[test]
228 fn boundary_255_bytes() {
229 let mut data: Vec<u8> = (1..=254).collect();
230 data.push(0x01);
231 assert_eq!(roundtrip(&data), data);
232 }
233
234 #[test]
235 fn large_payload() {
236 let data: Vec<u8> = (0..8000).map(|i| (i % 256) as u8).collect();
237 assert_eq!(roundtrip(&data), data);
238 }
239
240 #[test]
241 fn zeros_and_data_interleaved() {
242 let data = vec![0x01, 0x00, 0x02, 0x00, 0x03];
243 assert_eq!(roundtrip(&data), data);
244 }
245
246 #[test]
247 fn codec_multiple_frames() {
248 let mut codec = CobsCodec::new();
249 let mut buf = BytesMut::new();
250
251 let frame1 = Bytes::from_static(&[0x01, 0x02, 0x03]);
253 let frame2 = Bytes::from_static(&[0xAA, 0x00, 0xBB]);
254 codec.encode(frame1.clone(), &mut buf).unwrap();
255 codec.encode(frame2.clone(), &mut buf).unwrap();
256
257 let d1 = codec.decode(&mut buf).unwrap().unwrap();
259 let d2 = codec.decode(&mut buf).unwrap().unwrap();
260 assert_eq!(d1, frame1);
261 assert_eq!(d2, frame2);
262 }
263
264 #[test]
265 fn codec_resync_after_garbage() {
266 let mut codec = CobsCodec::new();
267 let mut buf = BytesMut::new();
268
269 buf.extend_from_slice(&[0xFF, 0xFE, 0xFD]);
272 buf.put_u8(0x00); let frame = Bytes::from_static(&[0x42]);
275 codec.encode(frame.clone(), &mut buf).unwrap();
276
277 let result1 = codec.decode(&mut buf).unwrap();
279 assert_eq!(result1, None);
280 let result2 = codec.decode(&mut buf).unwrap();
282 assert_eq!(result2, Some(frame));
283 }
284
285 #[test]
286 fn consecutive_delimiters_skipped() {
287 let mut codec = CobsCodec::new();
288 let mut buf = BytesMut::new();
289
290 buf.put_u8(0x00);
292 buf.put_u8(0x00);
293 let frame = Bytes::from_static(&[0x01]);
295 codec.encode(frame.clone(), &mut buf).unwrap();
296
297 assert_eq!(codec.decode(&mut buf).unwrap(), None);
299 assert_eq!(codec.decode(&mut buf).unwrap(), None);
300 assert_eq!(codec.decode(&mut buf).unwrap(), Some(frame));
302 }
303}