Internal.hs (25737B)
1 {-# OPTIONS_HADDOCK prune #-} 2 {-# LANGUAGE BangPatterns #-} 3 {-# LANGUAGE DeriveGeneric #-} 4 {-# LANGUAGE LambdaCase #-} 5 {-# LANGUAGE OverloadedStrings #-} 6 {-# LANGUAGE RecordWildCards #-} 7 {-# LANGUAGE ViewPatterns #-} 8 9 -- | 10 -- Module: Lightning.Protocol.BOLT8.Internal 11 -- Copyright: (c) 2025 Jared Tobin 12 -- License: MIT 13 -- Maintainer: Jared Tobin <jared@ppad.tech> 14 -- 15 -- Internal module exporting all constructors for testing and 16 -- benchmarking. Prefer "Lightning.Protocol.BOLT8" for general use. 17 18 module Lightning.Protocol.BOLT8.Internal ( 19 -- * Keys 20 Sec(..) 21 , Pub(..) 22 , keypair 23 , parse_pub 24 , serialize_pub 25 26 -- * Newtypes 27 , Key32(..) 28 , key32 29 , unsafeKey32 30 , SessionNonce(..) 31 , MessagePayload(..) 32 , mkMessagePayload 33 34 -- * Handshake roles 35 , Initiator 36 , Responder 37 , HandshakeFor(..) 38 39 -- * Handshake (initiator) 40 , act1 41 , act3 42 43 -- * Handshake (responder) 44 , act2 45 , finalize 46 47 -- * Session 48 , Session(..) 49 , HandshakeState(..) 50 , Handshake(..) 51 , encrypt 52 , decrypt 53 , decrypt_frame 54 , decrypt_frame_partial 55 , FrameResult(..) 56 57 -- * Errors 58 , Error(..) 59 ) where 60 61 import Control.Monad (guard, unless) 62 import qualified Crypto.AEAD.ChaCha20Poly1305 as AEAD 63 import qualified Crypto.Curve.Secp256k1 as Secp256k1 64 import qualified Crypto.Hash.SHA256 as SHA256 65 import qualified Crypto.KDF.HMAC as HKDF 66 import Data.Bits (unsafeShiftR, (.&.)) 67 import qualified Data.ByteString as BS 68 import Data.Word (Word16, Word64) 69 import GHC.Generics (Generic) 70 71 -- types ----------------------------------------------------------- 72 73 -- | Secret key (32 bytes). 74 newtype Sec = Sec BS.ByteString 75 deriving (Eq, Generic) 76 77 -- | Compressed public key. 78 newtype Pub = Pub Secp256k1.Projective 79 80 instance Eq Pub where 81 (Pub a) == (Pub b) = 82 Secp256k1.serialize_point a 83 == Secp256k1.serialize_point b 84 85 instance Show Pub where 86 show (Pub p) = 87 "Pub " ++ show (Secp256k1.serialize_point p) 88 89 -- | A 32-byte key, validated at construction. 90 newtype Key32 = Key32 { unKey32 :: BS.ByteString } 91 deriving (Eq, Generic) 92 93 -- | Construct a 'Key32' from a 32-byte 'BS.ByteString'. 94 -- 95 -- Returns 'Nothing' if the input is not exactly 32 bytes. 96 -- 97 -- >>> key32 (BS.replicate 32 0x00) 98 -- Just (Key32 {unKey32 = ...}) 99 -- >>> key32 (BS.replicate 31 0x00) 100 -- Nothing 101 key32 :: BS.ByteString -> Maybe Key32 102 key32 bs 103 | BS.length bs == 32 = Just (Key32 bs) 104 | otherwise = Nothing 105 106 -- | Construct a 'Key32' without validation. 107 -- 108 -- For test and benchmark use only; prefer 'key32'. 109 unsafeKey32 :: BS.ByteString -> Key32 110 unsafeKey32 = Key32 111 112 -- | Session nonce, distinguishing send from receive direction. 113 newtype SessionNonce = 114 SessionNonce { unSessionNonce :: Word64 } 115 deriving (Eq, Generic) 116 117 -- | Message payload (max 65535 bytes), validated at construction. 118 newtype MessagePayload = 119 MessagePayload { unMessagePayload :: BS.ByteString } 120 deriving (Eq, Generic) 121 122 -- | Construct a 'MessagePayload' from a 'BS.ByteString'. 123 -- 124 -- Returns 'Left' if the payload exceeds 65535 bytes. 125 mkMessagePayload 126 :: BS.ByteString -> Either Error MessagePayload 127 mkMessagePayload bs 128 | BS.length bs > 65535 = Left InvalidLength 129 | otherwise = Right (MessagePayload bs) 130 131 -- | Handshake errors. 132 data Error = 133 InvalidKey 134 | InvalidPub 135 | InvalidMAC 136 | InvalidVersion 137 | InvalidLength 138 | DecryptionFailed 139 deriving (Eq, Show, Generic) 140 141 -- | Result of attempting to decrypt a frame from a partial 142 -- buffer. 143 data FrameResult = 144 NeedMore {-# UNPACK #-} !Int 145 -- ^ More bytes needed; the 'Int' is the minimum 146 -- additional bytes required. 147 | FrameOk !BS.ByteString !BS.ByteString !Session 148 -- ^ Successfully decrypted: plaintext, remainder, 149 -- updated session. 150 | FrameError !Error 151 -- ^ Decryption failed with the given error. 152 deriving Generic 153 154 -- | Post-handshake session state. 155 data Session = Session { 156 sess_sk :: !Key32 157 -- ^ send key (32 bytes) 158 , sess_sn :: !SessionNonce 159 -- ^ send nonce 160 , sess_sck :: !Key32 161 -- ^ send chaining key 162 , sess_rk :: !Key32 163 -- ^ receive key (32 bytes) 164 , sess_rn :: !SessionNonce 165 -- ^ receive nonce 166 , sess_rck :: !Key32 167 -- ^ receive chaining key 168 } 169 deriving Generic 170 171 -- | Result of a successful handshake. 172 data Handshake = Handshake { 173 session :: !Session 174 -- ^ session state 175 , remote_static :: !Pub 176 -- ^ authenticated remote static pubkey 177 } 178 deriving Generic 179 180 -- | Internal handshake state (exported for benchmarking). 181 data HandshakeState = HandshakeState { 182 hs_h :: {-# UNPACK #-} !BS.ByteString 183 -- ^ handshake hash (32 bytes) 184 , hs_ck :: {-# UNPACK #-} !BS.ByteString 185 -- ^ chaining key (32 bytes) 186 , hs_temp_k :: {-# UNPACK #-} !BS.ByteString 187 -- ^ temp key (32 bytes) 188 , hs_e_sec :: !Sec 189 -- ^ ephemeral secret 190 , hs_e_pub :: !Pub 191 -- ^ ephemeral public 192 , hs_s_sec :: !Sec 193 -- ^ static secret 194 , hs_s_pub :: !Pub 195 -- ^ static public 196 , hs_re :: !(Maybe Pub) 197 -- ^ remote ephemeral 198 , hs_rs :: !(Maybe Pub) 199 -- ^ remote static 200 } 201 deriving Generic 202 203 -- handshake roles ------------------------------------------------- 204 205 -- | Phantom type for initiator role. 206 data Initiator 207 208 -- | Phantom type for responder role. 209 data Responder 210 211 -- | Role-indexed handshake state. 212 -- 213 -- The phantom type parameter prevents passing an initiator's 214 -- state to a responder function and vice versa. 215 data HandshakeFor a = 216 HandshakeFor { unHandshakeFor :: !HandshakeState } 217 218 -- protocol constants ---------------------------------------------- 219 220 _PROTOCOL_NAME :: BS.ByteString 221 _PROTOCOL_NAME = 222 "Noise_XK_secp256k1_ChaChaPoly_SHA256" 223 224 _PROLOGUE :: BS.ByteString 225 _PROLOGUE = "lightning" 226 227 -- key operations -------------------------------------------------- 228 229 -- | Derive a keypair from 32 bytes of entropy. 230 -- 231 -- Returns Nothing if the entropy is invalid 232 -- (zero or >= curve order). 233 -- 234 -- >>> let ent = BS.replicate 32 0x11 235 -- >>> case keypair ent of 236 -- ... Just _ -> "ok" 237 -- ... Nothing -> "fail" 238 -- "ok" 239 -- >>> keypair (BS.replicate 31 0x11) -- wrong length 240 -- Nothing 241 keypair :: BS.ByteString -> Maybe (Sec, Pub) 242 keypair ent = do 243 guard (BS.length ent == 32) 244 k <- Secp256k1.parse_int256 ent 245 p <- Secp256k1.derive_pub k 246 pure (Sec ent, Pub p) 247 248 -- | Parse a 33-byte compressed public key. 249 -- 250 -- >>> let Just (_, pub) = keypair (BS.replicate 32 0x11) 251 -- >>> let bytes = serialize_pub pub 252 -- >>> case parse_pub bytes of 253 -- ... Just _ -> "ok" 254 -- ... Nothing -> "fail" 255 -- "ok" 256 -- >>> parse_pub (BS.replicate 32 0x00) -- wrong length 257 -- Nothing 258 parse_pub :: BS.ByteString -> Maybe Pub 259 parse_pub bs = do 260 guard (BS.length bs == 33) 261 p <- Secp256k1.parse_point bs 262 pure (Pub p) 263 264 -- | Serialize a public key to 33-byte compressed form. 265 -- 266 -- >>> let Just (_, pub) = keypair (BS.replicate 32 0x11) 267 -- >>> BS.length (serialize_pub pub) 268 -- 33 269 serialize_pub :: Pub -> BS.ByteString 270 serialize_pub (Pub p) = Secp256k1.serialize_point p 271 272 -- cryptographic primitives ---------------------------------------- 273 274 -- bolt8-style ECDH 275 ecdh :: Sec -> Pub -> Maybe BS.ByteString 276 ecdh (Sec sec) (Pub pub) = do 277 k <- Secp256k1.parse_int256 sec 278 pt <- Secp256k1.mul pub k 279 let compressed = Secp256k1.serialize_point pt 280 pure (SHA256.hash compressed) 281 282 -- h' = SHA256(h || data) 283 mix_hash 284 :: BS.ByteString -> BS.ByteString -> BS.ByteString 285 mix_hash h dat = SHA256.hash (h <> dat) 286 287 -- Mix key: (ck', k) = HKDF(ck, input_key_material) 288 -- 289 -- NB HKDF limits output to 255 * hashlen bytes. For SHA256 290 -- that's 8160, well above the 64 bytes requested here, so 291 -- 'Nothing' is impossible. 292 mix_key 293 :: BS.ByteString 294 -> BS.ByteString 295 -> (BS.ByteString, BS.ByteString) 296 mix_key ck ikm = 297 case HKDF.derive hmac ck mempty 64 ikm of 298 Nothing -> 299 error 300 "ppad-bolt8: internal error, please report a bug!" 301 Just output -> BS.splitAt 32 output 302 where 303 hmac k b = case SHA256.hmac k b of 304 SHA256.MAC mac -> mac 305 306 -- Encrypt with associated data using ChaCha20-Poly1305 307 encrypt_with_ad 308 :: BS.ByteString -- ^ key (32 bytes) 309 -> Word64 -- ^ nonce 310 -> BS.ByteString -- ^ associated data 311 -> BS.ByteString -- ^ plaintext 312 -> Maybe BS.ByteString -- ^ ciphertext || mac (16 bytes) 313 encrypt_with_ad key n ad pt = 314 case AEAD.encrypt ad key (encode_nonce n) pt of 315 Left _ -> Nothing 316 Right (ct, mac) -> Just (ct <> mac) 317 318 -- Decrypt with associated data using ChaCha20-Poly1305 319 decrypt_with_ad 320 :: BS.ByteString -- ^ key (32 bytes) 321 -> Word64 -- ^ nonce 322 -> BS.ByteString -- ^ associated data 323 -> BS.ByteString -- ^ ciphertext || mac 324 -> Maybe BS.ByteString -- ^ plaintext 325 decrypt_with_ad key n ad ctmac 326 | BS.length ctmac < 16 = Nothing 327 | otherwise = 328 let (ct, mac) = 329 BS.splitAt (BS.length ctmac - 16) ctmac 330 in case AEAD.decrypt ad key (encode_nonce n) 331 (ct, mac) of 332 Left _ -> Nothing 333 Right pt -> Just pt 334 335 -- Encode nonce as 96-bit value: 4 zero bytes + 8-byte LE 336 encode_nonce :: Word64 -> BS.ByteString 337 encode_nonce n = BS.replicate 4 0x00 <> encode_le64 n 338 339 -- Little-endian 64-bit encoding 340 encode_le64 :: Word64 -> BS.ByteString 341 encode_le64 n = BS.pack [ 342 fi (n .&. 0xff) 343 , fi (unsafeShiftR n 8 .&. 0xff) 344 , fi (unsafeShiftR n 16 .&. 0xff) 345 , fi (unsafeShiftR n 24 .&. 0xff) 346 , fi (unsafeShiftR n 32 .&. 0xff) 347 , fi (unsafeShiftR n 40 .&. 0xff) 348 , fi (unsafeShiftR n 48 .&. 0xff) 349 , fi (unsafeShiftR n 56 .&. 0xff) 350 ] 351 352 -- Big-endian 16-bit encoding 353 encode_be16 :: Word16 -> BS.ByteString 354 encode_be16 n = 355 BS.pack [fi (unsafeShiftR n 8), fi (n .&. 0xff)] 356 357 -- Big-endian 16-bit decoding 358 decode_be16 :: BS.ByteString -> Maybe Word16 359 decode_be16 bs 360 | BS.length bs /= 2 = Nothing 361 | otherwise = 362 let !b0 = BS.index bs 0 363 !b1 = BS.index bs 1 364 in Just (fi b0 * 0x100 + fi b1) 365 366 -- handshake ------------------------------------------------------- 367 368 -- Initialize handshake state 369 -- 370 -- h = SHA256(protocol_name) 371 -- ck = h 372 -- h = SHA256(h || prologue) 373 -- h = SHA256(h || responder_static_pubkey) 374 init_handshake 375 :: Sec -- ^ local static secret 376 -> Pub -- ^ local static public 377 -> Sec -- ^ ephemeral secret 378 -> Pub -- ^ ephemeral public 379 -> Maybe Pub -- ^ remote static 380 -> Bool -- ^ True if initiator 381 -> HandshakeState 382 init_handshake s_sec s_pub e_sec e_pub m_rs is_init = 383 let !h0 = SHA256.hash _PROTOCOL_NAME 384 !ck = h0 385 !h1 = mix_hash h0 _PROLOGUE 386 -- Mix in responder's static pubkey 387 !h2 = case (is_init, m_rs) of 388 (True, Just rs) -> 389 mix_hash h1 (serialize_pub rs) 390 (False, Nothing) -> 391 mix_hash h1 (serialize_pub s_pub) 392 _ -> h1 -- shouldn't happen 393 in HandshakeState { 394 hs_h = h2 395 , hs_ck = ck 396 , hs_temp_k = BS.replicate 32 0x00 397 , hs_e_sec = e_sec 398 , hs_e_pub = e_pub 399 , hs_s_sec = s_sec 400 , hs_s_pub = s_pub 401 , hs_re = Nothing 402 , hs_rs = m_rs 403 } 404 405 -- | Initiator: generate Act 1 message (50 bytes). 406 -- 407 -- Takes local static key, remote static pubkey, and 32 408 -- bytes of entropy for ephemeral key generation. 409 -- 410 -- Returns the 50-byte Act 1 message and handshake state 411 -- for Act 3. 412 -- 413 -- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) 414 -- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) 415 -- >>> let eph_ent = BS.replicate 32 0x12 416 -- >>> case act1 i_sec i_pub r_pub eph_ent of { Right (msg, _) -> BS.length msg; Left _ -> 0 } 417 -- 50 418 act1 419 :: Sec 420 -> Pub 421 -> Pub 422 -> BS.ByteString 423 -> Either Error 424 (BS.ByteString, HandshakeFor Initiator) 425 act1 s_sec s_pub rs ent = do 426 (e_sec, e_pub) <- note InvalidKey (keypair ent) 427 let !hs0 = init_handshake 428 s_sec s_pub e_sec e_pub (Just rs) True 429 !e_pub_bytes = serialize_pub e_pub 430 !h1 = mix_hash (hs_h hs0) e_pub_bytes 431 es <- note InvalidKey (ecdh e_sec rs) 432 let !(ck1, temp_k1) = mix_key (hs_ck hs0) es 433 c <- note InvalidMAC 434 (encrypt_with_ad temp_k1 0 h1 BS.empty) 435 let !h2 = mix_hash h1 c 436 !msg = BS.singleton 0x00 <> e_pub_bytes <> c 437 !hs1 = hs0 { 438 hs_h = h2 439 , hs_ck = ck1 440 , hs_temp_k = temp_k1 441 } 442 pure (msg, HandshakeFor hs1) 443 444 -- | Responder: process Act 1 and generate Act 2 message 445 -- (50 bytes). 446 -- 447 -- Takes local static key and 32 bytes of entropy for 448 -- ephemeral key, plus the 50-byte Act 1 message from 449 -- initiator. 450 -- 451 -- Returns the 50-byte Act 2 message and handshake state 452 -- for finalize. 453 -- 454 -- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) 455 -- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) 456 -- >>> let Right (msg1, _) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) 457 -- >>> case act2 r_sec r_pub (BS.replicate 32 0x22) msg1 of { Right (msg, _) -> BS.length msg; Left _ -> 0 } 458 -- 50 459 act2 460 :: Sec 461 -> Pub 462 -> BS.ByteString 463 -> BS.ByteString 464 -> Either Error 465 (BS.ByteString, HandshakeFor Responder) 466 act2 s_sec s_pub ent msg1 = do 467 require (BS.length msg1 == 50) InvalidLength 468 let !version = BS.index msg1 0 469 !re_bytes = BS.take 33 (BS.drop 1 msg1) 470 !c = BS.drop 34 msg1 471 require (version == 0x00) InvalidVersion 472 re <- note InvalidPub (parse_pub re_bytes) 473 (e_sec, e_pub) <- note InvalidKey (keypair ent) 474 let !hs0 = init_handshake 475 s_sec s_pub e_sec e_pub Nothing False 476 !h1 = mix_hash (hs_h hs0) re_bytes 477 es <- note InvalidKey (ecdh s_sec re) 478 let !(ck1, temp_k1) = mix_key (hs_ck hs0) es 479 _ <- note InvalidMAC 480 (decrypt_with_ad temp_k1 0 h1 c) 481 let !h2 = mix_hash h1 c 482 !e_pub_bytes = serialize_pub e_pub 483 !h3 = mix_hash h2 e_pub_bytes 484 ee <- note InvalidKey (ecdh e_sec re) 485 let !(ck2, temp_k2) = mix_key ck1 ee 486 c2 <- note InvalidMAC 487 (encrypt_with_ad temp_k2 0 h3 BS.empty) 488 let !h4 = mix_hash h3 c2 489 !msg = BS.singleton 0x00 <> e_pub_bytes <> c2 490 !hs1 = hs0 { 491 hs_h = h4 492 , hs_ck = ck2 493 , hs_temp_k = temp_k2 494 , hs_re = Just re 495 } 496 pure (msg, HandshakeFor hs1) 497 498 -- | Initiator: process Act 2 and generate Act 3 (66 bytes), 499 -- completing the handshake. 500 -- 501 -- Returns the 66-byte Act 3 message and the handshake 502 -- result. 503 -- 504 -- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) 505 -- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) 506 -- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) 507 -- >>> let Right (msg2, _) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1 508 -- >>> case act3 i_hs msg2 of { Right (msg, _) -> BS.length msg; Left _ -> 0 } 509 -- 66 510 act3 511 :: HandshakeFor Initiator 512 -> BS.ByteString 513 -> Either Error (BS.ByteString, Handshake) 514 act3 (HandshakeFor hs) msg2 = do 515 require (BS.length msg2 == 50) InvalidLength 516 let !version = BS.index msg2 0 517 !re_bytes = BS.take 33 (BS.drop 1 msg2) 518 !c = BS.drop 34 msg2 519 require (version == 0x00) InvalidVersion 520 re <- note InvalidPub (parse_pub re_bytes) 521 let !h1 = mix_hash (hs_h hs) re_bytes 522 ee <- note InvalidKey (ecdh (hs_e_sec hs) re) 523 let !(ck1, temp_k2) = mix_key (hs_ck hs) ee 524 _ <- note InvalidMAC 525 (decrypt_with_ad temp_k2 0 h1 c) 526 let !h2 = mix_hash h1 c 527 !s_pub_bytes = serialize_pub (hs_s_pub hs) 528 c3 <- note InvalidMAC 529 (encrypt_with_ad temp_k2 1 h2 s_pub_bytes) 530 let !h3 = mix_hash h2 c3 531 se <- note InvalidKey (ecdh (hs_s_sec hs) re) 532 let !(ck2, temp_k3) = mix_key ck1 se 533 t <- note InvalidMAC 534 (encrypt_with_ad temp_k3 0 h3 BS.empty) 535 let !(sk, rk) = mix_key ck2 BS.empty 536 !msg = BS.singleton 0x00 <> c3 <> t 537 !sess = Session { 538 sess_sk = Key32 sk 539 , sess_sn = SessionNonce 0 540 , sess_sck = Key32 ck2 541 , sess_rk = Key32 rk 542 , sess_rn = SessionNonce 0 543 , sess_rck = Key32 ck2 544 } 545 rs <- note InvalidPub (hs_rs hs) 546 let !result = Handshake { 547 session = sess 548 , remote_static = rs 549 } 550 pure (msg, result) 551 552 -- | Responder: process Act 3 (66 bytes) and complete the 553 -- handshake. 554 -- 555 -- Returns the handshake result with authenticated remote 556 -- static pubkey. 557 -- 558 -- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) 559 -- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) 560 -- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) 561 -- >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1 562 -- >>> let Right (msg3, _) = act3 i_hs msg2 563 -- >>> case finalize r_hs msg3 of { Right _ -> "ok"; Left e -> show e } 564 -- "ok" 565 finalize 566 :: HandshakeFor Responder 567 -> BS.ByteString 568 -> Either Error Handshake 569 finalize (HandshakeFor hs) msg3 = do 570 require (BS.length msg3 == 66) InvalidLength 571 let !version = BS.index msg3 0 572 !c = BS.take 49 (BS.drop 1 msg3) 573 !t = BS.drop 50 msg3 574 require (version == 0x00) InvalidVersion 575 rs_bytes <- note InvalidMAC 576 (decrypt_with_ad (hs_temp_k hs) 1 (hs_h hs) c) 577 rs <- note InvalidPub (parse_pub rs_bytes) 578 let !h1 = mix_hash (hs_h hs) c 579 se <- note InvalidKey (ecdh (hs_e_sec hs) rs) 580 let !(ck1, temp_k3) = mix_key (hs_ck hs) se 581 _ <- note InvalidMAC 582 (decrypt_with_ad temp_k3 0 h1 t) 583 -- responder swaps order (receives what initiator sends) 584 let !(rk, sk) = mix_key ck1 BS.empty 585 !sess = Session { 586 sess_sk = Key32 sk 587 , sess_sn = SessionNonce 0 588 , sess_sck = Key32 ck1 589 , sess_rk = Key32 rk 590 , sess_rn = SessionNonce 0 591 , sess_rck = Key32 ck1 592 } 593 !result = Handshake { 594 session = sess 595 , remote_static = rs 596 } 597 pure result 598 599 -- message encryption ---------------------------------------------- 600 601 -- | Encrypt a message (max 65535 bytes). 602 -- 603 -- Returns the encrypted packet and updated session. Key 604 -- rotation is handled automatically at nonce 1000. 605 -- 606 -- Wire format: 607 -- encrypted_length (2) || MAC (16) 608 -- || encrypted_body || MAC (16) 609 -- 610 -- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) 611 -- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) 612 -- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) 613 -- >>> let Right (msg2, _) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1 614 -- >>> let Right (_, i_result) = act3 i_hs msg2 615 -- >>> let sess = session i_result 616 -- >>> case encrypt sess "hello" of { Right (ct, _) -> BS.length ct; Left _ -> 0 } 617 -- 39 618 encrypt 619 :: Session 620 -> BS.ByteString 621 -> Either Error (BS.ByteString, Session) 622 encrypt sess pt = do 623 let !len = BS.length pt 624 require (len <= 65535) InvalidLength 625 let !len_bytes = encode_be16 (fi len) 626 !sk = unKey32 (sess_sk sess) 627 !sn = unSessionNonce (sess_sn sess) 628 !sck = unKey32 (sess_sck sess) 629 lc <- note InvalidMAC 630 (encrypt_with_ad sk sn BS.empty len_bytes) 631 let !(sn1, sck1, sk1) = step_nonce sn sck sk 632 bc <- note InvalidMAC 633 (encrypt_with_ad sk1 sn1 BS.empty pt) 634 let !(sn2, sck2, sk2) = step_nonce sn1 sck1 sk1 635 !packet = lc <> bc 636 !sess' = sess { 637 sess_sk = Key32 sk2 638 , sess_sn = SessionNonce sn2 639 , sess_sck = Key32 sck2 640 } 641 pure (packet, sess') 642 643 -- | Decrypt a message, requiring an exact packet with no 644 -- trailing bytes. 645 -- 646 -- Returns the plaintext and updated session. Key rotation 647 -- is handled automatically at nonce 1000. 648 -- 649 -- This is a strict variant that rejects any trailing data. 650 -- For streaming use cases where you need to handle multiple 651 -- frames in a buffer, use 'decrypt_frame' instead. 652 -- 653 -- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) 654 -- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) 655 -- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) 656 -- >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1 657 -- >>> let Right (msg3, i_result) = act3 i_hs msg2 658 -- >>> let Right r_result = finalize r_hs msg3 659 -- >>> let Right (ct, _) = encrypt (session i_result) "hello" 660 -- >>> case decrypt (session r_result) ct of { Right (pt, _) -> pt; Left _ -> "fail" } 661 -- "hello" 662 decrypt 663 :: Session 664 -> BS.ByteString 665 -> Either Error (BS.ByteString, Session) 666 decrypt sess packet = do 667 (pt, remainder, sess') <- decrypt_frame sess packet 668 require (BS.null remainder) InvalidLength 669 pure (pt, sess') 670 671 -- | Decrypt a single frame from a buffer, returning the 672 -- remainder. 673 -- 674 -- Returns the plaintext, any unconsumed bytes, and the 675 -- updated session. Key rotation is handled automatically 676 -- every 1000 messages. 677 -- 678 -- This is useful for streaming scenarios where multiple 679 -- messages may be buffered together. The remainder can be 680 -- passed to the next call to 'decrypt_frame'. 681 -- 682 -- Wire format consumed: 683 -- encrypted_length (18) || encrypted_body (len + 16) 684 -- 685 -- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) 686 -- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) 687 -- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) 688 -- >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1 689 -- >>> let Right (msg3, i_result) = act3 i_hs msg2 690 -- >>> let Right r_result = finalize r_hs msg3 691 -- >>> let Right (ct, _) = encrypt (session i_result) "hello" 692 -- >>> case decrypt_frame (session r_result) ct of { Right (pt, rem, _) -> (pt, BS.null rem); Left _ -> ("fail", False) } 693 -- ("hello",True) 694 decrypt_frame 695 :: Session 696 -> BS.ByteString 697 -> Either Error 698 (BS.ByteString, BS.ByteString, Session) 699 decrypt_frame sess packet = do 700 require (BS.length packet >= 34) InvalidLength 701 let !lc = BS.take 18 packet 702 !rest = BS.drop 18 packet 703 !rk = unKey32 (sess_rk sess) 704 !rn = unSessionNonce (sess_rn sess) 705 !rck = unKey32 (sess_rck sess) 706 len_bytes <- note InvalidMAC 707 (decrypt_with_ad rk rn BS.empty lc) 708 len <- note InvalidLength (decode_be16 len_bytes) 709 let !(rn1, rck1, rk1) = step_nonce rn rck rk 710 !body_len = fi len + 16 711 require (BS.length rest >= body_len) InvalidLength 712 let !bc = BS.take body_len rest 713 !remainder = BS.drop body_len rest 714 pt <- note InvalidMAC 715 (decrypt_with_ad rk1 rn1 BS.empty bc) 716 let !(rn2, rck2, rk2) = step_nonce rn1 rck1 rk1 717 !sess' = sess { 718 sess_rk = Key32 rk2 719 , sess_rn = SessionNonce rn2 720 , sess_rck = Key32 rck2 721 } 722 pure (pt, remainder, sess') 723 724 -- | Decrypt a frame from a partial buffer, indicating when 725 -- more data needed. 726 -- 727 -- Unlike 'decrypt_frame', this function handles incomplete 728 -- buffers gracefully by returning 'NeedMore' with the 729 -- number of additional bytes required to make progress. 730 -- 731 -- * If the buffer has fewer than 18 bytes (encrypted 732 -- length + MAC), returns @'NeedMore' n@ where @n@ is 733 -- the bytes still needed. 734 -- * If the length header is complete but the body is 735 -- incomplete, returns @'NeedMore' n@ with bytes needed 736 -- for the full frame. 737 -- * MAC or decryption failures return 'FrameError'. 738 -- * A complete, valid frame returns 'FrameOk' with 739 -- plaintext, remainder, and updated session. 740 -- 741 -- This is useful for non-blocking I/O where data arrives 742 -- incrementally. 743 decrypt_frame_partial 744 :: Session 745 -> BS.ByteString 746 -> FrameResult 747 decrypt_frame_partial sess buf 748 | buflen < 18 = NeedMore (18 - buflen) 749 | otherwise = 750 let !lc = BS.take 18 buf 751 !rest = BS.drop 18 buf 752 !rk = unKey32 (sess_rk sess) 753 !rn = unSessionNonce (sess_rn sess) 754 !rck = unKey32 (sess_rck sess) 755 in case decrypt_with_ad rk rn BS.empty lc of 756 Nothing -> FrameError InvalidMAC 757 Just len_bytes -> 758 case decode_be16 len_bytes of 759 Nothing -> FrameError InvalidLength 760 Just len -> 761 let !body_len = fi len + 16 762 !(rn1, rck1, rk1) = 763 step_nonce rn rck rk 764 in if BS.length rest < body_len 765 then NeedMore 766 (body_len - BS.length rest) 767 else 768 let !bc = BS.take body_len rest 769 !remainder = 770 BS.drop body_len rest 771 in case decrypt_with_ad 772 rk1 rn1 BS.empty bc of 773 Nothing -> 774 FrameError InvalidMAC 775 Just pt -> 776 let !(rn2, rck2, rk2) = 777 step_nonce rn1 rck1 rk1 778 !sess' = sess { 779 sess_rk = Key32 rk2 780 , sess_rn = 781 SessionNonce rn2 782 , sess_rck = Key32 rck2 783 } 784 in FrameOk pt remainder sess' 785 where 786 !buflen = BS.length buf 787 788 -- key rotation ---------------------------------------------------- 789 790 -- Key rotation occurs after nonce reaches 1000 (i.e., before 791 -- using 1000) 792 -- (ck', k') = HKDF(ck, k), reset nonce to 0 793 step_nonce 794 :: Word64 795 -> BS.ByteString 796 -> BS.ByteString 797 -> (Word64, BS.ByteString, BS.ByteString) 798 step_nonce n ck k 799 | n + 1 == 1000 = 800 let !(ck', k') = mix_key ck k 801 in (0, ck', k') 802 | otherwise = (n + 1, ck, k) 803 804 -- utilities ------------------------------------------------------- 805 806 -- Lift Maybe to Either 807 note :: e -> Maybe a -> Either e a 808 note e = maybe (Left e) Right 809 {-# INLINE note #-} 810 811 -- Require condition or fail 812 require :: Bool -> e -> Either e () 813 require cond e = unless cond (Left e) 814 {-# INLINE require #-} 815 816 fi :: (Integral a, Num b) => a -> b 817 fi = fromIntegral 818 {-# INLINE fi #-}