bolt8

Encrypted and authenticated transport, per BOLT #8 (docs.ppad.tech/bolt8).
git clone git://git.ppad.tech/bolt8.git
Log | Files | Refs | README | LICENSE

BOLT8.hs (25550B)


      1 {-# OPTIONS_HADDOCK prune #-}
      2 {-# LANGUAGE BangPatterns #-}
      3 {-# LANGUAGE DeriveGeneric #-}
      4 {-# LANGUAGE LambdaCase #-}
      5 {-# LANGUAGE OverloadedStrings #-}
      6 {-# LANGUAGE RecordWildCards #-}
      7 {-# LANGUAGE ViewPatterns #-}
      8 
      9 -- |
     10 -- Module: Lightning.Protocol.BOLT8
     11 -- Copyright: (c) 2025 Jared Tobin
     12 -- License: MIT
     13 -- Maintainer: Jared Tobin <jared@ppad.tech>
     14 --
     15 -- Encrypted and authenticated transport for the Lightning Network, per
     16 -- [BOLT #8](https://github.com/lightning/bolts/blob/master/08-transport.md).
     17 --
     18 -- This module implements the Noise_XK_secp256k1_ChaChaPoly_SHA256
     19 -- handshake and subsequent encrypted message transport.
     20 --
     21 -- = Handshake
     22 --
     23 -- A BOLT #8 handshake consists of three acts. The /initiator/ knows the
     24 -- responder's static public key in advance and initiates the connection:
     25 --
     26 -- @
     27 -- (msg1, state) <- act1 i_sec i_pub r_pub entropy
     28 -- -- send msg1 (50 bytes) to responder
     29 -- -- receive msg2 (50 bytes) from responder
     30 -- (msg3, result) <- act3 state msg2
     31 -- -- send msg3 (66 bytes) to responder
     32 -- let session = 'session' result
     33 -- @
     34 --
     35 -- The /responder/ receives the connection and authenticates the initiator:
     36 --
     37 -- @
     38 -- -- receive msg1 (50 bytes) from initiator
     39 -- (msg2, state) <- act2 r_sec r_pub entropy msg1
     40 -- -- send msg2 (50 bytes) to initiator
     41 -- -- receive msg3 (66 bytes) from initiator
     42 -- result <- finalize state msg3
     43 -- let session = 'session' result
     44 -- @
     45 --
     46 -- = Message Transport
     47 --
     48 -- After a successful handshake, use 'encrypt' and 'decrypt' to exchange
     49 -- messages. Each returns an updated 'Session' that must be used for the
     50 -- next operation (keys rotate every 1000 messages):
     51 --
     52 -- @
     53 -- -- sender
     54 -- (ciphertext, session') <- 'encrypt' session plaintext
     55 --
     56 -- -- receiver
     57 -- (plaintext, session') <- 'decrypt' session ciphertext
     58 -- @
     59 --
     60 -- = Message Framing
     61 --
     62 -- BOLT #8 runs over a byte stream, so callers often need to deal with
     63 -- partial buffers. Use 'decrypt_frame' when you have exactly one frame,
     64 -- or 'decrypt_frame_partial' to handle incremental reads and return how
     65 -- many bytes are still needed.
     66 --
     67 -- Maximum plaintext size is 65535 bytes.
     68 
     69 module Lightning.Protocol.BOLT8 (
     70     -- * Keys
     71     Sec
     72   , Pub
     73   , keypair
     74   , parse_pub
     75   , serialize_pub
     76 
     77     -- * Handshake (initiator)
     78   , act1
     79   , act3
     80 
     81     -- * Handshake (responder)
     82   , act2
     83   , finalize
     84 
     85     -- * Session
     86   , Session
     87   , HandshakeState
     88   , Handshake(..)
     89   , encrypt
     90   , decrypt
     91   , decrypt_frame
     92   , decrypt_frame_partial
     93   , FrameResult(..)
     94 
     95     -- * Errors
     96   , Error(..)
     97   ) where
     98 
     99 import Control.Monad (guard, unless)
    100 import qualified Crypto.AEAD.ChaCha20Poly1305 as AEAD
    101 import qualified Crypto.Curve.Secp256k1 as Secp256k1
    102 import qualified Crypto.Hash.SHA256 as SHA256
    103 import qualified Crypto.KDF.HMAC as HKDF
    104 import Data.Bits (unsafeShiftR, (.&.))
    105 import qualified Data.ByteString as BS
    106 import Data.Word (Word16, Word64)
    107 import GHC.Generics (Generic)
    108 
    109 -- types ---------------------------------------------------------------------
    110 
    111 -- | Secret key (32 bytes).
    112 newtype Sec = Sec BS.ByteString
    113   deriving (Eq, Generic)
    114 
    115 -- | Compressed public key.
    116 newtype Pub = Pub Secp256k1.Projective
    117 
    118 instance Eq Pub where
    119   (Pub a) == (Pub b) =
    120     Secp256k1.serialize_point a == Secp256k1.serialize_point b
    121 
    122 instance Show Pub where
    123   show (Pub p) = "Pub " ++ show (Secp256k1.serialize_point p)
    124 
    125 -- | Handshake errors.
    126 data Error =
    127     InvalidKey
    128   | InvalidPub
    129   | InvalidMAC
    130   | InvalidVersion
    131   | InvalidLength
    132   | DecryptionFailed
    133   deriving (Eq, Show, Generic)
    134 
    135 -- | Result of attempting to decrypt a frame from a partial buffer.
    136 data FrameResult =
    137     NeedMore {-# UNPACK #-} !Int
    138     -- ^ More bytes needed; the 'Int' is the minimum additional bytes required.
    139   | FrameOk !BS.ByteString !BS.ByteString !Session
    140     -- ^ Successfully decrypted: plaintext, remainder, updated session.
    141   | FrameError !Error
    142     -- ^ Decryption failed with the given error.
    143   deriving Generic
    144 
    145 -- | Post-handshake session state.
    146 data Session = Session {
    147     sess_sk  :: {-# UNPACK #-} !BS.ByteString  -- ^ send key (32 bytes)
    148   , sess_sn  :: {-# UNPACK #-} !Word64         -- ^ send nonce
    149   , sess_sck :: {-# UNPACK #-} !BS.ByteString  -- ^ send chaining key
    150   , sess_rk  :: {-# UNPACK #-} !BS.ByteString  -- ^ receive key (32 bytes)
    151   , sess_rn  :: {-# UNPACK #-} !Word64         -- ^ receive nonce
    152   , sess_rck :: {-# UNPACK #-} !BS.ByteString  -- ^ receive chaining key
    153   }
    154   deriving Generic
    155 
    156 -- | Result of a successful handshake.
    157 data Handshake = Handshake {
    158     session       :: !Session  -- ^ session state
    159   , remote_static :: !Pub      -- ^ authenticated remote static pubkey
    160   }
    161   deriving Generic
    162 
    163 -- | Internal handshake state (exported for benchmarking).
    164 data HandshakeState = HandshakeState {
    165     hs_h      :: {-# UNPACK #-} !BS.ByteString  -- handshake hash (32 bytes)
    166   , hs_ck     :: {-# UNPACK #-} !BS.ByteString  -- chaining key (32 bytes)
    167   , hs_temp_k :: {-# UNPACK #-} !BS.ByteString  -- temp key (32 bytes)
    168   , hs_e_sec  :: !Sec                           -- ephemeral secret
    169   , hs_e_pub  :: !Pub                           -- ephemeral public
    170   , hs_s_sec  :: !Sec                           -- static secret
    171   , hs_s_pub  :: !Pub                           -- static public
    172   , hs_re     :: !(Maybe Pub)                   -- remote ephemeral
    173   , hs_rs     :: !(Maybe Pub)                   -- remote static
    174   }
    175   deriving Generic
    176 
    177 -- protocol constants --------------------------------------------------------
    178 
    179 _PROTOCOL_NAME :: BS.ByteString
    180 _PROTOCOL_NAME = "Noise_XK_secp256k1_ChaChaPoly_SHA256"
    181 
    182 _PROLOGUE :: BS.ByteString
    183 _PROLOGUE = "lightning"
    184 
    185 -- key operations ------------------------------------------------------------
    186 
    187 -- | Derive a keypair from 32 bytes of entropy.
    188 --
    189 --   Returns Nothing if the entropy is invalid (zero or >= curve order).
    190 --
    191 --   >>> let ent = BS.replicate 32 0x11
    192 --   >>> case keypair ent of { Just _ -> "ok"; Nothing -> "fail" }
    193 --   "ok"
    194 --   >>> keypair (BS.replicate 31 0x11) -- wrong length
    195 --   Nothing
    196 keypair :: BS.ByteString -> Maybe (Sec, Pub)
    197 keypair ent = do
    198   guard (BS.length ent == 32)
    199   k <- Secp256k1.parse_int256 ent
    200   p <- Secp256k1.derive_pub k
    201   pure (Sec ent, Pub p)
    202 
    203 -- | Parse a 33-byte compressed public key.
    204 --
    205 --   >>> let Just (_, pub) = keypair (BS.replicate 32 0x11)
    206 --   >>> let bytes = serialize_pub pub
    207 --   >>> case parse_pub bytes of { Just _ -> "ok"; Nothing -> "fail" }
    208 --   "ok"
    209 --   >>> parse_pub (BS.replicate 32 0x00) -- wrong length
    210 --   Nothing
    211 parse_pub :: BS.ByteString -> Maybe Pub
    212 parse_pub bs = do
    213   guard (BS.length bs == 33)
    214   p <- Secp256k1.parse_point bs
    215   pure (Pub p)
    216 
    217 -- | Serialize a public key to 33-byte compressed form.
    218 --
    219 --   >>> let Just (_, pub) = keypair (BS.replicate 32 0x11)
    220 --   >>> BS.length (serialize_pub pub)
    221 --   33
    222 serialize_pub :: Pub -> BS.ByteString
    223 serialize_pub (Pub p) = Secp256k1.serialize_point p
    224 
    225 -- cryptographic primitives --------------------------------------------------
    226 
    227 -- bolt8-style ECDH
    228 ecdh :: Sec -> Pub -> Maybe BS.ByteString
    229 ecdh (Sec sec) (Pub pub) = do
    230   k <- Secp256k1.parse_int256 sec
    231   pt <- Secp256k1.mul pub k
    232   let compressed = Secp256k1.serialize_point pt
    233   pure (SHA256.hash compressed)
    234 
    235 -- h' = SHA256(h || data)
    236 mix_hash :: BS.ByteString -> BS.ByteString -> BS.ByteString
    237 mix_hash h dat = SHA256.hash (h <> dat)
    238 
    239 -- Mix key: (ck', k) = HKDF(ck, input_key_material)
    240 --
    241 -- NB HKDF limits output to 255 * hashlen bytes. For SHA256 that's 8160,
    242 -- well above the 64 bytes requested here, so 'Nothing' is impossible.
    243 mix_key :: BS.ByteString -> BS.ByteString -> (BS.ByteString, BS.ByteString)
    244 mix_key ck ikm = case HKDF.derive hmac ck mempty 64 ikm of
    245     Nothing -> error "ppad-bolt8: internal error, please report a bug!"
    246     Just output -> BS.splitAt 32 output
    247   where
    248     hmac k b = case SHA256.hmac k b of
    249       SHA256.MAC mac -> mac
    250 
    251 -- Encrypt with associated data using ChaCha20-Poly1305
    252 encrypt_with_ad
    253   :: BS.ByteString       -- ^ key (32 bytes)
    254   -> Word64              -- ^ nonce
    255   -> BS.ByteString       -- ^ associated data
    256   -> BS.ByteString       -- ^ plaintext
    257   -> Maybe BS.ByteString -- ^ ciphertext || mac (16 bytes)
    258 encrypt_with_ad key n ad pt =
    259   case AEAD.encrypt ad key (encode_nonce n) pt of
    260     Left _ -> Nothing
    261     Right (ct, mac) -> Just (ct <> mac)
    262 
    263 -- Decrypt with associated data using ChaCha20-Poly1305
    264 decrypt_with_ad
    265   :: BS.ByteString       -- ^ key (32 bytes)
    266   -> Word64              -- ^ nonce
    267   -> BS.ByteString       -- ^ associated data
    268   -> BS.ByteString       -- ^ ciphertext || mac
    269   -> Maybe BS.ByteString -- ^ plaintext
    270 decrypt_with_ad key n ad ctmac
    271   | BS.length ctmac < 16 = Nothing
    272   | otherwise =
    273       let (ct, mac) = BS.splitAt (BS.length ctmac - 16) ctmac
    274       in case AEAD.decrypt ad key (encode_nonce n) (ct, mac) of
    275            Left _ -> Nothing
    276            Right pt -> Just pt
    277 
    278 -- Encode nonce as 96-bit value: 4 zero bytes + 8-byte little-endian
    279 encode_nonce :: Word64 -> BS.ByteString
    280 encode_nonce n = BS.replicate 4 0x00 <> encode_le64 n
    281 
    282 -- Little-endian 64-bit encoding
    283 encode_le64 :: Word64 -> BS.ByteString
    284 encode_le64 n = BS.pack [
    285     fi (n .&. 0xff)
    286   , fi (unsafeShiftR n 8  .&. 0xff)
    287   , fi (unsafeShiftR n 16 .&. 0xff)
    288   , fi (unsafeShiftR n 24 .&. 0xff)
    289   , fi (unsafeShiftR n 32 .&. 0xff)
    290   , fi (unsafeShiftR n 40 .&. 0xff)
    291   , fi (unsafeShiftR n 48 .&. 0xff)
    292   , fi (unsafeShiftR n 56 .&. 0xff)
    293   ]
    294 
    295 -- Big-endian 16-bit encoding
    296 encode_be16 :: Word16 -> BS.ByteString
    297 encode_be16 n = BS.pack [fi (unsafeShiftR n 8), fi (n .&. 0xff)]
    298 
    299 -- Big-endian 16-bit decoding
    300 decode_be16 :: BS.ByteString -> Maybe Word16
    301 decode_be16 bs
    302   | BS.length bs /= 2 = Nothing
    303   | otherwise =
    304       let !b0 = BS.index bs 0
    305           !b1 = BS.index bs 1
    306       in Just (fi b0 * 0x100 + fi b1)
    307 
    308 -- handshake -----------------------------------------------------------------
    309 
    310 -- Initialize handshake state
    311 --
    312 -- h = SHA256(protocol_name)
    313 -- ck = h
    314 -- h = SHA256(h || prologue)
    315 -- h = SHA256(h || responder_static_pubkey)
    316 init_handshake
    317   :: Sec                -- ^ local static secret
    318   -> Pub                -- ^ local static public
    319   -> Sec                -- ^ ephemeral secret
    320   -> Pub                -- ^ ephemeral public
    321   -> Maybe Pub          -- ^ remote static (initiator knows, responder doesn't)
    322   -> Bool               -- ^ True if initiator
    323   -> HandshakeState
    324 init_handshake s_sec s_pub e_sec e_pub m_rs is_initiator =
    325   let !h0 = SHA256.hash _PROTOCOL_NAME
    326       !ck = h0
    327       !h1 = mix_hash h0 _PROLOGUE
    328       -- Mix in responder's static pubkey
    329       !h2 = case (is_initiator, m_rs) of
    330         (True, Just rs)  -> mix_hash h1 (serialize_pub rs)
    331         (False, Nothing) -> mix_hash h1 (serialize_pub s_pub)
    332         _ -> h1  -- shouldn't happen
    333   in HandshakeState {
    334        hs_h      = h2
    335      , hs_ck     = ck
    336      , hs_temp_k = BS.replicate 32 0x00
    337      , hs_e_sec  = e_sec
    338      , hs_e_pub  = e_pub
    339      , hs_s_sec  = s_sec
    340      , hs_s_pub  = s_pub
    341      , hs_re     = Nothing
    342      , hs_rs     = m_rs
    343      }
    344 
    345 -- | Initiator: generate Act 1 message (50 bytes).
    346 --
    347 --   Takes local static key, remote static pubkey, and 32 bytes of
    348 --   entropy for ephemeral key generation.
    349 --
    350 --   Returns the 50-byte Act 1 message and handshake state for Act 3.
    351 --
    352 --   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
    353 --   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
    354 --   >>> let eph_ent = BS.replicate 32 0x12
    355 --   >>> case act1 i_sec i_pub r_pub eph_ent of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
    356 --   50
    357 act1
    358   :: Sec                -- ^ local static secret
    359   -> Pub                -- ^ local static public
    360   -> Pub                -- ^ remote static public (responder's)
    361   -> BS.ByteString      -- ^ 32 bytes entropy for ephemeral
    362   -> Either Error (BS.ByteString, HandshakeState)
    363 act1 s_sec s_pub rs ent = do
    364   (e_sec, e_pub) <- note InvalidKey (keypair ent)
    365   let !hs0 = init_handshake s_sec s_pub e_sec e_pub (Just rs) True
    366       !e_pub_bytes = serialize_pub e_pub
    367       !h1 = mix_hash (hs_h hs0) e_pub_bytes
    368   es <- note InvalidKey (ecdh e_sec rs)
    369   let !(ck1, temp_k1) = mix_key (hs_ck hs0) es
    370   c <- note InvalidMAC (encrypt_with_ad temp_k1 0 h1 BS.empty)
    371   let !h2 = mix_hash h1 c
    372       !msg = BS.singleton 0x00 <> e_pub_bytes <> c
    373       !hs1 = hs0 {
    374         hs_h      = h2
    375       , hs_ck     = ck1
    376       , hs_temp_k = temp_k1
    377       }
    378   pure (msg, hs1)
    379 
    380 -- | Responder: process Act 1 and generate Act 2 message (50 bytes).
    381 --
    382 --   Takes local static key and 32 bytes of entropy for ephemeral key,
    383 --   plus the 50-byte Act 1 message from initiator.
    384 --
    385 --   Returns the 50-byte Act 2 message and handshake state for finalize.
    386 --
    387 --   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
    388 --   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
    389 --   >>> let Right (msg1, _) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
    390 --   >>> case act2 r_sec r_pub (BS.replicate 32 0x22) msg1 of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
    391 --   50
    392 act2
    393   :: Sec                -- ^ local static secret
    394   -> Pub                -- ^ local static public
    395   -> BS.ByteString      -- ^ 32 bytes entropy for ephemeral
    396   -> BS.ByteString      -- ^ Act 1 message (50 bytes)
    397   -> Either Error (BS.ByteString, HandshakeState)
    398 act2 s_sec s_pub ent msg1 = do
    399   require (BS.length msg1 == 50) InvalidLength
    400   let !version = BS.index msg1 0
    401       !re_bytes = BS.take 33 (BS.drop 1 msg1)
    402       !c = BS.drop 34 msg1
    403   require (version == 0x00) InvalidVersion
    404   re <- note InvalidPub (parse_pub re_bytes)
    405   (e_sec, e_pub) <- note InvalidKey (keypair ent)
    406   let !hs0 = init_handshake s_sec s_pub e_sec e_pub Nothing False
    407       !h1 = mix_hash (hs_h hs0) re_bytes
    408   es <- note InvalidKey (ecdh s_sec re)
    409   let !(ck1, temp_k1) = mix_key (hs_ck hs0) es
    410   _ <- note InvalidMAC (decrypt_with_ad temp_k1 0 h1 c)
    411   let !h2 = mix_hash h1 c
    412       !e_pub_bytes = serialize_pub e_pub
    413       !h3 = mix_hash h2 e_pub_bytes
    414   ee <- note InvalidKey (ecdh e_sec re)
    415   let !(ck2, temp_k2) = mix_key ck1 ee
    416   c2 <- note InvalidMAC (encrypt_with_ad temp_k2 0 h3 BS.empty)
    417   let !h4 = mix_hash h3 c2
    418       !msg = BS.singleton 0x00 <> e_pub_bytes <> c2
    419       !hs1 = hs0 {
    420         hs_h      = h4
    421       , hs_ck     = ck2
    422       , hs_temp_k = temp_k2
    423       , hs_re     = Just re
    424       }
    425   pure (msg, hs1)
    426 
    427 -- | Initiator: process Act 2 and generate Act 3 (66 bytes), completing
    428 --   the handshake.
    429 --
    430 --   Returns the 66-byte Act 3 message and the handshake result.
    431 --
    432 --   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
    433 --   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
    434 --   >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
    435 --   >>> let Right (msg2, _) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
    436 --   >>> case act3 i_hs msg2 of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
    437 --   66
    438 act3
    439   :: HandshakeState     -- ^ state after Act 1
    440   -> BS.ByteString      -- ^ Act 2 message (50 bytes)
    441   -> Either Error (BS.ByteString, Handshake)
    442 act3 hs msg2 = do
    443   require (BS.length msg2 == 50) InvalidLength
    444   let !version = BS.index msg2 0
    445       !re_bytes = BS.take 33 (BS.drop 1 msg2)
    446       !c = BS.drop 34 msg2
    447   require (version == 0x00) InvalidVersion
    448   re <- note InvalidPub (parse_pub re_bytes)
    449   let !h1 = mix_hash (hs_h hs) re_bytes
    450   ee <- note InvalidKey (ecdh (hs_e_sec hs) re)
    451   let !(ck1, temp_k2) = mix_key (hs_ck hs) ee
    452   _ <- note InvalidMAC (decrypt_with_ad temp_k2 0 h1 c)
    453   let !h2 = mix_hash h1 c
    454       !s_pub_bytes = serialize_pub (hs_s_pub hs)
    455   c3 <- note InvalidMAC (encrypt_with_ad temp_k2 1 h2 s_pub_bytes)
    456   let !h3 = mix_hash h2 c3
    457   se <- note InvalidKey (ecdh (hs_s_sec hs) re)
    458   let !(ck2, temp_k3) = mix_key ck1 se
    459   t <- note InvalidMAC (encrypt_with_ad temp_k3 0 h3 BS.empty)
    460   let !(sk, rk) = mix_key ck2 BS.empty
    461       !msg = BS.singleton 0x00 <> c3 <> t
    462       !sess = Session {
    463         sess_sk  = sk
    464       , sess_sn  = 0
    465       , sess_sck = ck2
    466       , sess_rk  = rk
    467       , sess_rn  = 0
    468       , sess_rck = ck2
    469       }
    470   rs <- note InvalidPub (hs_rs hs)
    471   let !result = Handshake {
    472         session       = sess
    473       , remote_static = rs
    474       }
    475   pure (msg, result)
    476 
    477 -- | Responder: process Act 3 (66 bytes) and complete the handshake.
    478 --
    479 --   Returns the handshake result with authenticated remote static pubkey.
    480 --
    481 --   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
    482 --   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
    483 --   >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
    484 --   >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
    485 --   >>> let Right (msg3, _) = act3 i_hs msg2
    486 --   >>> case finalize r_hs msg3 of { Right _ -> "ok"; Left e -> show e }
    487 --   "ok"
    488 finalize
    489   :: HandshakeState     -- ^ state after Act 2
    490   -> BS.ByteString      -- ^ Act 3 message (66 bytes)
    491   -> Either Error Handshake
    492 finalize hs msg3 = do
    493   require (BS.length msg3 == 66) InvalidLength
    494   let !version = BS.index msg3 0
    495       !c = BS.take 49 (BS.drop 1 msg3)
    496       !t = BS.drop 50 msg3
    497   require (version == 0x00) InvalidVersion
    498   rs_bytes <- note InvalidMAC (decrypt_with_ad (hs_temp_k hs) 1 (hs_h hs) c)
    499   rs <- note InvalidPub (parse_pub rs_bytes)
    500   let !h1 = mix_hash (hs_h hs) c
    501   se <- note InvalidKey (ecdh (hs_e_sec hs) rs)
    502   let !(ck1, temp_k3) = mix_key (hs_ck hs) se
    503   _ <- note InvalidMAC (decrypt_with_ad temp_k3 0 h1 t)
    504   -- responder swaps order (receives what initiator sends)
    505   let !(rk, sk) = mix_key ck1 BS.empty
    506       !sess = Session {
    507         sess_sk  = sk
    508       , sess_sn  = 0
    509       , sess_sck = ck1
    510       , sess_rk  = rk
    511       , sess_rn  = 0
    512       , sess_rck = ck1
    513       }
    514       !result = Handshake {
    515         session       = sess
    516       , remote_static = rs
    517       }
    518   pure result
    519 
    520 -- message encryption --------------------------------------------------------
    521 
    522 -- | Encrypt a message (max 65535 bytes).
    523 --
    524 --   Returns the encrypted packet and updated session. Key rotation
    525 --   is handled automatically at nonce 1000.
    526 --
    527 --   Wire format: encrypted_length (2) || MAC (16) || encrypted_body || MAC (16)
    528 --
    529 --   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
    530 --   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
    531 --   >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
    532 --   >>> let Right (msg2, _) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
    533 --   >>> let Right (_, i_result) = act3 i_hs msg2
    534 --   >>> let sess = session i_result
    535 --   >>> case encrypt sess "hello" of { Right (ct, _) -> BS.length ct; Left _ -> 0 }
    536 --   39
    537 encrypt
    538   :: Session
    539   -> BS.ByteString          -- ^ plaintext (max 65535 bytes)
    540   -> Either Error (BS.ByteString, Session)
    541 encrypt sess pt = do
    542   let !len = BS.length pt
    543   require (len <= 65535) InvalidLength
    544   let !len_bytes = encode_be16 (fi len)
    545   lc <- note InvalidMAC (encrypt_with_ad (sess_sk sess) (sess_sn sess)
    546                            BS.empty len_bytes)
    547   let !(sn1, sck1, sk1) = step_nonce (sess_sn sess) (sess_sck sess) (sess_sk sess)
    548   bc <- note InvalidMAC (encrypt_with_ad sk1 sn1 BS.empty pt)
    549   let !(sn2, sck2, sk2) = step_nonce sn1 sck1 sk1
    550       !packet = lc <> bc
    551       !sess' = sess {
    552         sess_sk  = sk2
    553       , sess_sn  = sn2
    554       , sess_sck = sck2
    555       }
    556   pure (packet, sess')
    557 
    558 -- | Decrypt a message, requiring an exact packet with no trailing bytes.
    559 --
    560 --   Returns the plaintext and updated session. Key rotation
    561 --   is handled automatically at nonce 1000.
    562 --
    563 --   This is a strict variant that rejects any trailing data. For
    564 --   streaming use cases where you need to handle multiple frames in a
    565 --   buffer, use 'decrypt_frame' instead.
    566 --
    567 --   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
    568 --   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
    569 --   >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
    570 --   >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
    571 --   >>> let Right (msg3, i_result) = act3 i_hs msg2
    572 --   >>> let Right r_result = finalize r_hs msg3
    573 --   >>> let Right (ct, _) = encrypt (session i_result) "hello"
    574 --   >>> case decrypt (session r_result) ct of { Right (pt, _) -> pt; Left _ -> "fail" }
    575 --   "hello"
    576 decrypt
    577   :: Session
    578   -> BS.ByteString          -- ^ encrypted packet (exact length required)
    579   -> Either Error (BS.ByteString, Session)
    580 decrypt sess packet = do
    581   (pt, remainder, sess') <- decrypt_frame sess packet
    582   require (BS.null remainder) InvalidLength
    583   pure (pt, sess')
    584 
    585 -- | Decrypt a single frame from a buffer, returning the remainder.
    586 --
    587 --   Returns the plaintext, any unconsumed bytes, and the updated session.
    588 --   Key rotation is handled automatically every 1000 messages.
    589 --
    590 --   This is useful for streaming scenarios where multiple messages may
    591 --   be buffered together. The remainder can be passed to the next call
    592 --   to 'decrypt_frame'.
    593 --
    594 --   Wire format consumed: encrypted_length (18) || encrypted_body (len + 16)
    595 --
    596 --   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
    597 --   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
    598 --   >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
    599 --   >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
    600 --   >>> let Right (msg3, i_result) = act3 i_hs msg2
    601 --   >>> let Right r_result = finalize r_hs msg3
    602 --   >>> let Right (ct, _) = encrypt (session i_result) "hello"
    603 --   >>> case decrypt_frame (session r_result) ct of { Right (pt, rem, _) -> (pt, BS.null rem); Left _ -> ("fail", False) }
    604 --   ("hello",True)
    605 decrypt_frame
    606   :: Session
    607   -> BS.ByteString          -- ^ buffer containing at least one encrypted frame
    608   -> Either Error (BS.ByteString, BS.ByteString, Session)
    609 decrypt_frame sess packet = do
    610   require (BS.length packet >= 34) InvalidLength
    611   let !lc = BS.take 18 packet
    612       !rest = BS.drop 18 packet
    613   len_bytes <- note InvalidMAC (decrypt_with_ad (sess_rk sess) (sess_rn sess)
    614                                   BS.empty lc)
    615   len <- note InvalidLength (decode_be16 len_bytes)
    616   let !(rn1, rck1, rk1) = step_nonce (sess_rn sess) (sess_rck sess) (sess_rk sess)
    617       !body_len = fi len + 16
    618   require (BS.length rest >= body_len) InvalidLength
    619   let !bc = BS.take body_len rest
    620       !remainder = BS.drop body_len rest
    621   pt <- note InvalidMAC (decrypt_with_ad rk1 rn1 BS.empty bc)
    622   let !(rn2, rck2, rk2) = step_nonce rn1 rck1 rk1
    623       !sess' = sess {
    624         sess_rk  = rk2
    625       , sess_rn  = rn2
    626       , sess_rck = rck2
    627       }
    628   pure (pt, remainder, sess')
    629 
    630 -- | Decrypt a frame from a partial buffer, indicating when more data needed.
    631 --
    632 --   Unlike 'decrypt_frame', this function handles incomplete buffers
    633 --   gracefully by returning 'NeedMore' with the number of additional
    634 --   bytes required to make progress.
    635 --
    636 --   * If the buffer has fewer than 18 bytes (encrypted length + MAC),
    637 --     returns @'NeedMore' n@ where @n@ is the bytes still needed.
    638 --   * If the length header is complete but the body is incomplete,
    639 --     returns @'NeedMore' n@ with bytes needed for the full frame.
    640 --   * MAC or decryption failures return 'FrameError'.
    641 --   * A complete, valid frame returns 'FrameOk' with plaintext,
    642 --     remainder, and updated session.
    643 --
    644 --   This is useful for non-blocking I/O where data arrives incrementally.
    645 decrypt_frame_partial
    646   :: Session
    647   -> BS.ByteString  -- ^ buffer (possibly incomplete)
    648   -> FrameResult
    649 decrypt_frame_partial sess buf
    650   | buflen < 18 = NeedMore (18 - buflen)
    651   | otherwise =
    652       let !lc = BS.take 18 buf
    653           !rest = BS.drop 18 buf
    654       in case decrypt_with_ad (sess_rk sess) (sess_rn sess) BS.empty lc of
    655            Nothing -> FrameError InvalidMAC
    656            Just len_bytes -> case decode_be16 len_bytes of
    657              Nothing -> FrameError InvalidLength
    658              Just len ->
    659                let !body_len = fi len + 16
    660                    !(rn1, rck1, rk1) = step_nonce (sess_rn sess)
    661                                         (sess_rck sess) (sess_rk sess)
    662                in if BS.length rest < body_len
    663                     then NeedMore (body_len - BS.length rest)
    664                     else
    665                       let !bc = BS.take body_len rest
    666                           !remainder = BS.drop body_len rest
    667                       in case decrypt_with_ad rk1 rn1 BS.empty bc of
    668                            Nothing -> FrameError InvalidMAC
    669                            Just pt ->
    670                              let !(rn2, rck2, rk2) = step_nonce rn1 rck1 rk1
    671                                  !sess' = sess {
    672                                    sess_rk  = rk2
    673                                  , sess_rn  = rn2
    674                                  , sess_rck = rck2
    675                                  }
    676                              in FrameOk pt remainder sess'
    677   where
    678     !buflen = BS.length buf
    679 
    680 -- key rotation --------------------------------------------------------------
    681 
    682 -- Key rotation occurs after nonce reaches 1000 (i.e., before using 1000)
    683 -- (ck', k') = HKDF(ck, k), reset nonce to 0
    684 step_nonce
    685   :: Word64
    686   -> BS.ByteString
    687   -> BS.ByteString
    688   -> (Word64, BS.ByteString, BS.ByteString)
    689 step_nonce n ck k
    690   | n + 1 == 1000 =
    691       let !(ck', k') = mix_key ck k
    692       in (0, ck', k')
    693   | otherwise = (n + 1, ck, k)
    694 
    695 -- utilities -----------------------------------------------------------------
    696 
    697 -- Lift Maybe to Either
    698 note :: e -> Maybe a -> Either e a
    699 note e = maybe (Left e) Right
    700 {-# INLINE note #-}
    701 
    702 -- Require condition or fail
    703 require :: Bool -> e -> Either e ()
    704 require cond e = unless cond (Left e)
    705 {-# INLINE require #-}
    706 
    707 fi :: (Integral a, Num b) => a -> b
    708 fi = fromIntegral
    709 {-# INLINE fi #-}