Internal.hs (26617B)
1 {-# OPTIONS_HADDOCK prune #-} 2 {-# LANGUAGE BangPatterns #-} 3 {-# LANGUAGE DeriveGeneric #-} 4 {-# LANGUAGE LambdaCase #-} 5 {-# LANGUAGE OverloadedStrings #-} 6 {-# LANGUAGE RecordWildCards #-} 7 {-# LANGUAGE ViewPatterns #-} 8 9 -- | 10 -- Module: Lightning.Protocol.BOLT8.Internal 11 -- Copyright: (c) 2025 Jared Tobin 12 -- License: MIT 13 -- Maintainer: Jared Tobin <jared@ppad.tech> 14 -- 15 -- Internal module exporting all constructors for testing and 16 -- benchmarking. Prefer "Lightning.Protocol.BOLT8" for general use. 17 18 module Lightning.Protocol.BOLT8.Internal ( 19 -- * Keys 20 Sec(..) 21 , Pub(..) 22 , keypair 23 , parse_pub 24 , serialize_pub 25 26 -- * Newtypes 27 , Key32(..) 28 , key32 29 , unsafeKey32 30 , SessionNonce(..) 31 , MessagePayload(..) 32 , mkMessagePayload 33 34 -- * Handshake roles 35 , Initiator 36 , Responder 37 , HandshakeFor(..) 38 39 -- * Handshake (initiator) 40 , act1 41 , act3 42 43 -- * Handshake (responder) 44 , act2 45 , finalize 46 47 -- * Session 48 , Session(..) 49 , HandshakeState(..) 50 , Handshake(..) 51 , encrypt 52 , decrypt 53 , decrypt_frame 54 , decrypt_frame_partial 55 , FrameResult(..) 56 57 -- * Errors 58 , Error(..) 59 ) where 60 61 import Control.Monad (guard, unless) 62 import qualified Crypto.AEAD.ChaCha20Poly1305 as AEAD 63 import qualified Crypto.Curve.Secp256k1 as Secp256k1 64 import qualified Crypto.Hash.SHA256 as SHA256 65 import qualified Crypto.KDF.HMAC as HKDF 66 import Data.Bits (unsafeShiftR, xor, (.&.), (.|.)) 67 import qualified Data.ByteString as BS 68 import qualified Data.ByteString.Unsafe as BU 69 import Data.Word (Word16, Word64) 70 import GHC.Generics (Generic) 71 72 -- types ----------------------------------------------------------- 73 74 -- | Secret key (32 bytes). 75 -- 76 -- The 'Eq' instance compares in constant time. 77 newtype Sec = Sec BS.ByteString 78 deriving Generic 79 80 instance Eq Sec where 81 Sec a == Sec b = ct_eq_bs a b 82 83 -- | Compressed public key. 84 newtype Pub = Pub Secp256k1.Projective 85 86 instance Eq Pub where 87 (Pub a) == (Pub b) = 88 Secp256k1.serialize_point a 89 == Secp256k1.serialize_point b 90 91 instance Show Pub where 92 show (Pub p) = 93 "Pub " ++ show (Secp256k1.serialize_point p) 94 95 -- | A 32-byte key, validated at construction. 96 -- 97 -- The 'Eq' instance compares in constant time. 98 newtype Key32 = Key32 { unKey32 :: BS.ByteString } 99 deriving Generic 100 101 instance Eq Key32 where 102 Key32 a == Key32 b = ct_eq_bs a b 103 104 -- | Constant-time bytestring equality via XOR-accumulate. 105 -- 106 -- Returns 'False' immediately on length mismatch (length is 107 -- not considered secret); otherwise inspects every byte 108 -- regardless of where (or whether) inputs differ. 109 ct_eq_bs :: BS.ByteString -> BS.ByteString -> Bool 110 ct_eq_bs a b 111 | la /= lb = False 112 | otherwise = go 0 0 113 where 114 !la = BS.length a 115 !lb = BS.length b 116 go !i !acc 117 | i >= la = acc == 0 118 | otherwise = 119 let !x = BU.unsafeIndex a i 120 !y = BU.unsafeIndex b i 121 in go (i + 1) (acc .|. (x `xor` y)) 122 {-# NOINLINE ct_eq_bs #-} 123 124 -- | Construct a 'Key32' from a 32-byte 'BS.ByteString'. 125 -- 126 -- Returns 'Nothing' if the input is not exactly 32 bytes. 127 -- 128 -- >>> key32 (BS.replicate 32 0x00) 129 -- Just (Key32 {unKey32 = ...}) 130 -- >>> key32 (BS.replicate 31 0x00) 131 -- Nothing 132 key32 :: BS.ByteString -> Maybe Key32 133 key32 bs 134 | BS.length bs == 32 = Just (Key32 bs) 135 | otherwise = Nothing 136 137 -- | Construct a 'Key32' without validation. 138 -- 139 -- For test and benchmark use only; prefer 'key32'. 140 unsafeKey32 :: BS.ByteString -> Key32 141 unsafeKey32 = Key32 142 143 -- | Session nonce, distinguishing send from receive direction. 144 newtype SessionNonce = 145 SessionNonce { unSessionNonce :: Word64 } 146 deriving (Eq, Generic) 147 148 -- | Message payload (max 65535 bytes), validated at construction. 149 newtype MessagePayload = 150 MessagePayload { unMessagePayload :: BS.ByteString } 151 deriving (Eq, Generic) 152 153 -- | Construct a 'MessagePayload' from a 'BS.ByteString'. 154 -- 155 -- Returns 'Left' if the payload exceeds 65535 bytes. 156 mkMessagePayload 157 :: BS.ByteString -> Either Error MessagePayload 158 mkMessagePayload bs 159 | BS.length bs > 65535 = Left InvalidLength 160 | otherwise = Right (MessagePayload bs) 161 162 -- | Handshake errors. 163 data Error = 164 InvalidKey 165 | InvalidPub 166 | InvalidMAC 167 | InvalidVersion 168 | InvalidLength 169 | DecryptionFailed 170 deriving (Eq, Show, Generic) 171 172 -- | Result of attempting to decrypt a frame from a partial 173 -- buffer. 174 data FrameResult = 175 NeedMore {-# UNPACK #-} !Int 176 -- ^ More bytes needed; the 'Int' is the minimum 177 -- additional bytes required. 178 | FrameOk !BS.ByteString !BS.ByteString !Session 179 -- ^ Successfully decrypted: plaintext, remainder, 180 -- updated session. 181 | FrameError !Error 182 -- ^ Decryption failed with the given error. 183 deriving Generic 184 185 -- | Post-handshake session state. 186 data Session = Session { 187 sess_sk :: !Key32 188 -- ^ send key (32 bytes) 189 , sess_sn :: !SessionNonce 190 -- ^ send nonce 191 , sess_sck :: !Key32 192 -- ^ send chaining key 193 , sess_rk :: !Key32 194 -- ^ receive key (32 bytes) 195 , sess_rn :: !SessionNonce 196 -- ^ receive nonce 197 , sess_rck :: !Key32 198 -- ^ receive chaining key 199 } 200 deriving Generic 201 202 -- | Result of a successful handshake. 203 data Handshake = Handshake { 204 session :: !Session 205 -- ^ session state 206 , remote_static :: !Pub 207 -- ^ authenticated remote static pubkey 208 } 209 deriving Generic 210 211 -- | Internal handshake state (exported for benchmarking). 212 data HandshakeState = HandshakeState { 213 hs_h :: {-# UNPACK #-} !BS.ByteString 214 -- ^ handshake hash (32 bytes) 215 , hs_ck :: {-# UNPACK #-} !BS.ByteString 216 -- ^ chaining key (32 bytes) 217 , hs_temp_k :: {-# UNPACK #-} !BS.ByteString 218 -- ^ temp key (32 bytes) 219 , hs_e_sec :: !Sec 220 -- ^ ephemeral secret 221 , hs_e_pub :: !Pub 222 -- ^ ephemeral public 223 , hs_s_sec :: !Sec 224 -- ^ static secret 225 , hs_s_pub :: !Pub 226 -- ^ static public 227 , hs_re :: !(Maybe Pub) 228 -- ^ remote ephemeral 229 , hs_rs :: !(Maybe Pub) 230 -- ^ remote static 231 } 232 deriving Generic 233 234 -- handshake roles ------------------------------------------------- 235 236 -- | Phantom type for initiator role. 237 data Initiator 238 239 -- | Phantom type for responder role. 240 data Responder 241 242 -- | Role-indexed handshake state. 243 -- 244 -- The phantom type parameter prevents passing an initiator's 245 -- state to a responder function and vice versa. 246 data HandshakeFor a = 247 HandshakeFor { unHandshakeFor :: !HandshakeState } 248 249 -- protocol constants ---------------------------------------------- 250 251 _PROTOCOL_NAME :: BS.ByteString 252 _PROTOCOL_NAME = 253 "Noise_XK_secp256k1_ChaChaPoly_SHA256" 254 255 _PROLOGUE :: BS.ByteString 256 _PROLOGUE = "lightning" 257 258 -- key operations -------------------------------------------------- 259 260 -- | Derive a keypair from 32 bytes of entropy. 261 -- 262 -- Returns Nothing if the entropy is invalid 263 -- (zero or >= curve order). 264 -- 265 -- >>> let ent = BS.replicate 32 0x11 266 -- >>> case keypair ent of 267 -- ... Just _ -> "ok" 268 -- ... Nothing -> "fail" 269 -- "ok" 270 -- >>> keypair (BS.replicate 31 0x11) -- wrong length 271 -- Nothing 272 keypair :: BS.ByteString -> Maybe (Sec, Pub) 273 keypair ent = do 274 guard (BS.length ent == 32) 275 k <- Secp256k1.parse_int256 ent 276 p <- Secp256k1.derive_pub k 277 pure (Sec ent, Pub p) 278 279 -- | Parse a 33-byte compressed public key. 280 -- 281 -- >>> let Just (_, pub) = keypair (BS.replicate 32 0x11) 282 -- >>> let bytes = serialize_pub pub 283 -- >>> case parse_pub bytes of 284 -- ... Just _ -> "ok" 285 -- ... Nothing -> "fail" 286 -- "ok" 287 -- >>> parse_pub (BS.replicate 32 0x00) -- wrong length 288 -- Nothing 289 parse_pub :: BS.ByteString -> Maybe Pub 290 parse_pub bs = do 291 guard (BS.length bs == 33) 292 p <- Secp256k1.parse_point bs 293 pure (Pub p) 294 295 -- | Serialize a public key to 33-byte compressed form. 296 -- 297 -- >>> let Just (_, pub) = keypair (BS.replicate 32 0x11) 298 -- >>> BS.length (serialize_pub pub) 299 -- 33 300 serialize_pub :: Pub -> BS.ByteString 301 serialize_pub (Pub p) = Secp256k1.serialize_point p 302 303 -- cryptographic primitives ---------------------------------------- 304 305 -- bolt8-style ECDH 306 ecdh :: Sec -> Pub -> Maybe BS.ByteString 307 ecdh (Sec sec) (Pub pub) = do 308 k <- Secp256k1.parse_int256 sec 309 pt <- Secp256k1.mul pub k 310 let compressed = Secp256k1.serialize_point pt 311 pure (SHA256.hash compressed) 312 313 -- h' = SHA256(h || data) 314 mix_hash 315 :: BS.ByteString -> BS.ByteString -> BS.ByteString 316 mix_hash h dat = SHA256.hash (h <> dat) 317 318 -- Mix key: (ck', k) = HKDF(ck, input_key_material) 319 -- 320 -- NB HKDF limits output to 255 * hashlen bytes. For SHA256 321 -- that's 8160, well above the 64 bytes requested here, so 322 -- 'Nothing' is impossible. 323 mix_key 324 :: BS.ByteString 325 -> BS.ByteString 326 -> (BS.ByteString, BS.ByteString) 327 mix_key ck ikm = 328 case HKDF.derive hmac ck mempty 64 ikm of 329 Nothing -> 330 error 331 "ppad-bolt8: internal error, please report a bug!" 332 Just output -> BS.splitAt 32 output 333 where 334 hmac k b = case SHA256.hmac k b of 335 SHA256.MAC mac -> mac 336 337 -- Encrypt with associated data using ChaCha20-Poly1305 338 encrypt_with_ad 339 :: BS.ByteString -- ^ key (32 bytes) 340 -> Word64 -- ^ nonce 341 -> BS.ByteString -- ^ associated data 342 -> BS.ByteString -- ^ plaintext 343 -> Maybe BS.ByteString -- ^ ciphertext || mac (16 bytes) 344 encrypt_with_ad key n ad pt = 345 case AEAD.encrypt ad key (encode_nonce n) pt of 346 Left _ -> Nothing 347 Right (ct, mac) -> Just (ct <> mac) 348 349 -- Decrypt with associated data using ChaCha20-Poly1305 350 decrypt_with_ad 351 :: BS.ByteString -- ^ key (32 bytes) 352 -> Word64 -- ^ nonce 353 -> BS.ByteString -- ^ associated data 354 -> BS.ByteString -- ^ ciphertext || mac 355 -> Maybe BS.ByteString -- ^ plaintext 356 decrypt_with_ad key n ad ctmac 357 | BS.length ctmac < 16 = Nothing 358 | otherwise = 359 let (ct, mac) = 360 BS.splitAt (BS.length ctmac - 16) ctmac 361 in case AEAD.decrypt ad key (encode_nonce n) 362 (ct, mac) of 363 Left _ -> Nothing 364 Right pt -> Just pt 365 366 -- Encode nonce as 96-bit value: 4 zero bytes + 8-byte LE 367 encode_nonce :: Word64 -> BS.ByteString 368 encode_nonce n = BS.replicate 4 0x00 <> encode_le64 n 369 370 -- Little-endian 64-bit encoding 371 encode_le64 :: Word64 -> BS.ByteString 372 encode_le64 n = BS.pack [ 373 fi (n .&. 0xff) 374 , fi (unsafeShiftR n 8 .&. 0xff) 375 , fi (unsafeShiftR n 16 .&. 0xff) 376 , fi (unsafeShiftR n 24 .&. 0xff) 377 , fi (unsafeShiftR n 32 .&. 0xff) 378 , fi (unsafeShiftR n 40 .&. 0xff) 379 , fi (unsafeShiftR n 48 .&. 0xff) 380 , fi (unsafeShiftR n 56 .&. 0xff) 381 ] 382 383 -- Big-endian 16-bit encoding 384 encode_be16 :: Word16 -> BS.ByteString 385 encode_be16 n = 386 BS.pack [fi (unsafeShiftR n 8), fi (n .&. 0xff)] 387 388 -- Big-endian 16-bit decoding 389 decode_be16 :: BS.ByteString -> Maybe Word16 390 decode_be16 bs 391 | BS.length bs /= 2 = Nothing 392 | otherwise = 393 let !b0 = BS.index bs 0 394 !b1 = BS.index bs 1 395 in Just (fi b0 * 0x100 + fi b1) 396 397 -- handshake ------------------------------------------------------- 398 399 -- Initialize handshake state 400 -- 401 -- h = SHA256(protocol_name) 402 -- ck = h 403 -- h = SHA256(h || prologue) 404 -- h = SHA256(h || responder_static_pubkey) 405 init_handshake 406 :: Sec -- ^ local static secret 407 -> Pub -- ^ local static public 408 -> Sec -- ^ ephemeral secret 409 -> Pub -- ^ ephemeral public 410 -> Maybe Pub -- ^ remote static 411 -> Bool -- ^ True if initiator 412 -> HandshakeState 413 init_handshake s_sec s_pub e_sec e_pub m_rs is_init = 414 let !h0 = SHA256.hash _PROTOCOL_NAME 415 !ck = h0 416 !h1 = mix_hash h0 _PROLOGUE 417 -- Mix in responder's static pubkey 418 !h2 = case (is_init, m_rs) of 419 (True, Just rs) -> 420 mix_hash h1 (serialize_pub rs) 421 (False, Nothing) -> 422 mix_hash h1 (serialize_pub s_pub) 423 _ -> h1 -- shouldn't happen 424 in HandshakeState { 425 hs_h = h2 426 , hs_ck = ck 427 , hs_temp_k = BS.replicate 32 0x00 428 , hs_e_sec = e_sec 429 , hs_e_pub = e_pub 430 , hs_s_sec = s_sec 431 , hs_s_pub = s_pub 432 , hs_re = Nothing 433 , hs_rs = m_rs 434 } 435 436 -- | Initiator: generate Act 1 message (50 bytes). 437 -- 438 -- Takes local static key, remote static pubkey, and 32 439 -- bytes of entropy for ephemeral key generation. 440 -- 441 -- Returns the 50-byte Act 1 message and handshake state 442 -- for Act 3. 443 -- 444 -- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) 445 -- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) 446 -- >>> let eph_ent = BS.replicate 32 0x12 447 -- >>> case act1 i_sec i_pub r_pub eph_ent of { Right (msg, _) -> BS.length msg; Left _ -> 0 } 448 -- 50 449 act1 450 :: Sec 451 -> Pub 452 -> Pub 453 -> BS.ByteString 454 -> Either Error 455 (BS.ByteString, HandshakeFor Initiator) 456 act1 s_sec s_pub rs ent = do 457 (e_sec, e_pub) <- note InvalidKey (keypair ent) 458 let !hs0 = init_handshake 459 s_sec s_pub e_sec e_pub (Just rs) True 460 !e_pub_bytes = serialize_pub e_pub 461 !h1 = mix_hash (hs_h hs0) e_pub_bytes 462 es <- note InvalidKey (ecdh e_sec rs) 463 let !(ck1, temp_k1) = mix_key (hs_ck hs0) es 464 c <- note InvalidMAC 465 (encrypt_with_ad temp_k1 0 h1 BS.empty) 466 let !h2 = mix_hash h1 c 467 !msg = BS.singleton 0x00 <> e_pub_bytes <> c 468 !hs1 = hs0 { 469 hs_h = h2 470 , hs_ck = ck1 471 , hs_temp_k = temp_k1 472 } 473 pure (msg, HandshakeFor hs1) 474 475 -- | Responder: process Act 1 and generate Act 2 message 476 -- (50 bytes). 477 -- 478 -- Takes local static key and 32 bytes of entropy for 479 -- ephemeral key, plus the 50-byte Act 1 message from 480 -- initiator. 481 -- 482 -- Returns the 50-byte Act 2 message and handshake state 483 -- for finalize. 484 -- 485 -- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) 486 -- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) 487 -- >>> let Right (msg1, _) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) 488 -- >>> case act2 r_sec r_pub (BS.replicate 32 0x22) msg1 of { Right (msg, _) -> BS.length msg; Left _ -> 0 } 489 -- 50 490 act2 491 :: Sec 492 -> Pub 493 -> BS.ByteString 494 -> BS.ByteString 495 -> Either Error 496 (BS.ByteString, HandshakeFor Responder) 497 act2 s_sec s_pub ent msg1 = do 498 require (BS.length msg1 == 50) InvalidLength 499 let !version = BS.index msg1 0 500 !re_bytes = BS.take 33 (BS.drop 1 msg1) 501 !c = BS.drop 34 msg1 502 require (version == 0x00) InvalidVersion 503 re <- note InvalidPub (parse_pub re_bytes) 504 (e_sec, e_pub) <- note InvalidKey (keypair ent) 505 let !hs0 = init_handshake 506 s_sec s_pub e_sec e_pub Nothing False 507 !h1 = mix_hash (hs_h hs0) re_bytes 508 es <- note InvalidKey (ecdh s_sec re) 509 let !(ck1, temp_k1) = mix_key (hs_ck hs0) es 510 _ <- note InvalidMAC 511 (decrypt_with_ad temp_k1 0 h1 c) 512 let !h2 = mix_hash h1 c 513 !e_pub_bytes = serialize_pub e_pub 514 !h3 = mix_hash h2 e_pub_bytes 515 ee <- note InvalidKey (ecdh e_sec re) 516 let !(ck2, temp_k2) = mix_key ck1 ee 517 c2 <- note InvalidMAC 518 (encrypt_with_ad temp_k2 0 h3 BS.empty) 519 let !h4 = mix_hash h3 c2 520 !msg = BS.singleton 0x00 <> e_pub_bytes <> c2 521 !hs1 = hs0 { 522 hs_h = h4 523 , hs_ck = ck2 524 , hs_temp_k = temp_k2 525 , hs_re = Just re 526 } 527 pure (msg, HandshakeFor hs1) 528 529 -- | Initiator: process Act 2 and generate Act 3 (66 bytes), 530 -- completing the handshake. 531 -- 532 -- Returns the 66-byte Act 3 message and the handshake 533 -- result. 534 -- 535 -- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) 536 -- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) 537 -- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) 538 -- >>> let Right (msg2, _) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1 539 -- >>> case act3 i_hs msg2 of { Right (msg, _) -> BS.length msg; Left _ -> 0 } 540 -- 66 541 act3 542 :: HandshakeFor Initiator 543 -> BS.ByteString 544 -> Either Error (BS.ByteString, Handshake) 545 act3 (HandshakeFor hs) msg2 = do 546 require (BS.length msg2 == 50) InvalidLength 547 let !version = BS.index msg2 0 548 !re_bytes = BS.take 33 (BS.drop 1 msg2) 549 !c = BS.drop 34 msg2 550 require (version == 0x00) InvalidVersion 551 re <- note InvalidPub (parse_pub re_bytes) 552 let !h1 = mix_hash (hs_h hs) re_bytes 553 ee <- note InvalidKey (ecdh (hs_e_sec hs) re) 554 let !(ck1, temp_k2) = mix_key (hs_ck hs) ee 555 _ <- note InvalidMAC 556 (decrypt_with_ad temp_k2 0 h1 c) 557 let !h2 = mix_hash h1 c 558 !s_pub_bytes = serialize_pub (hs_s_pub hs) 559 c3 <- note InvalidMAC 560 (encrypt_with_ad temp_k2 1 h2 s_pub_bytes) 561 let !h3 = mix_hash h2 c3 562 se <- note InvalidKey (ecdh (hs_s_sec hs) re) 563 let !(ck2, temp_k3) = mix_key ck1 se 564 t <- note InvalidMAC 565 (encrypt_with_ad temp_k3 0 h3 BS.empty) 566 let !(sk, rk) = mix_key ck2 BS.empty 567 !msg = BS.singleton 0x00 <> c3 <> t 568 !sess = Session { 569 sess_sk = Key32 sk 570 , sess_sn = SessionNonce 0 571 , sess_sck = Key32 ck2 572 , sess_rk = Key32 rk 573 , sess_rn = SessionNonce 0 574 , sess_rck = Key32 ck2 575 } 576 rs <- note InvalidPub (hs_rs hs) 577 let !result = Handshake { 578 session = sess 579 , remote_static = rs 580 } 581 pure (msg, result) 582 583 -- | Responder: process Act 3 (66 bytes) and complete the 584 -- handshake. 585 -- 586 -- Returns the handshake result with authenticated remote 587 -- static pubkey. 588 -- 589 -- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) 590 -- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) 591 -- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) 592 -- >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1 593 -- >>> let Right (msg3, _) = act3 i_hs msg2 594 -- >>> case finalize r_hs msg3 of { Right _ -> "ok"; Left e -> show e } 595 -- "ok" 596 finalize 597 :: HandshakeFor Responder 598 -> BS.ByteString 599 -> Either Error Handshake 600 finalize (HandshakeFor hs) msg3 = do 601 require (BS.length msg3 == 66) InvalidLength 602 let !version = BS.index msg3 0 603 !c = BS.take 49 (BS.drop 1 msg3) 604 !t = BS.drop 50 msg3 605 require (version == 0x00) InvalidVersion 606 rs_bytes <- note InvalidMAC 607 (decrypt_with_ad (hs_temp_k hs) 1 (hs_h hs) c) 608 rs <- note InvalidPub (parse_pub rs_bytes) 609 let !h1 = mix_hash (hs_h hs) c 610 se <- note InvalidKey (ecdh (hs_e_sec hs) rs) 611 let !(ck1, temp_k3) = mix_key (hs_ck hs) se 612 _ <- note InvalidMAC 613 (decrypt_with_ad temp_k3 0 h1 t) 614 -- responder swaps order (receives what initiator sends) 615 let !(rk, sk) = mix_key ck1 BS.empty 616 !sess = Session { 617 sess_sk = Key32 sk 618 , sess_sn = SessionNonce 0 619 , sess_sck = Key32 ck1 620 , sess_rk = Key32 rk 621 , sess_rn = SessionNonce 0 622 , sess_rck = Key32 ck1 623 } 624 !result = Handshake { 625 session = sess 626 , remote_static = rs 627 } 628 pure result 629 630 -- message encryption ---------------------------------------------- 631 632 -- | Encrypt a message (max 65535 bytes). 633 -- 634 -- Returns the encrypted packet and updated session. Key 635 -- rotation is handled automatically at nonce 1000. 636 -- 637 -- Wire format: 638 -- encrypted_length (2) || MAC (16) 639 -- || encrypted_body || MAC (16) 640 -- 641 -- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) 642 -- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) 643 -- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) 644 -- >>> let Right (msg2, _) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1 645 -- >>> let Right (_, i_result) = act3 i_hs msg2 646 -- >>> let sess = session i_result 647 -- >>> case encrypt sess "hello" of { Right (ct, _) -> BS.length ct; Left _ -> 0 } 648 -- 39 649 encrypt 650 :: Session 651 -> BS.ByteString 652 -> Either Error (BS.ByteString, Session) 653 encrypt sess pt = do 654 let !len = BS.length pt 655 require (len <= 65535) InvalidLength 656 let !len_bytes = encode_be16 (fi len) 657 !sk = unKey32 (sess_sk sess) 658 !sn = unSessionNonce (sess_sn sess) 659 !sck = unKey32 (sess_sck sess) 660 lc <- note InvalidMAC 661 (encrypt_with_ad sk sn BS.empty len_bytes) 662 let !(sn1, sck1, sk1) = step_nonce sn sck sk 663 bc <- note InvalidMAC 664 (encrypt_with_ad sk1 sn1 BS.empty pt) 665 let !(sn2, sck2, sk2) = step_nonce sn1 sck1 sk1 666 !packet = lc <> bc 667 !sess' = sess { 668 sess_sk = Key32 sk2 669 , sess_sn = SessionNonce sn2 670 , sess_sck = Key32 sck2 671 } 672 pure (packet, sess') 673 674 -- | Decrypt a message, requiring an exact packet with no 675 -- trailing bytes. 676 -- 677 -- Returns the plaintext and updated session. Key rotation 678 -- is handled automatically at nonce 1000. 679 -- 680 -- This is a strict variant that rejects any trailing data. 681 -- For streaming use cases where you need to handle multiple 682 -- frames in a buffer, use 'decrypt_frame' instead. 683 -- 684 -- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) 685 -- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) 686 -- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) 687 -- >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1 688 -- >>> let Right (msg3, i_result) = act3 i_hs msg2 689 -- >>> let Right r_result = finalize r_hs msg3 690 -- >>> let Right (ct, _) = encrypt (session i_result) "hello" 691 -- >>> case decrypt (session r_result) ct of { Right (pt, _) -> pt; Left _ -> "fail" } 692 -- "hello" 693 decrypt 694 :: Session 695 -> BS.ByteString 696 -> Either Error (BS.ByteString, Session) 697 decrypt sess packet = do 698 (pt, remainder, sess') <- decrypt_frame sess packet 699 require (BS.null remainder) InvalidLength 700 pure (pt, sess') 701 702 -- | Decrypt a single frame from a buffer, returning the 703 -- remainder. 704 -- 705 -- Returns the plaintext, any unconsumed bytes, and the 706 -- updated session. Key rotation is handled automatically 707 -- every 1000 messages. 708 -- 709 -- This is useful for streaming scenarios where multiple 710 -- messages may be buffered together. The remainder can be 711 -- passed to the next call to 'decrypt_frame'. 712 -- 713 -- Wire format consumed: 714 -- encrypted_length (18) || encrypted_body (len + 16) 715 -- 716 -- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) 717 -- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) 718 -- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) 719 -- >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1 720 -- >>> let Right (msg3, i_result) = act3 i_hs msg2 721 -- >>> let Right r_result = finalize r_hs msg3 722 -- >>> let Right (ct, _) = encrypt (session i_result) "hello" 723 -- >>> case decrypt_frame (session r_result) ct of { Right (pt, rem, _) -> (pt, BS.null rem); Left _ -> ("fail", False) } 724 -- ("hello",True) 725 decrypt_frame 726 :: Session 727 -> BS.ByteString 728 -> Either Error 729 (BS.ByteString, BS.ByteString, Session) 730 decrypt_frame sess packet = do 731 require (BS.length packet >= 34) InvalidLength 732 let !lc = BS.take 18 packet 733 !rest = BS.drop 18 packet 734 !rk = unKey32 (sess_rk sess) 735 !rn = unSessionNonce (sess_rn sess) 736 !rck = unKey32 (sess_rck sess) 737 len_bytes <- note InvalidMAC 738 (decrypt_with_ad rk rn BS.empty lc) 739 len <- note InvalidLength (decode_be16 len_bytes) 740 let !(rn1, rck1, rk1) = step_nonce rn rck rk 741 !body_len = fi len + 16 742 require (BS.length rest >= body_len) InvalidLength 743 let !bc = BS.take body_len rest 744 !remainder = BS.drop body_len rest 745 pt <- note InvalidMAC 746 (decrypt_with_ad rk1 rn1 BS.empty bc) 747 let !(rn2, rck2, rk2) = step_nonce rn1 rck1 rk1 748 !sess' = sess { 749 sess_rk = Key32 rk2 750 , sess_rn = SessionNonce rn2 751 , sess_rck = Key32 rck2 752 } 753 pure (pt, remainder, sess') 754 755 -- | Decrypt a frame from a partial buffer, indicating when 756 -- more data needed. 757 -- 758 -- Unlike 'decrypt_frame', this function handles incomplete 759 -- buffers gracefully by returning 'NeedMore' with the 760 -- number of additional bytes required to make progress. 761 -- 762 -- * If the buffer has fewer than 18 bytes (encrypted 763 -- length + MAC), returns @'NeedMore' n@ where @n@ is 764 -- the bytes still needed. 765 -- * If the length header is complete but the body is 766 -- incomplete, returns @'NeedMore' n@ with bytes needed 767 -- for the full frame. 768 -- * MAC or decryption failures return 'FrameError'. 769 -- * A complete, valid frame returns 'FrameOk' with 770 -- plaintext, remainder, and updated session. 771 -- 772 -- This is useful for non-blocking I/O where data arrives 773 -- incrementally. 774 decrypt_frame_partial 775 :: Session 776 -> BS.ByteString 777 -> FrameResult 778 decrypt_frame_partial sess buf 779 | buflen < 18 = NeedMore (18 - buflen) 780 | otherwise = 781 let !lc = BS.take 18 buf 782 !rest = BS.drop 18 buf 783 !rk = unKey32 (sess_rk sess) 784 !rn = unSessionNonce (sess_rn sess) 785 !rck = unKey32 (sess_rck sess) 786 in case decrypt_with_ad rk rn BS.empty lc of 787 Nothing -> FrameError InvalidMAC 788 Just len_bytes -> 789 case decode_be16 len_bytes of 790 Nothing -> FrameError InvalidLength 791 Just len -> 792 let !body_len = fi len + 16 793 !(rn1, rck1, rk1) = 794 step_nonce rn rck rk 795 in if BS.length rest < body_len 796 then NeedMore 797 (body_len - BS.length rest) 798 else 799 let !bc = BS.take body_len rest 800 !remainder = 801 BS.drop body_len rest 802 in case decrypt_with_ad 803 rk1 rn1 BS.empty bc of 804 Nothing -> 805 FrameError InvalidMAC 806 Just pt -> 807 let !(rn2, rck2, rk2) = 808 step_nonce rn1 rck1 rk1 809 !sess' = sess { 810 sess_rk = Key32 rk2 811 , sess_rn = 812 SessionNonce rn2 813 , sess_rck = Key32 rck2 814 } 815 in FrameOk pt remainder sess' 816 where 817 !buflen = BS.length buf 818 819 -- key rotation ---------------------------------------------------- 820 821 -- Key rotation occurs after nonce reaches 1000 (i.e., before 822 -- using 1000) 823 -- (ck', k') = HKDF(ck, k), reset nonce to 0 824 step_nonce 825 :: Word64 826 -> BS.ByteString 827 -> BS.ByteString 828 -> (Word64, BS.ByteString, BS.ByteString) 829 step_nonce n ck k 830 | n + 1 == 1000 = 831 let !(ck', k') = mix_key ck k 832 in (0, ck', k') 833 | otherwise = (n + 1, ck, k) 834 835 -- utilities ------------------------------------------------------- 836 837 -- Lift Maybe to Either 838 note :: e -> Maybe a -> Either e a 839 note e = maybe (Left e) Right 840 {-# INLINE note #-} 841 842 -- Require condition or fail 843 require :: Bool -> e -> Either e () 844 require cond e = unless cond (Left e) 845 {-# INLINE require #-} 846 847 fi :: (Integral a, Num b) => a -> b 848 fi = fromIntegral 849 {-# INLINE fi #-}