bolt8

Encrypted and authenticated transport, per BOLT #8 (docs.ppad.tech/bolt8).
git clone git://git.ppad.tech/bolt8.git
Log | Files | Refs | README | LICENSE

Internal.hs (26617B)


      1 {-# OPTIONS_HADDOCK prune #-}
      2 {-# LANGUAGE BangPatterns #-}
      3 {-# LANGUAGE DeriveGeneric #-}
      4 {-# LANGUAGE LambdaCase #-}
      5 {-# LANGUAGE OverloadedStrings #-}
      6 {-# LANGUAGE RecordWildCards #-}
      7 {-# LANGUAGE ViewPatterns #-}
      8 
      9 -- |
     10 -- Module: Lightning.Protocol.BOLT8.Internal
     11 -- Copyright: (c) 2025 Jared Tobin
     12 -- License: MIT
     13 -- Maintainer: Jared Tobin <jared@ppad.tech>
     14 --
     15 -- Internal module exporting all constructors for testing and
     16 -- benchmarking. Prefer "Lightning.Protocol.BOLT8" for general use.
     17 
     18 module Lightning.Protocol.BOLT8.Internal (
     19     -- * Keys
     20     Sec(..)
     21   , Pub(..)
     22   , keypair
     23   , parse_pub
     24   , serialize_pub
     25 
     26     -- * Newtypes
     27   , Key32(..)
     28   , key32
     29   , unsafeKey32
     30   , SessionNonce(..)
     31   , MessagePayload(..)
     32   , mkMessagePayload
     33 
     34     -- * Handshake roles
     35   , Initiator
     36   , Responder
     37   , HandshakeFor(..)
     38 
     39     -- * Handshake (initiator)
     40   , act1
     41   , act3
     42 
     43     -- * Handshake (responder)
     44   , act2
     45   , finalize
     46 
     47     -- * Session
     48   , Session(..)
     49   , HandshakeState(..)
     50   , Handshake(..)
     51   , encrypt
     52   , decrypt
     53   , decrypt_frame
     54   , decrypt_frame_partial
     55   , FrameResult(..)
     56 
     57     -- * Errors
     58   , Error(..)
     59   ) where
     60 
     61 import Control.Monad (guard, unless)
     62 import qualified Crypto.AEAD.ChaCha20Poly1305 as AEAD
     63 import qualified Crypto.Curve.Secp256k1 as Secp256k1
     64 import qualified Crypto.Hash.SHA256 as SHA256
     65 import qualified Crypto.KDF.HMAC as HKDF
     66 import Data.Bits (unsafeShiftR, xor, (.&.), (.|.))
     67 import qualified Data.ByteString as BS
     68 import qualified Data.ByteString.Unsafe as BU
     69 import Data.Word (Word16, Word64)
     70 import GHC.Generics (Generic)
     71 
     72 -- types -----------------------------------------------------------
     73 
     74 -- | Secret key (32 bytes).
     75 --
     76 --   The 'Eq' instance compares in constant time.
     77 newtype Sec = Sec BS.ByteString
     78   deriving Generic
     79 
     80 instance Eq Sec where
     81   Sec a == Sec b = ct_eq_bs a b
     82 
     83 -- | Compressed public key.
     84 newtype Pub = Pub Secp256k1.Projective
     85 
     86 instance Eq Pub where
     87   (Pub a) == (Pub b) =
     88     Secp256k1.serialize_point a
     89       == Secp256k1.serialize_point b
     90 
     91 instance Show Pub where
     92   show (Pub p) =
     93     "Pub " ++ show (Secp256k1.serialize_point p)
     94 
     95 -- | A 32-byte key, validated at construction.
     96 --
     97 --   The 'Eq' instance compares in constant time.
     98 newtype Key32 = Key32 { unKey32 :: BS.ByteString }
     99   deriving Generic
    100 
    101 instance Eq Key32 where
    102   Key32 a == Key32 b = ct_eq_bs a b
    103 
    104 -- | Constant-time bytestring equality via XOR-accumulate.
    105 --
    106 --   Returns 'False' immediately on length mismatch (length is
    107 --   not considered secret); otherwise inspects every byte
    108 --   regardless of where (or whether) inputs differ.
    109 ct_eq_bs :: BS.ByteString -> BS.ByteString -> Bool
    110 ct_eq_bs a b
    111   | la /= lb  = False
    112   | otherwise = go 0 0
    113   where
    114     !la = BS.length a
    115     !lb = BS.length b
    116     go !i !acc
    117       | i >= la   = acc == 0
    118       | otherwise =
    119           let !x = BU.unsafeIndex a i
    120               !y = BU.unsafeIndex b i
    121           in  go (i + 1) (acc .|. (x `xor` y))
    122 {-# NOINLINE ct_eq_bs #-}
    123 
    124 -- | Construct a 'Key32' from a 32-byte 'BS.ByteString'.
    125 --
    126 --   Returns 'Nothing' if the input is not exactly 32 bytes.
    127 --
    128 --   >>> key32 (BS.replicate 32 0x00)
    129 --   Just (Key32 {unKey32 = ...})
    130 --   >>> key32 (BS.replicate 31 0x00)
    131 --   Nothing
    132 key32 :: BS.ByteString -> Maybe Key32
    133 key32 bs
    134   | BS.length bs == 32 = Just (Key32 bs)
    135   | otherwise = Nothing
    136 
    137 -- | Construct a 'Key32' without validation.
    138 --
    139 --   For test and benchmark use only; prefer 'key32'.
    140 unsafeKey32 :: BS.ByteString -> Key32
    141 unsafeKey32 = Key32
    142 
    143 -- | Session nonce, distinguishing send from receive direction.
    144 newtype SessionNonce =
    145   SessionNonce { unSessionNonce :: Word64 }
    146   deriving (Eq, Generic)
    147 
    148 -- | Message payload (max 65535 bytes), validated at construction.
    149 newtype MessagePayload =
    150   MessagePayload { unMessagePayload :: BS.ByteString }
    151   deriving (Eq, Generic)
    152 
    153 -- | Construct a 'MessagePayload' from a 'BS.ByteString'.
    154 --
    155 --   Returns 'Left' if the payload exceeds 65535 bytes.
    156 mkMessagePayload
    157   :: BS.ByteString -> Either Error MessagePayload
    158 mkMessagePayload bs
    159   | BS.length bs > 65535 = Left InvalidLength
    160   | otherwise = Right (MessagePayload bs)
    161 
    162 -- | Handshake errors.
    163 data Error =
    164     InvalidKey
    165   | InvalidPub
    166   | InvalidMAC
    167   | InvalidVersion
    168   | InvalidLength
    169   | DecryptionFailed
    170   deriving (Eq, Show, Generic)
    171 
    172 -- | Result of attempting to decrypt a frame from a partial
    173 --   buffer.
    174 data FrameResult =
    175     NeedMore {-# UNPACK #-} !Int
    176     -- ^ More bytes needed; the 'Int' is the minimum
    177     --   additional bytes required.
    178   | FrameOk !BS.ByteString !BS.ByteString !Session
    179     -- ^ Successfully decrypted: plaintext, remainder,
    180     --   updated session.
    181   | FrameError !Error
    182     -- ^ Decryption failed with the given error.
    183   deriving Generic
    184 
    185 -- | Post-handshake session state.
    186 data Session = Session {
    187     sess_sk  :: !Key32
    188     -- ^ send key (32 bytes)
    189   , sess_sn  :: !SessionNonce
    190     -- ^ send nonce
    191   , sess_sck :: !Key32
    192     -- ^ send chaining key
    193   , sess_rk  :: !Key32
    194     -- ^ receive key (32 bytes)
    195   , sess_rn  :: !SessionNonce
    196     -- ^ receive nonce
    197   , sess_rck :: !Key32
    198     -- ^ receive chaining key
    199   }
    200   deriving Generic
    201 
    202 -- | Result of a successful handshake.
    203 data Handshake = Handshake {
    204     session       :: !Session
    205     -- ^ session state
    206   , remote_static :: !Pub
    207     -- ^ authenticated remote static pubkey
    208   }
    209   deriving Generic
    210 
    211 -- | Internal handshake state (exported for benchmarking).
    212 data HandshakeState = HandshakeState {
    213     hs_h      :: {-# UNPACK #-} !BS.ByteString
    214     -- ^ handshake hash (32 bytes)
    215   , hs_ck     :: {-# UNPACK #-} !BS.ByteString
    216     -- ^ chaining key (32 bytes)
    217   , hs_temp_k :: {-# UNPACK #-} !BS.ByteString
    218     -- ^ temp key (32 bytes)
    219   , hs_e_sec  :: !Sec
    220     -- ^ ephemeral secret
    221   , hs_e_pub  :: !Pub
    222     -- ^ ephemeral public
    223   , hs_s_sec  :: !Sec
    224     -- ^ static secret
    225   , hs_s_pub  :: !Pub
    226     -- ^ static public
    227   , hs_re     :: !(Maybe Pub)
    228     -- ^ remote ephemeral
    229   , hs_rs     :: !(Maybe Pub)
    230     -- ^ remote static
    231   }
    232   deriving Generic
    233 
    234 -- handshake roles -------------------------------------------------
    235 
    236 -- | Phantom type for initiator role.
    237 data Initiator
    238 
    239 -- | Phantom type for responder role.
    240 data Responder
    241 
    242 -- | Role-indexed handshake state.
    243 --
    244 --   The phantom type parameter prevents passing an initiator's
    245 --   state to a responder function and vice versa.
    246 data HandshakeFor a =
    247   HandshakeFor { unHandshakeFor :: !HandshakeState }
    248 
    249 -- protocol constants ----------------------------------------------
    250 
    251 _PROTOCOL_NAME :: BS.ByteString
    252 _PROTOCOL_NAME =
    253   "Noise_XK_secp256k1_ChaChaPoly_SHA256"
    254 
    255 _PROLOGUE :: BS.ByteString
    256 _PROLOGUE = "lightning"
    257 
    258 -- key operations --------------------------------------------------
    259 
    260 -- | Derive a keypair from 32 bytes of entropy.
    261 --
    262 --   Returns Nothing if the entropy is invalid
    263 --   (zero or >= curve order).
    264 --
    265 --   >>> let ent = BS.replicate 32 0x11
    266 --   >>> case keypair ent of
    267 --   ...   Just _ -> "ok"
    268 --   ...   Nothing -> "fail"
    269 --   "ok"
    270 --   >>> keypair (BS.replicate 31 0x11) -- wrong length
    271 --   Nothing
    272 keypair :: BS.ByteString -> Maybe (Sec, Pub)
    273 keypair ent = do
    274   guard (BS.length ent == 32)
    275   k <- Secp256k1.parse_int256 ent
    276   p <- Secp256k1.derive_pub k
    277   pure (Sec ent, Pub p)
    278 
    279 -- | Parse a 33-byte compressed public key.
    280 --
    281 --   >>> let Just (_, pub) = keypair (BS.replicate 32 0x11)
    282 --   >>> let bytes = serialize_pub pub
    283 --   >>> case parse_pub bytes of
    284 --   ...   Just _ -> "ok"
    285 --   ...   Nothing -> "fail"
    286 --   "ok"
    287 --   >>> parse_pub (BS.replicate 32 0x00) -- wrong length
    288 --   Nothing
    289 parse_pub :: BS.ByteString -> Maybe Pub
    290 parse_pub bs = do
    291   guard (BS.length bs == 33)
    292   p <- Secp256k1.parse_point bs
    293   pure (Pub p)
    294 
    295 -- | Serialize a public key to 33-byte compressed form.
    296 --
    297 --   >>> let Just (_, pub) = keypair (BS.replicate 32 0x11)
    298 --   >>> BS.length (serialize_pub pub)
    299 --   33
    300 serialize_pub :: Pub -> BS.ByteString
    301 serialize_pub (Pub p) = Secp256k1.serialize_point p
    302 
    303 -- cryptographic primitives ----------------------------------------
    304 
    305 -- bolt8-style ECDH
    306 ecdh :: Sec -> Pub -> Maybe BS.ByteString
    307 ecdh (Sec sec) (Pub pub) = do
    308   k <- Secp256k1.parse_int256 sec
    309   pt <- Secp256k1.mul pub k
    310   let compressed = Secp256k1.serialize_point pt
    311   pure (SHA256.hash compressed)
    312 
    313 -- h' = SHA256(h || data)
    314 mix_hash
    315   :: BS.ByteString -> BS.ByteString -> BS.ByteString
    316 mix_hash h dat = SHA256.hash (h <> dat)
    317 
    318 -- Mix key: (ck', k) = HKDF(ck, input_key_material)
    319 --
    320 -- NB HKDF limits output to 255 * hashlen bytes. For SHA256
    321 -- that's 8160, well above the 64 bytes requested here, so
    322 -- 'Nothing' is impossible.
    323 mix_key
    324   :: BS.ByteString
    325   -> BS.ByteString
    326   -> (BS.ByteString, BS.ByteString)
    327 mix_key ck ikm =
    328   case HKDF.derive hmac ck mempty 64 ikm of
    329     Nothing ->
    330       error
    331         "ppad-bolt8: internal error, please report a bug!"
    332     Just output -> BS.splitAt 32 output
    333   where
    334     hmac k b = case SHA256.hmac k b of
    335       SHA256.MAC mac -> mac
    336 
    337 -- Encrypt with associated data using ChaCha20-Poly1305
    338 encrypt_with_ad
    339   :: BS.ByteString       -- ^ key (32 bytes)
    340   -> Word64              -- ^ nonce
    341   -> BS.ByteString       -- ^ associated data
    342   -> BS.ByteString       -- ^ plaintext
    343   -> Maybe BS.ByteString -- ^ ciphertext || mac (16 bytes)
    344 encrypt_with_ad key n ad pt =
    345   case AEAD.encrypt ad key (encode_nonce n) pt of
    346     Left _ -> Nothing
    347     Right (ct, mac) -> Just (ct <> mac)
    348 
    349 -- Decrypt with associated data using ChaCha20-Poly1305
    350 decrypt_with_ad
    351   :: BS.ByteString       -- ^ key (32 bytes)
    352   -> Word64              -- ^ nonce
    353   -> BS.ByteString       -- ^ associated data
    354   -> BS.ByteString       -- ^ ciphertext || mac
    355   -> Maybe BS.ByteString -- ^ plaintext
    356 decrypt_with_ad key n ad ctmac
    357   | BS.length ctmac < 16 = Nothing
    358   | otherwise =
    359       let (ct, mac) =
    360             BS.splitAt (BS.length ctmac - 16) ctmac
    361       in case AEAD.decrypt ad key (encode_nonce n)
    362                 (ct, mac) of
    363            Left _ -> Nothing
    364            Right pt -> Just pt
    365 
    366 -- Encode nonce as 96-bit value: 4 zero bytes + 8-byte LE
    367 encode_nonce :: Word64 -> BS.ByteString
    368 encode_nonce n = BS.replicate 4 0x00 <> encode_le64 n
    369 
    370 -- Little-endian 64-bit encoding
    371 encode_le64 :: Word64 -> BS.ByteString
    372 encode_le64 n = BS.pack [
    373     fi (n .&. 0xff)
    374   , fi (unsafeShiftR n 8  .&. 0xff)
    375   , fi (unsafeShiftR n 16 .&. 0xff)
    376   , fi (unsafeShiftR n 24 .&. 0xff)
    377   , fi (unsafeShiftR n 32 .&. 0xff)
    378   , fi (unsafeShiftR n 40 .&. 0xff)
    379   , fi (unsafeShiftR n 48 .&. 0xff)
    380   , fi (unsafeShiftR n 56 .&. 0xff)
    381   ]
    382 
    383 -- Big-endian 16-bit encoding
    384 encode_be16 :: Word16 -> BS.ByteString
    385 encode_be16 n =
    386   BS.pack [fi (unsafeShiftR n 8), fi (n .&. 0xff)]
    387 
    388 -- Big-endian 16-bit decoding
    389 decode_be16 :: BS.ByteString -> Maybe Word16
    390 decode_be16 bs
    391   | BS.length bs /= 2 = Nothing
    392   | otherwise =
    393       let !b0 = BS.index bs 0
    394           !b1 = BS.index bs 1
    395       in Just (fi b0 * 0x100 + fi b1)
    396 
    397 -- handshake -------------------------------------------------------
    398 
    399 -- Initialize handshake state
    400 --
    401 -- h = SHA256(protocol_name)
    402 -- ck = h
    403 -- h = SHA256(h || prologue)
    404 -- h = SHA256(h || responder_static_pubkey)
    405 init_handshake
    406   :: Sec           -- ^ local static secret
    407   -> Pub           -- ^ local static public
    408   -> Sec           -- ^ ephemeral secret
    409   -> Pub           -- ^ ephemeral public
    410   -> Maybe Pub     -- ^ remote static
    411   -> Bool          -- ^ True if initiator
    412   -> HandshakeState
    413 init_handshake s_sec s_pub e_sec e_pub m_rs is_init =
    414   let !h0 = SHA256.hash _PROTOCOL_NAME
    415       !ck = h0
    416       !h1 = mix_hash h0 _PROLOGUE
    417       -- Mix in responder's static pubkey
    418       !h2 = case (is_init, m_rs) of
    419         (True, Just rs) ->
    420           mix_hash h1 (serialize_pub rs)
    421         (False, Nothing) ->
    422           mix_hash h1 (serialize_pub s_pub)
    423         _ -> h1  -- shouldn't happen
    424   in HandshakeState {
    425        hs_h      = h2
    426      , hs_ck     = ck
    427      , hs_temp_k = BS.replicate 32 0x00
    428      , hs_e_sec  = e_sec
    429      , hs_e_pub  = e_pub
    430      , hs_s_sec  = s_sec
    431      , hs_s_pub  = s_pub
    432      , hs_re     = Nothing
    433      , hs_rs     = m_rs
    434      }
    435 
    436 -- | Initiator: generate Act 1 message (50 bytes).
    437 --
    438 --   Takes local static key, remote static pubkey, and 32
    439 --   bytes of entropy for ephemeral key generation.
    440 --
    441 --   Returns the 50-byte Act 1 message and handshake state
    442 --   for Act 3.
    443 --
    444 --   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
    445 --   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
    446 --   >>> let eph_ent = BS.replicate 32 0x12
    447 --   >>> case act1 i_sec i_pub r_pub eph_ent of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
    448 --   50
    449 act1
    450   :: Sec
    451   -> Pub
    452   -> Pub
    453   -> BS.ByteString
    454   -> Either Error
    455        (BS.ByteString, HandshakeFor Initiator)
    456 act1 s_sec s_pub rs ent = do
    457   (e_sec, e_pub) <- note InvalidKey (keypair ent)
    458   let !hs0 = init_handshake
    459                s_sec s_pub e_sec e_pub (Just rs) True
    460       !e_pub_bytes = serialize_pub e_pub
    461       !h1 = mix_hash (hs_h hs0) e_pub_bytes
    462   es <- note InvalidKey (ecdh e_sec rs)
    463   let !(ck1, temp_k1) = mix_key (hs_ck hs0) es
    464   c <- note InvalidMAC
    465          (encrypt_with_ad temp_k1 0 h1 BS.empty)
    466   let !h2 = mix_hash h1 c
    467       !msg = BS.singleton 0x00 <> e_pub_bytes <> c
    468       !hs1 = hs0 {
    469         hs_h      = h2
    470       , hs_ck     = ck1
    471       , hs_temp_k = temp_k1
    472       }
    473   pure (msg, HandshakeFor hs1)
    474 
    475 -- | Responder: process Act 1 and generate Act 2 message
    476 --   (50 bytes).
    477 --
    478 --   Takes local static key and 32 bytes of entropy for
    479 --   ephemeral key, plus the 50-byte Act 1 message from
    480 --   initiator.
    481 --
    482 --   Returns the 50-byte Act 2 message and handshake state
    483 --   for finalize.
    484 --
    485 --   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
    486 --   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
    487 --   >>> let Right (msg1, _) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
    488 --   >>> case act2 r_sec r_pub (BS.replicate 32 0x22) msg1 of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
    489 --   50
    490 act2
    491   :: Sec
    492   -> Pub
    493   -> BS.ByteString
    494   -> BS.ByteString
    495   -> Either Error
    496        (BS.ByteString, HandshakeFor Responder)
    497 act2 s_sec s_pub ent msg1 = do
    498   require (BS.length msg1 == 50) InvalidLength
    499   let !version = BS.index msg1 0
    500       !re_bytes = BS.take 33 (BS.drop 1 msg1)
    501       !c = BS.drop 34 msg1
    502   require (version == 0x00) InvalidVersion
    503   re <- note InvalidPub (parse_pub re_bytes)
    504   (e_sec, e_pub) <- note InvalidKey (keypair ent)
    505   let !hs0 = init_handshake
    506                s_sec s_pub e_sec e_pub Nothing False
    507       !h1 = mix_hash (hs_h hs0) re_bytes
    508   es <- note InvalidKey (ecdh s_sec re)
    509   let !(ck1, temp_k1) = mix_key (hs_ck hs0) es
    510   _ <- note InvalidMAC
    511          (decrypt_with_ad temp_k1 0 h1 c)
    512   let !h2 = mix_hash h1 c
    513       !e_pub_bytes = serialize_pub e_pub
    514       !h3 = mix_hash h2 e_pub_bytes
    515   ee <- note InvalidKey (ecdh e_sec re)
    516   let !(ck2, temp_k2) = mix_key ck1 ee
    517   c2 <- note InvalidMAC
    518           (encrypt_with_ad temp_k2 0 h3 BS.empty)
    519   let !h4 = mix_hash h3 c2
    520       !msg = BS.singleton 0x00 <> e_pub_bytes <> c2
    521       !hs1 = hs0 {
    522         hs_h      = h4
    523       , hs_ck     = ck2
    524       , hs_temp_k = temp_k2
    525       , hs_re     = Just re
    526       }
    527   pure (msg, HandshakeFor hs1)
    528 
    529 -- | Initiator: process Act 2 and generate Act 3 (66 bytes),
    530 --   completing the handshake.
    531 --
    532 --   Returns the 66-byte Act 3 message and the handshake
    533 --   result.
    534 --
    535 --   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
    536 --   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
    537 --   >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
    538 --   >>> let Right (msg2, _) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
    539 --   >>> case act3 i_hs msg2 of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
    540 --   66
    541 act3
    542   :: HandshakeFor Initiator
    543   -> BS.ByteString
    544   -> Either Error (BS.ByteString, Handshake)
    545 act3 (HandshakeFor hs) msg2 = do
    546   require (BS.length msg2 == 50) InvalidLength
    547   let !version = BS.index msg2 0
    548       !re_bytes = BS.take 33 (BS.drop 1 msg2)
    549       !c = BS.drop 34 msg2
    550   require (version == 0x00) InvalidVersion
    551   re <- note InvalidPub (parse_pub re_bytes)
    552   let !h1 = mix_hash (hs_h hs) re_bytes
    553   ee <- note InvalidKey (ecdh (hs_e_sec hs) re)
    554   let !(ck1, temp_k2) = mix_key (hs_ck hs) ee
    555   _ <- note InvalidMAC
    556          (decrypt_with_ad temp_k2 0 h1 c)
    557   let !h2 = mix_hash h1 c
    558       !s_pub_bytes = serialize_pub (hs_s_pub hs)
    559   c3 <- note InvalidMAC
    560           (encrypt_with_ad temp_k2 1 h2 s_pub_bytes)
    561   let !h3 = mix_hash h2 c3
    562   se <- note InvalidKey (ecdh (hs_s_sec hs) re)
    563   let !(ck2, temp_k3) = mix_key ck1 se
    564   t <- note InvalidMAC
    565          (encrypt_with_ad temp_k3 0 h3 BS.empty)
    566   let !(sk, rk) = mix_key ck2 BS.empty
    567       !msg = BS.singleton 0x00 <> c3 <> t
    568       !sess = Session {
    569         sess_sk  = Key32 sk
    570       , sess_sn  = SessionNonce 0
    571       , sess_sck = Key32 ck2
    572       , sess_rk  = Key32 rk
    573       , sess_rn  = SessionNonce 0
    574       , sess_rck = Key32 ck2
    575       }
    576   rs <- note InvalidPub (hs_rs hs)
    577   let !result = Handshake {
    578         session       = sess
    579       , remote_static = rs
    580       }
    581   pure (msg, result)
    582 
    583 -- | Responder: process Act 3 (66 bytes) and complete the
    584 --   handshake.
    585 --
    586 --   Returns the handshake result with authenticated remote
    587 --   static pubkey.
    588 --
    589 --   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
    590 --   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
    591 --   >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
    592 --   >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
    593 --   >>> let Right (msg3, _) = act3 i_hs msg2
    594 --   >>> case finalize r_hs msg3 of { Right _ -> "ok"; Left e -> show e }
    595 --   "ok"
    596 finalize
    597   :: HandshakeFor Responder
    598   -> BS.ByteString
    599   -> Either Error Handshake
    600 finalize (HandshakeFor hs) msg3 = do
    601   require (BS.length msg3 == 66) InvalidLength
    602   let !version = BS.index msg3 0
    603       !c = BS.take 49 (BS.drop 1 msg3)
    604       !t = BS.drop 50 msg3
    605   require (version == 0x00) InvalidVersion
    606   rs_bytes <- note InvalidMAC
    607     (decrypt_with_ad (hs_temp_k hs) 1 (hs_h hs) c)
    608   rs <- note InvalidPub (parse_pub rs_bytes)
    609   let !h1 = mix_hash (hs_h hs) c
    610   se <- note InvalidKey (ecdh (hs_e_sec hs) rs)
    611   let !(ck1, temp_k3) = mix_key (hs_ck hs) se
    612   _ <- note InvalidMAC
    613          (decrypt_with_ad temp_k3 0 h1 t)
    614   -- responder swaps order (receives what initiator sends)
    615   let !(rk, sk) = mix_key ck1 BS.empty
    616       !sess = Session {
    617         sess_sk  = Key32 sk
    618       , sess_sn  = SessionNonce 0
    619       , sess_sck = Key32 ck1
    620       , sess_rk  = Key32 rk
    621       , sess_rn  = SessionNonce 0
    622       , sess_rck = Key32 ck1
    623       }
    624       !result = Handshake {
    625         session       = sess
    626       , remote_static = rs
    627       }
    628   pure result
    629 
    630 -- message encryption ----------------------------------------------
    631 
    632 -- | Encrypt a message (max 65535 bytes).
    633 --
    634 --   Returns the encrypted packet and updated session. Key
    635 --   rotation is handled automatically at nonce 1000.
    636 --
    637 --   Wire format:
    638 --     encrypted_length (2) || MAC (16)
    639 --     || encrypted_body || MAC (16)
    640 --
    641 --   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
    642 --   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
    643 --   >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
    644 --   >>> let Right (msg2, _) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
    645 --   >>> let Right (_, i_result) = act3 i_hs msg2
    646 --   >>> let sess = session i_result
    647 --   >>> case encrypt sess "hello" of { Right (ct, _) -> BS.length ct; Left _ -> 0 }
    648 --   39
    649 encrypt
    650   :: Session
    651   -> BS.ByteString
    652   -> Either Error (BS.ByteString, Session)
    653 encrypt sess pt = do
    654   let !len = BS.length pt
    655   require (len <= 65535) InvalidLength
    656   let !len_bytes = encode_be16 (fi len)
    657       !sk = unKey32 (sess_sk sess)
    658       !sn = unSessionNonce (sess_sn sess)
    659       !sck = unKey32 (sess_sck sess)
    660   lc <- note InvalidMAC
    661           (encrypt_with_ad sk sn BS.empty len_bytes)
    662   let !(sn1, sck1, sk1) = step_nonce sn sck sk
    663   bc <- note InvalidMAC
    664           (encrypt_with_ad sk1 sn1 BS.empty pt)
    665   let !(sn2, sck2, sk2) = step_nonce sn1 sck1 sk1
    666       !packet = lc <> bc
    667       !sess' = sess {
    668         sess_sk  = Key32 sk2
    669       , sess_sn  = SessionNonce sn2
    670       , sess_sck = Key32 sck2
    671       }
    672   pure (packet, sess')
    673 
    674 -- | Decrypt a message, requiring an exact packet with no
    675 --   trailing bytes.
    676 --
    677 --   Returns the plaintext and updated session. Key rotation
    678 --   is handled automatically at nonce 1000.
    679 --
    680 --   This is a strict variant that rejects any trailing data.
    681 --   For streaming use cases where you need to handle multiple
    682 --   frames in a buffer, use 'decrypt_frame' instead.
    683 --
    684 --   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
    685 --   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
    686 --   >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
    687 --   >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
    688 --   >>> let Right (msg3, i_result) = act3 i_hs msg2
    689 --   >>> let Right r_result = finalize r_hs msg3
    690 --   >>> let Right (ct, _) = encrypt (session i_result) "hello"
    691 --   >>> case decrypt (session r_result) ct of { Right (pt, _) -> pt; Left _ -> "fail" }
    692 --   "hello"
    693 decrypt
    694   :: Session
    695   -> BS.ByteString
    696   -> Either Error (BS.ByteString, Session)
    697 decrypt sess packet = do
    698   (pt, remainder, sess') <- decrypt_frame sess packet
    699   require (BS.null remainder) InvalidLength
    700   pure (pt, sess')
    701 
    702 -- | Decrypt a single frame from a buffer, returning the
    703 --   remainder.
    704 --
    705 --   Returns the plaintext, any unconsumed bytes, and the
    706 --   updated session. Key rotation is handled automatically
    707 --   every 1000 messages.
    708 --
    709 --   This is useful for streaming scenarios where multiple
    710 --   messages may be buffered together. The remainder can be
    711 --   passed to the next call to 'decrypt_frame'.
    712 --
    713 --   Wire format consumed:
    714 --     encrypted_length (18) || encrypted_body (len + 16)
    715 --
    716 --   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
    717 --   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
    718 --   >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
    719 --   >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
    720 --   >>> let Right (msg3, i_result) = act3 i_hs msg2
    721 --   >>> let Right r_result = finalize r_hs msg3
    722 --   >>> let Right (ct, _) = encrypt (session i_result) "hello"
    723 --   >>> case decrypt_frame (session r_result) ct of { Right (pt, rem, _) -> (pt, BS.null rem); Left _ -> ("fail", False) }
    724 --   ("hello",True)
    725 decrypt_frame
    726   :: Session
    727   -> BS.ByteString
    728   -> Either Error
    729        (BS.ByteString, BS.ByteString, Session)
    730 decrypt_frame sess packet = do
    731   require (BS.length packet >= 34) InvalidLength
    732   let !lc = BS.take 18 packet
    733       !rest = BS.drop 18 packet
    734       !rk = unKey32 (sess_rk sess)
    735       !rn = unSessionNonce (sess_rn sess)
    736       !rck = unKey32 (sess_rck sess)
    737   len_bytes <- note InvalidMAC
    738     (decrypt_with_ad rk rn BS.empty lc)
    739   len <- note InvalidLength (decode_be16 len_bytes)
    740   let !(rn1, rck1, rk1) = step_nonce rn rck rk
    741       !body_len = fi len + 16
    742   require (BS.length rest >= body_len) InvalidLength
    743   let !bc = BS.take body_len rest
    744       !remainder = BS.drop body_len rest
    745   pt <- note InvalidMAC
    746           (decrypt_with_ad rk1 rn1 BS.empty bc)
    747   let !(rn2, rck2, rk2) = step_nonce rn1 rck1 rk1
    748       !sess' = sess {
    749         sess_rk  = Key32 rk2
    750       , sess_rn  = SessionNonce rn2
    751       , sess_rck = Key32 rck2
    752       }
    753   pure (pt, remainder, sess')
    754 
    755 -- | Decrypt a frame from a partial buffer, indicating when
    756 --   more data needed.
    757 --
    758 --   Unlike 'decrypt_frame', this function handles incomplete
    759 --   buffers gracefully by returning 'NeedMore' with the
    760 --   number of additional bytes required to make progress.
    761 --
    762 --   * If the buffer has fewer than 18 bytes (encrypted
    763 --     length + MAC), returns @'NeedMore' n@ where @n@ is
    764 --     the bytes still needed.
    765 --   * If the length header is complete but the body is
    766 --     incomplete, returns @'NeedMore' n@ with bytes needed
    767 --     for the full frame.
    768 --   * MAC or decryption failures return 'FrameError'.
    769 --   * A complete, valid frame returns 'FrameOk' with
    770 --     plaintext, remainder, and updated session.
    771 --
    772 --   This is useful for non-blocking I/O where data arrives
    773 --   incrementally.
    774 decrypt_frame_partial
    775   :: Session
    776   -> BS.ByteString
    777   -> FrameResult
    778 decrypt_frame_partial sess buf
    779   | buflen < 18 = NeedMore (18 - buflen)
    780   | otherwise =
    781       let !lc = BS.take 18 buf
    782           !rest = BS.drop 18 buf
    783           !rk = unKey32 (sess_rk sess)
    784           !rn = unSessionNonce (sess_rn sess)
    785           !rck = unKey32 (sess_rck sess)
    786       in case decrypt_with_ad rk rn BS.empty lc of
    787            Nothing -> FrameError InvalidMAC
    788            Just len_bytes ->
    789              case decode_be16 len_bytes of
    790                Nothing -> FrameError InvalidLength
    791                Just len ->
    792                  let !body_len = fi len + 16
    793                      !(rn1, rck1, rk1) =
    794                        step_nonce rn rck rk
    795                  in if BS.length rest < body_len
    796                    then NeedMore
    797                      (body_len - BS.length rest)
    798                    else
    799                      let !bc = BS.take body_len rest
    800                          !remainder =
    801                            BS.drop body_len rest
    802                      in case decrypt_with_ad
    803                               rk1 rn1 BS.empty bc of
    804                        Nothing ->
    805                          FrameError InvalidMAC
    806                        Just pt ->
    807                          let !(rn2, rck2, rk2) =
    808                                step_nonce rn1 rck1 rk1
    809                              !sess' = sess {
    810                                sess_rk  = Key32 rk2
    811                              , sess_rn  =
    812                                  SessionNonce rn2
    813                              , sess_rck = Key32 rck2
    814                              }
    815                          in FrameOk pt remainder sess'
    816   where
    817     !buflen = BS.length buf
    818 
    819 -- key rotation ----------------------------------------------------
    820 
    821 -- Key rotation occurs after nonce reaches 1000 (i.e., before
    822 -- using 1000)
    823 -- (ck', k') = HKDF(ck, k), reset nonce to 0
    824 step_nonce
    825   :: Word64
    826   -> BS.ByteString
    827   -> BS.ByteString
    828   -> (Word64, BS.ByteString, BS.ByteString)
    829 step_nonce n ck k
    830   | n + 1 == 1000 =
    831       let !(ck', k') = mix_key ck k
    832       in (0, ck', k')
    833   | otherwise = (n + 1, ck, k)
    834 
    835 -- utilities -------------------------------------------------------
    836 
    837 -- Lift Maybe to Either
    838 note :: e -> Maybe a -> Either e a
    839 note e = maybe (Left e) Right
    840 {-# INLINE note #-}
    841 
    842 -- Require condition or fail
    843 require :: Bool -> e -> Either e ()
    844 require cond e = unless cond (Left e)
    845 {-# INLINE require #-}
    846 
    847 fi :: (Integral a, Num b) => a -> b
    848 fi = fromIntegral
    849 {-# INLINE fi #-}