1use aes_gcm::{
10 Aes128Gcm,
11 aead::{Aead, AeadCore, KeyInit, OsRng, Payload},
12};
13use bytes::Bytes;
14use hkdf::Hkdf;
15use p256::{
16 EncodedPoint, NistP256, PublicKey, ecdh::EphemeralSecret,
17 elliptic_curve::sec1::FromEncodedPoint,
18};
19use sha2::Sha256;
20
21use crate::error::CertError;
22
23pub struct EcdhKeypair {
29 secret: EphemeralSecret,
30}
31
32impl EcdhKeypair {
33 pub fn generate() -> Self {
35 Self {
36 secret: EphemeralSecret::random(&mut OsRng),
37 }
38 }
39
40 pub fn public_key_bytes(&self) -> Vec<u8> {
43 let pub_key: PublicKey = (&self.secret).into();
44 EncodedPoint::from(&pub_key).as_bytes().to_vec()
45 }
46
47 pub fn random_salt() -> [u8; 32] {
49 use ring::rand::{SecureRandom, SystemRandom};
50 let rng = SystemRandom::new();
51 let mut salt = [0u8; 32];
52 rng.fill(&mut salt).unwrap_or(());
53 salt
54 }
55
56 pub fn derive_session_key(
63 self,
64 peer_pub_bytes: &[u8],
65 salt: &[u8; 32],
66 request_id: &[u8; 8],
67 ) -> Result<SessionKey, CertError> {
68 let peer_point = EncodedPoint::from_bytes(peer_pub_bytes)
69 .map_err(|_| CertError::InvalidRequest("invalid peer ECDH public key".into()))?;
70 let peer_pub = Option::<PublicKey>::from(
71 <PublicKey as FromEncodedPoint<NistP256>>::from_encoded_point(&peer_point),
72 )
73 .ok_or_else(|| CertError::InvalidRequest("invalid P-256 point".into()))?;
74
75 let shared = self.secret.diffie_hellman(&peer_pub);
77
78 let hk = Hkdf::<Sha256>::new(Some(salt), shared.raw_secret_bytes());
80 let mut aes_key = [0u8; 16];
81 hk.expand(request_id, &mut aes_key)
82 .map_err(|_| CertError::InvalidRequest("HKDF expand failed".into()))?;
83
84 Ok(SessionKey { key: aes_key })
85 }
86}
87
88#[derive(Clone)]
92pub struct SessionKey {
93 pub(crate) key: [u8; 16],
94}
95
96impl SessionKey {
97 pub fn encrypt(
102 &self,
103 plaintext: &[u8],
104 aad: &[u8],
105 ) -> Result<([u8; 12], Bytes, [u8; 16]), CertError> {
106 let cipher = Aes128Gcm::new_from_slice(&self.key)
107 .map_err(|_| CertError::InvalidRequest("AES key init failed".into()))?;
108
109 let nonce = Aes128Gcm::generate_nonce(&mut OsRng);
110 let nonce_arr: [u8; 12] = nonce.into();
111
112 let ciphertext_with_tag = cipher
113 .encrypt(
114 &nonce,
115 Payload {
116 msg: plaintext,
117 aad,
118 },
119 )
120 .map_err(|_| CertError::InvalidRequest("AES-GCM encryption failed".into()))?;
121
122 let split_at = ciphertext_with_tag.len() - 16;
124 let (ct, tag) = ciphertext_with_tag.split_at(split_at);
125 let mut tag_arr = [0u8; 16];
126 tag_arr.copy_from_slice(tag);
127
128 Ok((nonce_arr, Bytes::copy_from_slice(ct), tag_arr))
129 }
130
131 pub fn decrypt(
135 &self,
136 iv: &[u8; 12],
137 ciphertext: &[u8],
138 auth_tag: &[u8; 16],
139 aad: &[u8],
140 ) -> Result<Vec<u8>, CertError> {
141 use aes_gcm::aead::generic_array::GenericArray;
142
143 let cipher = Aes128Gcm::new_from_slice(&self.key)
144 .map_err(|_| CertError::InvalidRequest("AES key init failed".into()))?;
145
146 let mut ct_with_tag = Vec::with_capacity(ciphertext.len() + 16);
148 ct_with_tag.extend_from_slice(ciphertext);
149 ct_with_tag.extend_from_slice(auth_tag);
150
151 let nonce = GenericArray::from_slice(iv);
152 let plaintext = cipher
153 .decrypt(
154 nonce,
155 Payload {
156 msg: &ct_with_tag,
157 aad,
158 },
159 )
160 .map_err(|_| CertError::InvalidRequest("AES-GCM decryption failed (bad tag)".into()))?;
161
162 Ok(plaintext)
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[test]
171 fn ecdh_key_agreement_produces_same_session_key() {
172 let client_kp = EcdhKeypair::generate();
173 let ca_kp = EcdhKeypair::generate();
174
175 let client_pub = client_kp.public_key_bytes();
176 let ca_pub = ca_kp.public_key_bytes();
177
178 let salt = [0x42u8; 32];
179 let request_id = [0x01u8; 8];
180
181 let client_session = client_kp
182 .derive_session_key(&ca_pub, &salt, &request_id)
183 .unwrap();
184 let ca_session = ca_kp
185 .derive_session_key(&client_pub, &salt, &request_id)
186 .unwrap();
187
188 assert_eq!(client_session.key, ca_session.key);
189 }
190
191 #[test]
192 fn encrypt_decrypt_roundtrip() {
193 let kp_a = EcdhKeypair::generate();
194 let kp_b = EcdhKeypair::generate();
195 let pub_a = kp_a.public_key_bytes();
196 let pub_b = kp_b.public_key_bytes();
197
198 let salt = [0x11u8; 32];
199 let request_id = [0x22u8; 8];
200
201 let key_a = kp_a.derive_session_key(&pub_b, &salt, &request_id).unwrap();
202 let key_b = kp_b.derive_session_key(&pub_a, &salt, &request_id).unwrap();
203
204 let plaintext = b"{\"code\":\"123456\"}";
205 let aad = &request_id[..];
206
207 let (iv, ct, tag) = key_a.encrypt(plaintext, aad).unwrap();
208 let decrypted = key_b.decrypt(&iv, &ct, &tag, aad).unwrap();
209
210 assert_eq!(decrypted, plaintext);
211 }
212
213 #[test]
214 fn decrypt_fails_with_wrong_tag() {
215 let kp_a = EcdhKeypair::generate();
216 let kp_b = EcdhKeypair::generate();
217 let pub_a = kp_a.public_key_bytes();
218 let pub_b = kp_b.public_key_bytes();
219
220 let salt = [0x33u8; 32];
221 let request_id = [0x44u8; 8];
222
223 let key_a = kp_a.derive_session_key(&pub_b, &salt, &request_id).unwrap();
224 let key_b = kp_b.derive_session_key(&pub_a, &salt, &request_id).unwrap();
225
226 let (iv, ct, mut tag) = key_a.encrypt(b"secret", &request_id).unwrap();
227 tag[0] ^= 0xFF; assert!(key_b.decrypt(&iv, &ct, &tag, &request_id).is_err());
230 }
231
232 #[test]
233 fn public_key_is_65_bytes() {
234 let kp = EcdhKeypair::generate();
235 let pub_bytes = kp.public_key_bytes();
236 assert_eq!(pub_bytes.len(), 65);
237 assert_eq!(pub_bytes[0], 0x04); }
239}