1use ndn_packet::Name;
9use ndn_security::Certificate;
10
11use crate::{
12 ca::deserialize_cert,
13 ecdh::{EcdhKeypair, SessionKey},
14 error::CertError,
15 tlv::{
16 ChallengeRequestTlv, ChallengeResponseTlv, NewRequestTlv, NewResponseTlv, STATUS_FAILURE,
17 STATUS_PENDING, STATUS_SUCCESS,
18 },
19};
20
21#[derive(Debug, Clone, PartialEq)]
23enum SessionState {
24 Init,
25 AwaitingChallenge {
26 request_id: String,
27 challenges: Vec<String>,
28 },
29 Challenging {
31 request_id: String,
32 challenge_type: String,
33 status_message: String,
34 remaining_tries: u8,
35 remaining_time_secs: u32,
36 },
37 Complete,
38}
39
40pub struct EnrollmentSession {
51 name: Name,
52 public_key: Vec<u8>,
53 validity_secs: u64,
54 state: SessionState,
55 certificate: Option<Certificate>,
56 ecdh_keypair: Option<EcdhKeypair>,
58 session_key: Option<SessionKey>,
60 request_id_bytes: Option<[u8; 8]>,
62}
63
64impl EnrollmentSession {
65 pub fn new(name: Name, public_key: Vec<u8>, validity_secs: u64) -> Self {
66 Self {
67 name,
68 public_key,
69 validity_secs,
70 state: SessionState::Init,
71 certificate: None,
72 ecdh_keypair: None,
73 session_key: None,
74 request_id_bytes: None,
75 }
76 }
77
78 pub fn new_request_body(&mut self) -> Result<Vec<u8>, CertError> {
83 let kp = EcdhKeypair::generate();
84 let ecdh_pub_bytes = kp.public_key_bytes();
85 self.ecdh_keypair = Some(kp);
86
87 let now_ms = now_ms();
88 let cert_request = encode_cert_request_bytes(
89 now_ms,
90 now_ms + self.validity_secs * 1000,
91 &self.public_key,
92 &self.name.to_string(),
93 );
94
95 let tlv = NewRequestTlv {
96 ecdh_pub: bytes::Bytes::from(ecdh_pub_bytes),
97 cert_request: bytes::Bytes::from(cert_request),
98 };
99 Ok(tlv.encode().to_vec())
100 }
101
102 pub fn handle_new_response(&mut self, body: &[u8]) -> Result<(), CertError> {
107 let resp = NewResponseTlv::decode(bytes::Bytes::copy_from_slice(body))?;
108
109 if resp.challenges.is_empty() {
110 return Err(CertError::InvalidRequest(
111 "no challenges offered".to_string(),
112 ));
113 }
114
115 let kp = self.ecdh_keypair.take().ok_or_else(|| {
116 CertError::InvalidRequest("no ECDH keypair — call new_request_body first".into())
117 })?;
118
119 let session_key = kp.derive_session_key(&resp.ecdh_pub, &resp.salt, &resp.request_id)?;
120
121 let request_id_hex: String = resp.request_id.iter().map(|b| format!("{b:02x}")).collect();
122
123 self.session_key = Some(session_key);
124 self.request_id_bytes = Some(resp.request_id);
125 self.state = SessionState::AwaitingChallenge {
126 request_id: request_id_hex,
127 challenges: resp.challenges,
128 };
129 Ok(())
130 }
131
132 pub fn request_id(&self) -> Option<&str> {
134 match &self.state {
135 SessionState::AwaitingChallenge { request_id, .. }
136 | SessionState::Challenging { request_id, .. } => Some(request_id),
137 _ => None,
138 }
139 }
140
141 pub fn offered_challenges(&self) -> &[String] {
143 match &self.state {
144 SessionState::AwaitingChallenge { challenges, .. } => challenges,
145 _ => &[],
146 }
147 }
148
149 pub fn challenge_status_message(&self) -> Option<&str> {
151 match &self.state {
152 SessionState::Challenging { status_message, .. } => Some(status_message),
153 _ => None,
154 }
155 }
156
157 pub fn remaining_tries(&self) -> Option<u8> {
159 match &self.state {
160 SessionState::Challenging {
161 remaining_tries, ..
162 } => Some(*remaining_tries),
163 _ => None,
164 }
165 }
166
167 pub fn challenge_request_body(
171 &self,
172 challenge_type: &str,
173 parameters: serde_json::Map<String, serde_json::Value>,
174 ) -> Result<Vec<u8>, CertError> {
175 let request_id_bytes = self
176 .request_id_bytes
177 .ok_or_else(|| CertError::InvalidRequest("not in challenge state".to_string()))?;
178
179 let session_key = self.session_key.as_ref().ok_or_else(|| {
180 CertError::InvalidRequest("no session key — call handle_new_response first".into())
181 })?;
182
183 let params_json = serde_json::to_vec(¶meters)?;
184 let (iv, encrypted_payload, auth_tag) =
185 session_key.encrypt(¶ms_json, &request_id_bytes)?;
186
187 let tlv = ChallengeRequestTlv {
188 request_id: request_id_bytes,
189 selected_challenge: challenge_type.to_string(),
190 iv,
191 encrypted_payload,
192 auth_tag,
193 };
194 Ok(tlv.encode().to_vec())
195 }
196
197 pub fn handle_challenge_response(&mut self, body: &[u8]) -> Result<(), CertError> {
203 let resp = ChallengeResponseTlv::decode(bytes::Bytes::copy_from_slice(body))?;
204 match resp.status {
205 STATUS_FAILURE => {
206 let reason = resp
207 .error_info
208 .unwrap_or_else(|| "challenge denied".to_string());
209 Err(CertError::ChallengeFailed(reason))
210 }
211 STATUS_PENDING => {
212 let request_id = self.request_id().unwrap_or_default().to_string();
213 let challenge_type = match &self.state {
214 SessionState::Challenging { challenge_type, .. } => challenge_type.clone(),
215 _ => String::new(),
216 };
217 self.state = SessionState::Challenging {
218 request_id,
219 challenge_type,
220 status_message: resp
221 .challenge_status
222 .unwrap_or_else(|| "Challenge in progress".to_string()),
223 remaining_tries: resp.remaining_tries.unwrap_or(0),
224 remaining_time_secs: resp.remaining_time_secs.unwrap_or(0),
225 };
226 Ok(())
227 }
228 STATUS_SUCCESS => {
229 let cert_bytes = resp.encrypted_payload.ok_or_else(|| {
230 CertError::InvalidRequest("approved but no certificate returned".to_string())
231 })?;
232 let cert = deserialize_cert(&cert_bytes).ok_or_else(|| {
233 CertError::InvalidRequest("could not decode certificate".to_string())
234 })?;
235 self.certificate = Some(cert);
236 self.state = SessionState::Complete;
237 Ok(())
238 }
239 other => Err(CertError::InvalidRequest(format!(
240 "unexpected challenge response status: {other}"
241 ))),
242 }
243 }
244
245 pub fn is_complete(&self) -> bool {
247 self.state == SessionState::Complete
248 }
249
250 pub fn needs_another_round(&self) -> bool {
252 matches!(self.state, SessionState::Challenging { .. })
253 }
254
255 pub fn certificate(&self) -> Option<&Certificate> {
257 self.certificate.as_ref()
258 }
259
260 pub fn into_certificate(self) -> Option<Certificate> {
262 self.certificate
263 }
264}
265
266fn encode_cert_request_bytes(
270 not_before: u64,
271 not_after: u64,
272 public_key: &[u8],
273 name: &str,
274) -> Vec<u8> {
275 let name_bytes = name.as_bytes();
276 let mut out = Vec::with_capacity(20 + public_key.len() + 4 + name_bytes.len());
277 out.extend_from_slice(¬_before.to_be_bytes());
278 out.extend_from_slice(¬_after.to_be_bytes());
279 out.extend_from_slice(&(public_key.len() as u32).to_be_bytes());
280 out.extend_from_slice(public_key);
281 out.extend_from_slice(&(name_bytes.len() as u32).to_be_bytes());
282 out.extend_from_slice(name_bytes);
283 out
284}
285
286fn now_ms() -> u64 {
287 use std::time::{SystemTime, UNIX_EPOCH};
288 SystemTime::now()
289 .duration_since(UNIX_EPOCH)
290 .unwrap_or_default()
291 .as_millis() as u64
292}
293
294#[cfg(test)]
295mod tests {
296 use base64::Engine as _;
297
298 use super::*;
299
300 #[test]
301 fn new_request_body_is_valid_tlv() {
302 let name: Name = "/com/acme/alice/KEY/v=0".parse().unwrap();
303 let pubkey = vec![0x42u8; 32];
304 let mut session = EnrollmentSession::new(name, pubkey, 86400);
305 let body = session.new_request_body().unwrap();
306 let req = NewRequestTlv::decode(bytes::Bytes::from(body)).unwrap();
308 assert_eq!(req.ecdh_pub.len(), 65);
309 assert_eq!(req.ecdh_pub[0], 0x04); let cr_bytes = &req.cert_request;
312 let pk_len = 32usize;
313 let name_len =
314 u32::from_be_bytes(cr_bytes[20 + pk_len..24 + pk_len].try_into().unwrap()) as usize;
315 let name_str = std::str::from_utf8(&cr_bytes[24 + pk_len..24 + pk_len + name_len]).unwrap();
316 assert_eq!(name_str, "/com/acme/alice/KEY/v=0");
317 }
318
319 #[test]
320 fn encode_decode_cert_request_bytes_roundtrip() {
321 let name = "/com/acme/bob/KEY/v=1";
322 let pubkey = vec![0xAAu8; 32];
323 let not_before = 1_700_000_000_000u64;
324 let not_after = 1_700_086_400_000u64;
325 let encoded = encode_cert_request_bytes(not_before, not_after, &pubkey, name);
326 let decoded = crate::ca::decode_cert_request_bytes_pub(&encoded).unwrap();
327 assert_eq!(decoded.name, name);
328 assert_eq!(decoded.not_before, not_before);
329 assert_eq!(decoded.not_after, not_after);
330 let raw = base64::engine::general_purpose::URL_SAFE_NO_PAD
332 .decode(&decoded.public_key)
333 .unwrap();
334 assert_eq!(raw, pubkey);
335 }
336}