bolt8

Encrypted and authenticated transport, per BOLT #8 (docs.ppad.tech/bolt8).
git clone git://git.ppad.tech/bolt8.git
Log | Files | Refs | README | LICENSE

Internal.hs (25737B)


      1 {-# OPTIONS_HADDOCK prune #-}
      2 {-# LANGUAGE BangPatterns #-}
      3 {-# LANGUAGE DeriveGeneric #-}
      4 {-# LANGUAGE LambdaCase #-}
      5 {-# LANGUAGE OverloadedStrings #-}
      6 {-# LANGUAGE RecordWildCards #-}
      7 {-# LANGUAGE ViewPatterns #-}
      8 
      9 -- |
     10 -- Module: Lightning.Protocol.BOLT8.Internal
     11 -- Copyright: (c) 2025 Jared Tobin
     12 -- License: MIT
     13 -- Maintainer: Jared Tobin <jared@ppad.tech>
     14 --
     15 -- Internal module exporting all constructors for testing and
     16 -- benchmarking. Prefer "Lightning.Protocol.BOLT8" for general use.
     17 
     18 module Lightning.Protocol.BOLT8.Internal (
     19     -- * Keys
     20     Sec(..)
     21   , Pub(..)
     22   , keypair
     23   , parse_pub
     24   , serialize_pub
     25 
     26     -- * Newtypes
     27   , Key32(..)
     28   , key32
     29   , unsafeKey32
     30   , SessionNonce(..)
     31   , MessagePayload(..)
     32   , mkMessagePayload
     33 
     34     -- * Handshake roles
     35   , Initiator
     36   , Responder
     37   , HandshakeFor(..)
     38 
     39     -- * Handshake (initiator)
     40   , act1
     41   , act3
     42 
     43     -- * Handshake (responder)
     44   , act2
     45   , finalize
     46 
     47     -- * Session
     48   , Session(..)
     49   , HandshakeState(..)
     50   , Handshake(..)
     51   , encrypt
     52   , decrypt
     53   , decrypt_frame
     54   , decrypt_frame_partial
     55   , FrameResult(..)
     56 
     57     -- * Errors
     58   , Error(..)
     59   ) where
     60 
     61 import Control.Monad (guard, unless)
     62 import qualified Crypto.AEAD.ChaCha20Poly1305 as AEAD
     63 import qualified Crypto.Curve.Secp256k1 as Secp256k1
     64 import qualified Crypto.Hash.SHA256 as SHA256
     65 import qualified Crypto.KDF.HMAC as HKDF
     66 import Data.Bits (unsafeShiftR, (.&.))
     67 import qualified Data.ByteString as BS
     68 import Data.Word (Word16, Word64)
     69 import GHC.Generics (Generic)
     70 
     71 -- types -----------------------------------------------------------
     72 
     73 -- | Secret key (32 bytes).
     74 newtype Sec = Sec BS.ByteString
     75   deriving (Eq, Generic)
     76 
     77 -- | Compressed public key.
     78 newtype Pub = Pub Secp256k1.Projective
     79 
     80 instance Eq Pub where
     81   (Pub a) == (Pub b) =
     82     Secp256k1.serialize_point a
     83       == Secp256k1.serialize_point b
     84 
     85 instance Show Pub where
     86   show (Pub p) =
     87     "Pub " ++ show (Secp256k1.serialize_point p)
     88 
     89 -- | A 32-byte key, validated at construction.
     90 newtype Key32 = Key32 { unKey32 :: BS.ByteString }
     91   deriving (Eq, Generic)
     92 
     93 -- | Construct a 'Key32' from a 32-byte 'BS.ByteString'.
     94 --
     95 --   Returns 'Nothing' if the input is not exactly 32 bytes.
     96 --
     97 --   >>> key32 (BS.replicate 32 0x00)
     98 --   Just (Key32 {unKey32 = ...})
     99 --   >>> key32 (BS.replicate 31 0x00)
    100 --   Nothing
    101 key32 :: BS.ByteString -> Maybe Key32
    102 key32 bs
    103   | BS.length bs == 32 = Just (Key32 bs)
    104   | otherwise = Nothing
    105 
    106 -- | Construct a 'Key32' without validation.
    107 --
    108 --   For test and benchmark use only; prefer 'key32'.
    109 unsafeKey32 :: BS.ByteString -> Key32
    110 unsafeKey32 = Key32
    111 
    112 -- | Session nonce, distinguishing send from receive direction.
    113 newtype SessionNonce =
    114   SessionNonce { unSessionNonce :: Word64 }
    115   deriving (Eq, Generic)
    116 
    117 -- | Message payload (max 65535 bytes), validated at construction.
    118 newtype MessagePayload =
    119   MessagePayload { unMessagePayload :: BS.ByteString }
    120   deriving (Eq, Generic)
    121 
    122 -- | Construct a 'MessagePayload' from a 'BS.ByteString'.
    123 --
    124 --   Returns 'Left' if the payload exceeds 65535 bytes.
    125 mkMessagePayload
    126   :: BS.ByteString -> Either Error MessagePayload
    127 mkMessagePayload bs
    128   | BS.length bs > 65535 = Left InvalidLength
    129   | otherwise = Right (MessagePayload bs)
    130 
    131 -- | Handshake errors.
    132 data Error =
    133     InvalidKey
    134   | InvalidPub
    135   | InvalidMAC
    136   | InvalidVersion
    137   | InvalidLength
    138   | DecryptionFailed
    139   deriving (Eq, Show, Generic)
    140 
    141 -- | Result of attempting to decrypt a frame from a partial
    142 --   buffer.
    143 data FrameResult =
    144     NeedMore {-# UNPACK #-} !Int
    145     -- ^ More bytes needed; the 'Int' is the minimum
    146     --   additional bytes required.
    147   | FrameOk !BS.ByteString !BS.ByteString !Session
    148     -- ^ Successfully decrypted: plaintext, remainder,
    149     --   updated session.
    150   | FrameError !Error
    151     -- ^ Decryption failed with the given error.
    152   deriving Generic
    153 
    154 -- | Post-handshake session state.
    155 data Session = Session {
    156     sess_sk  :: !Key32
    157     -- ^ send key (32 bytes)
    158   , sess_sn  :: !SessionNonce
    159     -- ^ send nonce
    160   , sess_sck :: !Key32
    161     -- ^ send chaining key
    162   , sess_rk  :: !Key32
    163     -- ^ receive key (32 bytes)
    164   , sess_rn  :: !SessionNonce
    165     -- ^ receive nonce
    166   , sess_rck :: !Key32
    167     -- ^ receive chaining key
    168   }
    169   deriving Generic
    170 
    171 -- | Result of a successful handshake.
    172 data Handshake = Handshake {
    173     session       :: !Session
    174     -- ^ session state
    175   , remote_static :: !Pub
    176     -- ^ authenticated remote static pubkey
    177   }
    178   deriving Generic
    179 
    180 -- | Internal handshake state (exported for benchmarking).
    181 data HandshakeState = HandshakeState {
    182     hs_h      :: {-# UNPACK #-} !BS.ByteString
    183     -- ^ handshake hash (32 bytes)
    184   , hs_ck     :: {-# UNPACK #-} !BS.ByteString
    185     -- ^ chaining key (32 bytes)
    186   , hs_temp_k :: {-# UNPACK #-} !BS.ByteString
    187     -- ^ temp key (32 bytes)
    188   , hs_e_sec  :: !Sec
    189     -- ^ ephemeral secret
    190   , hs_e_pub  :: !Pub
    191     -- ^ ephemeral public
    192   , hs_s_sec  :: !Sec
    193     -- ^ static secret
    194   , hs_s_pub  :: !Pub
    195     -- ^ static public
    196   , hs_re     :: !(Maybe Pub)
    197     -- ^ remote ephemeral
    198   , hs_rs     :: !(Maybe Pub)
    199     -- ^ remote static
    200   }
    201   deriving Generic
    202 
    203 -- handshake roles -------------------------------------------------
    204 
    205 -- | Phantom type for initiator role.
    206 data Initiator
    207 
    208 -- | Phantom type for responder role.
    209 data Responder
    210 
    211 -- | Role-indexed handshake state.
    212 --
    213 --   The phantom type parameter prevents passing an initiator's
    214 --   state to a responder function and vice versa.
    215 data HandshakeFor a =
    216   HandshakeFor { unHandshakeFor :: !HandshakeState }
    217 
    218 -- protocol constants ----------------------------------------------
    219 
    220 _PROTOCOL_NAME :: BS.ByteString
    221 _PROTOCOL_NAME =
    222   "Noise_XK_secp256k1_ChaChaPoly_SHA256"
    223 
    224 _PROLOGUE :: BS.ByteString
    225 _PROLOGUE = "lightning"
    226 
    227 -- key operations --------------------------------------------------
    228 
    229 -- | Derive a keypair from 32 bytes of entropy.
    230 --
    231 --   Returns Nothing if the entropy is invalid
    232 --   (zero or >= curve order).
    233 --
    234 --   >>> let ent = BS.replicate 32 0x11
    235 --   >>> case keypair ent of
    236 --   ...   Just _ -> "ok"
    237 --   ...   Nothing -> "fail"
    238 --   "ok"
    239 --   >>> keypair (BS.replicate 31 0x11) -- wrong length
    240 --   Nothing
    241 keypair :: BS.ByteString -> Maybe (Sec, Pub)
    242 keypair ent = do
    243   guard (BS.length ent == 32)
    244   k <- Secp256k1.parse_int256 ent
    245   p <- Secp256k1.derive_pub k
    246   pure (Sec ent, Pub p)
    247 
    248 -- | Parse a 33-byte compressed public key.
    249 --
    250 --   >>> let Just (_, pub) = keypair (BS.replicate 32 0x11)
    251 --   >>> let bytes = serialize_pub pub
    252 --   >>> case parse_pub bytes of
    253 --   ...   Just _ -> "ok"
    254 --   ...   Nothing -> "fail"
    255 --   "ok"
    256 --   >>> parse_pub (BS.replicate 32 0x00) -- wrong length
    257 --   Nothing
    258 parse_pub :: BS.ByteString -> Maybe Pub
    259 parse_pub bs = do
    260   guard (BS.length bs == 33)
    261   p <- Secp256k1.parse_point bs
    262   pure (Pub p)
    263 
    264 -- | Serialize a public key to 33-byte compressed form.
    265 --
    266 --   >>> let Just (_, pub) = keypair (BS.replicate 32 0x11)
    267 --   >>> BS.length (serialize_pub pub)
    268 --   33
    269 serialize_pub :: Pub -> BS.ByteString
    270 serialize_pub (Pub p) = Secp256k1.serialize_point p
    271 
    272 -- cryptographic primitives ----------------------------------------
    273 
    274 -- bolt8-style ECDH
    275 ecdh :: Sec -> Pub -> Maybe BS.ByteString
    276 ecdh (Sec sec) (Pub pub) = do
    277   k <- Secp256k1.parse_int256 sec
    278   pt <- Secp256k1.mul pub k
    279   let compressed = Secp256k1.serialize_point pt
    280   pure (SHA256.hash compressed)
    281 
    282 -- h' = SHA256(h || data)
    283 mix_hash
    284   :: BS.ByteString -> BS.ByteString -> BS.ByteString
    285 mix_hash h dat = SHA256.hash (h <> dat)
    286 
    287 -- Mix key: (ck', k) = HKDF(ck, input_key_material)
    288 --
    289 -- NB HKDF limits output to 255 * hashlen bytes. For SHA256
    290 -- that's 8160, well above the 64 bytes requested here, so
    291 -- 'Nothing' is impossible.
    292 mix_key
    293   :: BS.ByteString
    294   -> BS.ByteString
    295   -> (BS.ByteString, BS.ByteString)
    296 mix_key ck ikm =
    297   case HKDF.derive hmac ck mempty 64 ikm of
    298     Nothing ->
    299       error
    300         "ppad-bolt8: internal error, please report a bug!"
    301     Just output -> BS.splitAt 32 output
    302   where
    303     hmac k b = case SHA256.hmac k b of
    304       SHA256.MAC mac -> mac
    305 
    306 -- Encrypt with associated data using ChaCha20-Poly1305
    307 encrypt_with_ad
    308   :: BS.ByteString       -- ^ key (32 bytes)
    309   -> Word64              -- ^ nonce
    310   -> BS.ByteString       -- ^ associated data
    311   -> BS.ByteString       -- ^ plaintext
    312   -> Maybe BS.ByteString -- ^ ciphertext || mac (16 bytes)
    313 encrypt_with_ad key n ad pt =
    314   case AEAD.encrypt ad key (encode_nonce n) pt of
    315     Left _ -> Nothing
    316     Right (ct, mac) -> Just (ct <> mac)
    317 
    318 -- Decrypt with associated data using ChaCha20-Poly1305
    319 decrypt_with_ad
    320   :: BS.ByteString       -- ^ key (32 bytes)
    321   -> Word64              -- ^ nonce
    322   -> BS.ByteString       -- ^ associated data
    323   -> BS.ByteString       -- ^ ciphertext || mac
    324   -> Maybe BS.ByteString -- ^ plaintext
    325 decrypt_with_ad key n ad ctmac
    326   | BS.length ctmac < 16 = Nothing
    327   | otherwise =
    328       let (ct, mac) =
    329             BS.splitAt (BS.length ctmac - 16) ctmac
    330       in case AEAD.decrypt ad key (encode_nonce n)
    331                 (ct, mac) of
    332            Left _ -> Nothing
    333            Right pt -> Just pt
    334 
    335 -- Encode nonce as 96-bit value: 4 zero bytes + 8-byte LE
    336 encode_nonce :: Word64 -> BS.ByteString
    337 encode_nonce n = BS.replicate 4 0x00 <> encode_le64 n
    338 
    339 -- Little-endian 64-bit encoding
    340 encode_le64 :: Word64 -> BS.ByteString
    341 encode_le64 n = BS.pack [
    342     fi (n .&. 0xff)
    343   , fi (unsafeShiftR n 8  .&. 0xff)
    344   , fi (unsafeShiftR n 16 .&. 0xff)
    345   , fi (unsafeShiftR n 24 .&. 0xff)
    346   , fi (unsafeShiftR n 32 .&. 0xff)
    347   , fi (unsafeShiftR n 40 .&. 0xff)
    348   , fi (unsafeShiftR n 48 .&. 0xff)
    349   , fi (unsafeShiftR n 56 .&. 0xff)
    350   ]
    351 
    352 -- Big-endian 16-bit encoding
    353 encode_be16 :: Word16 -> BS.ByteString
    354 encode_be16 n =
    355   BS.pack [fi (unsafeShiftR n 8), fi (n .&. 0xff)]
    356 
    357 -- Big-endian 16-bit decoding
    358 decode_be16 :: BS.ByteString -> Maybe Word16
    359 decode_be16 bs
    360   | BS.length bs /= 2 = Nothing
    361   | otherwise =
    362       let !b0 = BS.index bs 0
    363           !b1 = BS.index bs 1
    364       in Just (fi b0 * 0x100 + fi b1)
    365 
    366 -- handshake -------------------------------------------------------
    367 
    368 -- Initialize handshake state
    369 --
    370 -- h = SHA256(protocol_name)
    371 -- ck = h
    372 -- h = SHA256(h || prologue)
    373 -- h = SHA256(h || responder_static_pubkey)
    374 init_handshake
    375   :: Sec           -- ^ local static secret
    376   -> Pub           -- ^ local static public
    377   -> Sec           -- ^ ephemeral secret
    378   -> Pub           -- ^ ephemeral public
    379   -> Maybe Pub     -- ^ remote static
    380   -> Bool          -- ^ True if initiator
    381   -> HandshakeState
    382 init_handshake s_sec s_pub e_sec e_pub m_rs is_init =
    383   let !h0 = SHA256.hash _PROTOCOL_NAME
    384       !ck = h0
    385       !h1 = mix_hash h0 _PROLOGUE
    386       -- Mix in responder's static pubkey
    387       !h2 = case (is_init, m_rs) of
    388         (True, Just rs) ->
    389           mix_hash h1 (serialize_pub rs)
    390         (False, Nothing) ->
    391           mix_hash h1 (serialize_pub s_pub)
    392         _ -> h1  -- shouldn't happen
    393   in HandshakeState {
    394        hs_h      = h2
    395      , hs_ck     = ck
    396      , hs_temp_k = BS.replicate 32 0x00
    397      , hs_e_sec  = e_sec
    398      , hs_e_pub  = e_pub
    399      , hs_s_sec  = s_sec
    400      , hs_s_pub  = s_pub
    401      , hs_re     = Nothing
    402      , hs_rs     = m_rs
    403      }
    404 
    405 -- | Initiator: generate Act 1 message (50 bytes).
    406 --
    407 --   Takes local static key, remote static pubkey, and 32
    408 --   bytes of entropy for ephemeral key generation.
    409 --
    410 --   Returns the 50-byte Act 1 message and handshake state
    411 --   for Act 3.
    412 --
    413 --   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
    414 --   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
    415 --   >>> let eph_ent = BS.replicate 32 0x12
    416 --   >>> case act1 i_sec i_pub r_pub eph_ent of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
    417 --   50
    418 act1
    419   :: Sec
    420   -> Pub
    421   -> Pub
    422   -> BS.ByteString
    423   -> Either Error
    424        (BS.ByteString, HandshakeFor Initiator)
    425 act1 s_sec s_pub rs ent = do
    426   (e_sec, e_pub) <- note InvalidKey (keypair ent)
    427   let !hs0 = init_handshake
    428                s_sec s_pub e_sec e_pub (Just rs) True
    429       !e_pub_bytes = serialize_pub e_pub
    430       !h1 = mix_hash (hs_h hs0) e_pub_bytes
    431   es <- note InvalidKey (ecdh e_sec rs)
    432   let !(ck1, temp_k1) = mix_key (hs_ck hs0) es
    433   c <- note InvalidMAC
    434          (encrypt_with_ad temp_k1 0 h1 BS.empty)
    435   let !h2 = mix_hash h1 c
    436       !msg = BS.singleton 0x00 <> e_pub_bytes <> c
    437       !hs1 = hs0 {
    438         hs_h      = h2
    439       , hs_ck     = ck1
    440       , hs_temp_k = temp_k1
    441       }
    442   pure (msg, HandshakeFor hs1)
    443 
    444 -- | Responder: process Act 1 and generate Act 2 message
    445 --   (50 bytes).
    446 --
    447 --   Takes local static key and 32 bytes of entropy for
    448 --   ephemeral key, plus the 50-byte Act 1 message from
    449 --   initiator.
    450 --
    451 --   Returns the 50-byte Act 2 message and handshake state
    452 --   for finalize.
    453 --
    454 --   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
    455 --   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
    456 --   >>> let Right (msg1, _) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
    457 --   >>> case act2 r_sec r_pub (BS.replicate 32 0x22) msg1 of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
    458 --   50
    459 act2
    460   :: Sec
    461   -> Pub
    462   -> BS.ByteString
    463   -> BS.ByteString
    464   -> Either Error
    465        (BS.ByteString, HandshakeFor Responder)
    466 act2 s_sec s_pub ent msg1 = do
    467   require (BS.length msg1 == 50) InvalidLength
    468   let !version = BS.index msg1 0
    469       !re_bytes = BS.take 33 (BS.drop 1 msg1)
    470       !c = BS.drop 34 msg1
    471   require (version == 0x00) InvalidVersion
    472   re <- note InvalidPub (parse_pub re_bytes)
    473   (e_sec, e_pub) <- note InvalidKey (keypair ent)
    474   let !hs0 = init_handshake
    475                s_sec s_pub e_sec e_pub Nothing False
    476       !h1 = mix_hash (hs_h hs0) re_bytes
    477   es <- note InvalidKey (ecdh s_sec re)
    478   let !(ck1, temp_k1) = mix_key (hs_ck hs0) es
    479   _ <- note InvalidMAC
    480          (decrypt_with_ad temp_k1 0 h1 c)
    481   let !h2 = mix_hash h1 c
    482       !e_pub_bytes = serialize_pub e_pub
    483       !h3 = mix_hash h2 e_pub_bytes
    484   ee <- note InvalidKey (ecdh e_sec re)
    485   let !(ck2, temp_k2) = mix_key ck1 ee
    486   c2 <- note InvalidMAC
    487           (encrypt_with_ad temp_k2 0 h3 BS.empty)
    488   let !h4 = mix_hash h3 c2
    489       !msg = BS.singleton 0x00 <> e_pub_bytes <> c2
    490       !hs1 = hs0 {
    491         hs_h      = h4
    492       , hs_ck     = ck2
    493       , hs_temp_k = temp_k2
    494       , hs_re     = Just re
    495       }
    496   pure (msg, HandshakeFor hs1)
    497 
    498 -- | Initiator: process Act 2 and generate Act 3 (66 bytes),
    499 --   completing the handshake.
    500 --
    501 --   Returns the 66-byte Act 3 message and the handshake
    502 --   result.
    503 --
    504 --   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
    505 --   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
    506 --   >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
    507 --   >>> let Right (msg2, _) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
    508 --   >>> case act3 i_hs msg2 of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
    509 --   66
    510 act3
    511   :: HandshakeFor Initiator
    512   -> BS.ByteString
    513   -> Either Error (BS.ByteString, Handshake)
    514 act3 (HandshakeFor hs) msg2 = do
    515   require (BS.length msg2 == 50) InvalidLength
    516   let !version = BS.index msg2 0
    517       !re_bytes = BS.take 33 (BS.drop 1 msg2)
    518       !c = BS.drop 34 msg2
    519   require (version == 0x00) InvalidVersion
    520   re <- note InvalidPub (parse_pub re_bytes)
    521   let !h1 = mix_hash (hs_h hs) re_bytes
    522   ee <- note InvalidKey (ecdh (hs_e_sec hs) re)
    523   let !(ck1, temp_k2) = mix_key (hs_ck hs) ee
    524   _ <- note InvalidMAC
    525          (decrypt_with_ad temp_k2 0 h1 c)
    526   let !h2 = mix_hash h1 c
    527       !s_pub_bytes = serialize_pub (hs_s_pub hs)
    528   c3 <- note InvalidMAC
    529           (encrypt_with_ad temp_k2 1 h2 s_pub_bytes)
    530   let !h3 = mix_hash h2 c3
    531   se <- note InvalidKey (ecdh (hs_s_sec hs) re)
    532   let !(ck2, temp_k3) = mix_key ck1 se
    533   t <- note InvalidMAC
    534          (encrypt_with_ad temp_k3 0 h3 BS.empty)
    535   let !(sk, rk) = mix_key ck2 BS.empty
    536       !msg = BS.singleton 0x00 <> c3 <> t
    537       !sess = Session {
    538         sess_sk  = Key32 sk
    539       , sess_sn  = SessionNonce 0
    540       , sess_sck = Key32 ck2
    541       , sess_rk  = Key32 rk
    542       , sess_rn  = SessionNonce 0
    543       , sess_rck = Key32 ck2
    544       }
    545   rs <- note InvalidPub (hs_rs hs)
    546   let !result = Handshake {
    547         session       = sess
    548       , remote_static = rs
    549       }
    550   pure (msg, result)
    551 
    552 -- | Responder: process Act 3 (66 bytes) and complete the
    553 --   handshake.
    554 --
    555 --   Returns the handshake result with authenticated remote
    556 --   static pubkey.
    557 --
    558 --   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
    559 --   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
    560 --   >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
    561 --   >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
    562 --   >>> let Right (msg3, _) = act3 i_hs msg2
    563 --   >>> case finalize r_hs msg3 of { Right _ -> "ok"; Left e -> show e }
    564 --   "ok"
    565 finalize
    566   :: HandshakeFor Responder
    567   -> BS.ByteString
    568   -> Either Error Handshake
    569 finalize (HandshakeFor hs) msg3 = do
    570   require (BS.length msg3 == 66) InvalidLength
    571   let !version = BS.index msg3 0
    572       !c = BS.take 49 (BS.drop 1 msg3)
    573       !t = BS.drop 50 msg3
    574   require (version == 0x00) InvalidVersion
    575   rs_bytes <- note InvalidMAC
    576     (decrypt_with_ad (hs_temp_k hs) 1 (hs_h hs) c)
    577   rs <- note InvalidPub (parse_pub rs_bytes)
    578   let !h1 = mix_hash (hs_h hs) c
    579   se <- note InvalidKey (ecdh (hs_e_sec hs) rs)
    580   let !(ck1, temp_k3) = mix_key (hs_ck hs) se
    581   _ <- note InvalidMAC
    582          (decrypt_with_ad temp_k3 0 h1 t)
    583   -- responder swaps order (receives what initiator sends)
    584   let !(rk, sk) = mix_key ck1 BS.empty
    585       !sess = Session {
    586         sess_sk  = Key32 sk
    587       , sess_sn  = SessionNonce 0
    588       , sess_sck = Key32 ck1
    589       , sess_rk  = Key32 rk
    590       , sess_rn  = SessionNonce 0
    591       , sess_rck = Key32 ck1
    592       }
    593       !result = Handshake {
    594         session       = sess
    595       , remote_static = rs
    596       }
    597   pure result
    598 
    599 -- message encryption ----------------------------------------------
    600 
    601 -- | Encrypt a message (max 65535 bytes).
    602 --
    603 --   Returns the encrypted packet and updated session. Key
    604 --   rotation is handled automatically at nonce 1000.
    605 --
    606 --   Wire format:
    607 --     encrypted_length (2) || MAC (16)
    608 --     || encrypted_body || MAC (16)
    609 --
    610 --   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
    611 --   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
    612 --   >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
    613 --   >>> let Right (msg2, _) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
    614 --   >>> let Right (_, i_result) = act3 i_hs msg2
    615 --   >>> let sess = session i_result
    616 --   >>> case encrypt sess "hello" of { Right (ct, _) -> BS.length ct; Left _ -> 0 }
    617 --   39
    618 encrypt
    619   :: Session
    620   -> BS.ByteString
    621   -> Either Error (BS.ByteString, Session)
    622 encrypt sess pt = do
    623   let !len = BS.length pt
    624   require (len <= 65535) InvalidLength
    625   let !len_bytes = encode_be16 (fi len)
    626       !sk = unKey32 (sess_sk sess)
    627       !sn = unSessionNonce (sess_sn sess)
    628       !sck = unKey32 (sess_sck sess)
    629   lc <- note InvalidMAC
    630           (encrypt_with_ad sk sn BS.empty len_bytes)
    631   let !(sn1, sck1, sk1) = step_nonce sn sck sk
    632   bc <- note InvalidMAC
    633           (encrypt_with_ad sk1 sn1 BS.empty pt)
    634   let !(sn2, sck2, sk2) = step_nonce sn1 sck1 sk1
    635       !packet = lc <> bc
    636       !sess' = sess {
    637         sess_sk  = Key32 sk2
    638       , sess_sn  = SessionNonce sn2
    639       , sess_sck = Key32 sck2
    640       }
    641   pure (packet, sess')
    642 
    643 -- | Decrypt a message, requiring an exact packet with no
    644 --   trailing bytes.
    645 --
    646 --   Returns the plaintext and updated session. Key rotation
    647 --   is handled automatically at nonce 1000.
    648 --
    649 --   This is a strict variant that rejects any trailing data.
    650 --   For streaming use cases where you need to handle multiple
    651 --   frames in a buffer, use 'decrypt_frame' instead.
    652 --
    653 --   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
    654 --   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
    655 --   >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
    656 --   >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
    657 --   >>> let Right (msg3, i_result) = act3 i_hs msg2
    658 --   >>> let Right r_result = finalize r_hs msg3
    659 --   >>> let Right (ct, _) = encrypt (session i_result) "hello"
    660 --   >>> case decrypt (session r_result) ct of { Right (pt, _) -> pt; Left _ -> "fail" }
    661 --   "hello"
    662 decrypt
    663   :: Session
    664   -> BS.ByteString
    665   -> Either Error (BS.ByteString, Session)
    666 decrypt sess packet = do
    667   (pt, remainder, sess') <- decrypt_frame sess packet
    668   require (BS.null remainder) InvalidLength
    669   pure (pt, sess')
    670 
    671 -- | Decrypt a single frame from a buffer, returning the
    672 --   remainder.
    673 --
    674 --   Returns the plaintext, any unconsumed bytes, and the
    675 --   updated session. Key rotation is handled automatically
    676 --   every 1000 messages.
    677 --
    678 --   This is useful for streaming scenarios where multiple
    679 --   messages may be buffered together. The remainder can be
    680 --   passed to the next call to 'decrypt_frame'.
    681 --
    682 --   Wire format consumed:
    683 --     encrypted_length (18) || encrypted_body (len + 16)
    684 --
    685 --   >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
    686 --   >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
    687 --   >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
    688 --   >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
    689 --   >>> let Right (msg3, i_result) = act3 i_hs msg2
    690 --   >>> let Right r_result = finalize r_hs msg3
    691 --   >>> let Right (ct, _) = encrypt (session i_result) "hello"
    692 --   >>> case decrypt_frame (session r_result) ct of { Right (pt, rem, _) -> (pt, BS.null rem); Left _ -> ("fail", False) }
    693 --   ("hello",True)
    694 decrypt_frame
    695   :: Session
    696   -> BS.ByteString
    697   -> Either Error
    698        (BS.ByteString, BS.ByteString, Session)
    699 decrypt_frame sess packet = do
    700   require (BS.length packet >= 34) InvalidLength
    701   let !lc = BS.take 18 packet
    702       !rest = BS.drop 18 packet
    703       !rk = unKey32 (sess_rk sess)
    704       !rn = unSessionNonce (sess_rn sess)
    705       !rck = unKey32 (sess_rck sess)
    706   len_bytes <- note InvalidMAC
    707     (decrypt_with_ad rk rn BS.empty lc)
    708   len <- note InvalidLength (decode_be16 len_bytes)
    709   let !(rn1, rck1, rk1) = step_nonce rn rck rk
    710       !body_len = fi len + 16
    711   require (BS.length rest >= body_len) InvalidLength
    712   let !bc = BS.take body_len rest
    713       !remainder = BS.drop body_len rest
    714   pt <- note InvalidMAC
    715           (decrypt_with_ad rk1 rn1 BS.empty bc)
    716   let !(rn2, rck2, rk2) = step_nonce rn1 rck1 rk1
    717       !sess' = sess {
    718         sess_rk  = Key32 rk2
    719       , sess_rn  = SessionNonce rn2
    720       , sess_rck = Key32 rck2
    721       }
    722   pure (pt, remainder, sess')
    723 
    724 -- | Decrypt a frame from a partial buffer, indicating when
    725 --   more data needed.
    726 --
    727 --   Unlike 'decrypt_frame', this function handles incomplete
    728 --   buffers gracefully by returning 'NeedMore' with the
    729 --   number of additional bytes required to make progress.
    730 --
    731 --   * If the buffer has fewer than 18 bytes (encrypted
    732 --     length + MAC), returns @'NeedMore' n@ where @n@ is
    733 --     the bytes still needed.
    734 --   * If the length header is complete but the body is
    735 --     incomplete, returns @'NeedMore' n@ with bytes needed
    736 --     for the full frame.
    737 --   * MAC or decryption failures return 'FrameError'.
    738 --   * A complete, valid frame returns 'FrameOk' with
    739 --     plaintext, remainder, and updated session.
    740 --
    741 --   This is useful for non-blocking I/O where data arrives
    742 --   incrementally.
    743 decrypt_frame_partial
    744   :: Session
    745   -> BS.ByteString
    746   -> FrameResult
    747 decrypt_frame_partial sess buf
    748   | buflen < 18 = NeedMore (18 - buflen)
    749   | otherwise =
    750       let !lc = BS.take 18 buf
    751           !rest = BS.drop 18 buf
    752           !rk = unKey32 (sess_rk sess)
    753           !rn = unSessionNonce (sess_rn sess)
    754           !rck = unKey32 (sess_rck sess)
    755       in case decrypt_with_ad rk rn BS.empty lc of
    756            Nothing -> FrameError InvalidMAC
    757            Just len_bytes ->
    758              case decode_be16 len_bytes of
    759                Nothing -> FrameError InvalidLength
    760                Just len ->
    761                  let !body_len = fi len + 16
    762                      !(rn1, rck1, rk1) =
    763                        step_nonce rn rck rk
    764                  in if BS.length rest < body_len
    765                    then NeedMore
    766                      (body_len - BS.length rest)
    767                    else
    768                      let !bc = BS.take body_len rest
    769                          !remainder =
    770                            BS.drop body_len rest
    771                      in case decrypt_with_ad
    772                               rk1 rn1 BS.empty bc of
    773                        Nothing ->
    774                          FrameError InvalidMAC
    775                        Just pt ->
    776                          let !(rn2, rck2, rk2) =
    777                                step_nonce rn1 rck1 rk1
    778                              !sess' = sess {
    779                                sess_rk  = Key32 rk2
    780                              , sess_rn  =
    781                                  SessionNonce rn2
    782                              , sess_rck = Key32 rck2
    783                              }
    784                          in FrameOk pt remainder sess'
    785   where
    786     !buflen = BS.length buf
    787 
    788 -- key rotation ----------------------------------------------------
    789 
    790 -- Key rotation occurs after nonce reaches 1000 (i.e., before
    791 -- using 1000)
    792 -- (ck', k') = HKDF(ck, k), reset nonce to 0
    793 step_nonce
    794   :: Word64
    795   -> BS.ByteString
    796   -> BS.ByteString
    797   -> (Word64, BS.ByteString, BS.ByteString)
    798 step_nonce n ck k
    799   | n + 1 == 1000 =
    800       let !(ck', k') = mix_key ck k
    801       in (0, ck', k')
    802   | otherwise = (n + 1, ck, k)
    803 
    804 -- utilities -------------------------------------------------------
    805 
    806 -- Lift Maybe to Either
    807 note :: e -> Maybe a -> Either e a
    808 note e = maybe (Left e) Right
    809 {-# INLINE note #-}
    810 
    811 -- Require condition or fail
    812 require :: Bool -> e -> Either e ()
    813 require cond e = unless cond (Left e)
    814 {-# INLINE require #-}
    815 
    816 fi :: (Integral a, Num b) => a -> b
    817 fi = fromIntegral
    818 {-# INLINE fi #-}