commit ffcab68d5a14498d37e52390d4fe0bbe48ecc020
parent ca737a5e3d1970d51078cee6cc10279351759870
Author: Jared Tobin <jared@jtobin.io>
Date: Mon, 20 Apr 2026 14:59:48 +0800
lib: add type safety newtypes and role-indexed handshake
Introduce four type-level improvements without changing
runtime behaviour:
- SessionNonce: newtype over Word64, distinguishes send/receive
nonces in Session
- Key32: newtype over ByteString with smart constructor (key32)
validating length == 32, used for all 32-byte keys in Session
- MessagePayload: newtype over ByteString with smart constructor
(mkMessagePayload) validating length <= 65535
- HandshakeFor: phantom-typed wrapper over HandshakeState,
indexed by Initiator/Responder to prevent role confusion at
the type level (act1 returns HandshakeFor Initiator, act2
returns HandshakeFor Responder, act3 requires Initiator,
finalize requires Responder)
Split into Internal module exporting all constructors (for
tests/benchmarks) and public module re-exporting safe API.
All 26 tests pass, benchmarks build.
Diffstat:
5 files changed, 951 insertions(+), 676 deletions(-)
diff --git a/bench/Main.hs b/bench/Main.hs
@@ -7,16 +7,20 @@ module Main where
import Control.DeepSeq
import Criterion.Main
import qualified Data.ByteString as BS
-import qualified Lightning.Protocol.BOLT8 as BOLT8
+import qualified Lightning.Protocol.BOLT8.Internal as BOLT8
instance NFData BOLT8.Pub where
rnf p = rnf (BOLT8.serialize_pub p)
instance NFData BOLT8.Sec
instance NFData BOLT8.Error
+instance NFData BOLT8.Key32
+instance NFData BOLT8.SessionNonce
instance NFData BOLT8.Session
instance NFData BOLT8.HandshakeState
instance NFData BOLT8.Handshake
+instance NFData (BOLT8.HandshakeFor a) where
+ rnf (BOLT8.HandshakeFor s) = rnf s
main :: IO ()
main = defaultMain [
@@ -36,50 +40,74 @@ keys :: Benchmark
keys = bgroup "keys" [
bench "keypair" $ nf BOLT8.keypair i_s_ent
, bench "parse_pub" $ nf BOLT8.parse_pub r_s_pub_bs
- , bench "serialize_pub" $ nf BOLT8.serialize_pub r_s_pub
+ , bench "serialize_pub" $
+ nf BOLT8.serialize_pub r_s_pub
]
where
Just (_, r_s_pub) = BOLT8.keypair r_s_ent
r_s_pub_bs = BOLT8.serialize_pub r_s_pub
handshake :: Benchmark
-handshake = env setup $ \ ~(i_s_sec, i_s_pub, r_s_sec, r_s_pub, msg1, i_hs,
- msg2, r_hs, msg3) ->
+handshake = env setup $
+ \ ~(i_s_sec, i_s_pub, r_s_sec, r_s_pub,
+ msg1, i_hs, msg2, r_hs, msg3) ->
bgroup "handshake" [
- bench "act1" $ nf (BOLT8.act1 i_s_sec i_s_pub r_s_pub) i_e_ent
- , bench "act2" $ nf (BOLT8.act2 r_s_sec r_s_pub r_e_ent) msg1
- , bench "act3" $ nf (BOLT8.act3 i_hs) msg2
- , bench "finalize" $ nf (BOLT8.finalize r_hs) msg3
+ bench "act1" $
+ nf (BOLT8.act1 i_s_sec i_s_pub r_s_pub) i_e_ent
+ , bench "act2" $
+ nf (BOLT8.act2 r_s_sec r_s_pub r_e_ent) msg1
+ , bench "act3" $
+ nf (BOLT8.act3 i_hs) msg2
+ , bench "finalize" $
+ nf (BOLT8.finalize r_hs) msg3
]
where
setup = do
- let Just (!i_s_sec, !i_s_pub) = BOLT8.keypair i_s_ent
- Just (!r_s_sec, !r_s_pub) = BOLT8.keypair r_s_ent
- Right (!msg1, !i_hs) = BOLT8.act1 i_s_sec i_s_pub r_s_pub i_e_ent
- Right (!msg2, !r_hs) = BOLT8.act2 r_s_sec r_s_pub r_e_ent msg1
+ let Just (!i_s_sec, !i_s_pub) =
+ BOLT8.keypair i_s_ent
+ Just (!r_s_sec, !r_s_pub) =
+ BOLT8.keypair r_s_ent
+ Right (!msg1, !i_hs) =
+ BOLT8.act1 i_s_sec i_s_pub r_s_pub i_e_ent
+ Right (!msg2, !r_hs) =
+ BOLT8.act2 r_s_sec r_s_pub r_e_ent msg1
Right (!msg3, _) = BOLT8.act3 i_hs msg2
- pure (i_s_sec, i_s_pub, r_s_sec, r_s_pub, msg1, i_hs, msg2, r_hs, msg3)
+ pure ( i_s_sec, i_s_pub, r_s_sec, r_s_pub
+ , msg1, i_hs, msg2, r_hs, msg3 )
messages :: Benchmark
-messages = env setup $ \ ~(i_sess, r_sess, ct_small, ct_large) ->
+messages = env setup $
+ \ ~(i_sess, r_sess, ct_small, ct_large) ->
bgroup "messages" [
- bench "encrypt (32B)" $ nf (BOLT8.encrypt i_sess) small_msg
- , bench "encrypt (1KB)" $ nf (BOLT8.encrypt i_sess) large_msg
- , bench "decrypt (32B)" $ nf (BOLT8.decrypt r_sess) ct_small
- , bench "decrypt (1KB)" $ nf (BOLT8.decrypt r_sess) ct_large
+ bench "encrypt (32B)" $
+ nf (BOLT8.encrypt i_sess) small_msg
+ , bench "encrypt (1KB)" $
+ nf (BOLT8.encrypt i_sess) large_msg
+ , bench "decrypt (32B)" $
+ nf (BOLT8.decrypt r_sess) ct_small
+ , bench "decrypt (1KB)" $
+ nf (BOLT8.decrypt r_sess) ct_large
]
where
small_msg = BS.replicate 32 0x00
large_msg = BS.replicate 1024 0x00
setup = do
- let Just (!i_s_sec, !i_s_pub) = BOLT8.keypair i_s_ent
- Just (!r_s_sec, !r_s_pub) = BOLT8.keypair r_s_ent
- Right (msg1, i_hs) = BOLT8.act1 i_s_sec i_s_pub r_s_pub i_e_ent
- Right (msg2, r_hs) = BOLT8.act2 r_s_sec r_s_pub r_e_ent msg1
- Right (msg3, i_result) = BOLT8.act3 i_hs msg2
- Right r_result = BOLT8.finalize r_hs msg3
+ let Just (!i_s_sec, !i_s_pub) =
+ BOLT8.keypair i_s_ent
+ Just (!r_s_sec, !r_s_pub) =
+ BOLT8.keypair r_s_ent
+ Right (msg1, i_hs) =
+ BOLT8.act1 i_s_sec i_s_pub r_s_pub i_e_ent
+ Right (msg2, r_hs) =
+ BOLT8.act2 r_s_sec r_s_pub r_e_ent msg1
+ Right (msg3, i_result) =
+ BOLT8.act3 i_hs msg2
+ Right r_result =
+ BOLT8.finalize r_hs msg3
!i_sess = BOLT8.session i_result
!r_sess = BOLT8.session r_result
- Right (!ct_small, _) = BOLT8.encrypt i_sess small_msg
- Right (!ct_large, _) = BOLT8.encrypt i_sess large_msg
+ Right (!ct_small, _) =
+ BOLT8.encrypt i_sess small_msg
+ Right (!ct_large, _) =
+ BOLT8.encrypt i_sess large_msg
pure (i_sess, r_sess, ct_small, ct_large)
diff --git a/bench/Weight.hs b/bench/Weight.hs
@@ -6,7 +6,7 @@ module Main where
import Control.DeepSeq
import qualified Data.ByteString as BS
-import qualified Lightning.Protocol.BOLT8 as BOLT8
+import qualified Lightning.Protocol.BOLT8.Internal as BOLT8
import Weigh
instance NFData BOLT8.Pub where
@@ -14,9 +14,13 @@ instance NFData BOLT8.Pub where
instance NFData BOLT8.Sec
instance NFData BOLT8.Error
+instance NFData BOLT8.Key32
+instance NFData BOLT8.SessionNonce
instance NFData BOLT8.Session
instance NFData BOLT8.HandshakeState
instance NFData BOLT8.Handshake
+instance NFData (BOLT8.HandshakeFor a) where
+ rnf (BOLT8.HandshakeFor s) = rnf s
-- note that 'weigh' doesn't work properly in a repl
main :: IO ()
@@ -39,37 +43,56 @@ keys =
in wgroup "keys" $ do
func "keypair" BOLT8.keypair i_s_ent
func "parse_pub" BOLT8.parse_pub r_s_pub_bs
- func "serialize_pub" BOLT8.serialize_pub r_s_pub
+ func "serialize_pub"
+ BOLT8.serialize_pub r_s_pub
handshake :: Weigh ()
handshake =
- let Just (!i_s_sec, !i_s_pub) = BOLT8.keypair i_s_ent
- Just (!r_s_sec, !r_s_pub) = BOLT8.keypair r_s_ent
- Right (!msg1, !i_hs) = BOLT8.act1 i_s_sec i_s_pub r_s_pub i_e_ent
- Right (!msg2, !r_hs) = BOLT8.act2 r_s_sec r_s_pub r_e_ent msg1
+ let Just (!i_s_sec, !i_s_pub) =
+ BOLT8.keypair i_s_ent
+ Just (!r_s_sec, !r_s_pub) =
+ BOLT8.keypair r_s_ent
+ Right (!msg1, !i_hs) =
+ BOLT8.act1 i_s_sec i_s_pub r_s_pub i_e_ent
+ Right (!msg2, !r_hs) =
+ BOLT8.act2 r_s_sec r_s_pub r_e_ent msg1
Right (!msg3, _) = BOLT8.act3 i_hs msg2
in wgroup "handshake" $ do
- func "act1" (BOLT8.act1 i_s_sec i_s_pub r_s_pub) i_e_ent
- func "act2" (BOLT8.act2 r_s_sec r_s_pub r_e_ent) msg1
+ func "act1"
+ (BOLT8.act1 i_s_sec i_s_pub r_s_pub) i_e_ent
+ func "act2"
+ (BOLT8.act2 r_s_sec r_s_pub r_e_ent) msg1
func "act3" (BOLT8.act3 i_hs) msg2
func "finalize" (BOLT8.finalize r_hs) msg3
messages :: Weigh ()
messages =
- let Just (!i_s_sec, !i_s_pub) = BOLT8.keypair i_s_ent
- Just (!r_s_sec, !r_s_pub) = BOLT8.keypair r_s_ent
- Right (msg1, i_hs) = BOLT8.act1 i_s_sec i_s_pub r_s_pub i_e_ent
- Right (msg2, r_hs) = BOLT8.act2 r_s_sec r_s_pub r_e_ent msg1
- Right (msg3, i_result) = BOLT8.act3 i_hs msg2
- Right r_result = BOLT8.finalize r_hs msg3
+ let Just (!i_s_sec, !i_s_pub) =
+ BOLT8.keypair i_s_ent
+ Just (!r_s_sec, !r_s_pub) =
+ BOLT8.keypair r_s_ent
+ Right (msg1, i_hs) =
+ BOLT8.act1 i_s_sec i_s_pub r_s_pub i_e_ent
+ Right (msg2, r_hs) =
+ BOLT8.act2 r_s_sec r_s_pub r_e_ent msg1
+ Right (msg3, i_result) =
+ BOLT8.act3 i_hs msg2
+ Right r_result =
+ BOLT8.finalize r_hs msg3
!i_sess = BOLT8.session i_result
!r_sess = BOLT8.session r_result
!small_msg = BS.replicate 32 0x00
!large_msg = BS.replicate 1024 0x00
- Right (!ct_small, _) = BOLT8.encrypt i_sess small_msg
- Right (!ct_large, _) = BOLT8.encrypt i_sess large_msg
+ Right (!ct_small, _) =
+ BOLT8.encrypt i_sess small_msg
+ Right (!ct_large, _) =
+ BOLT8.encrypt i_sess large_msg
in wgroup "messages" $ do
- func "encrypt (32B)" (BOLT8.encrypt i_sess) small_msg
- func "encrypt (1KB)" (BOLT8.encrypt i_sess) large_msg
- func "decrypt (32B)" (BOLT8.decrypt r_sess) ct_small
- func "decrypt (1KB)" (BOLT8.decrypt r_sess) ct_large
+ func "encrypt (32B)"
+ (BOLT8.encrypt i_sess) small_msg
+ func "encrypt (1KB)"
+ (BOLT8.encrypt i_sess) large_msg
+ func "decrypt (32B)"
+ (BOLT8.decrypt r_sess) ct_small
+ func "decrypt (1KB)"
+ (BOLT8.decrypt r_sess) ct_large
diff --git a/lib/Lightning/Protocol/BOLT8.hs b/lib/Lightning/Protocol/BOLT8.hs
@@ -1,10 +1,4 @@
{-# OPTIONS_HADDOCK prune #-}
-{-# LANGUAGE BangPatterns #-}
-{-# LANGUAGE DeriveGeneric #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE OverloadedStrings #-}
-{-# LANGUAGE RecordWildCards #-}
-{-# LANGUAGE ViewPatterns #-}
-- |
-- Module: Lightning.Protocol.BOLT8
@@ -12,16 +6,19 @@
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
--- Encrypted and authenticated transport for the Lightning Network, per
+-- Encrypted and authenticated transport for the Lightning
+-- Network, per
-- [BOLT #8](https://github.com/lightning/bolts/blob/master/08-transport.md).
--
--- This module implements the Noise_XK_secp256k1_ChaChaPoly_SHA256
--- handshake and subsequent encrypted message transport.
+-- This module implements the
+-- Noise_XK_secp256k1_ChaChaPoly_SHA256 handshake and
+-- subsequent encrypted message transport.
--
-- = Handshake
--
--- A BOLT #8 handshake consists of three acts. The /initiator/ knows the
--- responder's static public key in advance and initiates the connection:
+-- A BOLT #8 handshake consists of three acts. The
+-- /initiator/ knows the responder's static public key in
+-- advance and initiates the connection:
--
-- @
-- (msg1, state) <- act1 i_sec i_pub r_pub entropy
@@ -32,7 +29,8 @@
-- let session = 'session' result
-- @
--
--- The /responder/ receives the connection and authenticates the initiator:
+-- The /responder/ receives the connection and authenticates
+-- the initiator:
--
-- @
-- -- receive msg1 (50 bytes) from initiator
@@ -45,9 +43,10 @@
--
-- = Message Transport
--
--- After a successful handshake, use 'encrypt' and 'decrypt' to exchange
--- messages. Each returns an updated 'Session' that must be used for the
--- next operation (keys rotate every 1000 messages):
+-- After a successful handshake, use 'encrypt' and 'decrypt'
+-- to exchange messages. Each returns an updated 'Session'
+-- that must be used for the next operation (keys rotate
+-- every 1000 messages):
--
-- @
-- -- sender
@@ -59,10 +58,11 @@
--
-- = Message Framing
--
--- BOLT #8 runs over a byte stream, so callers often need to deal with
--- partial buffers. Use 'decrypt_frame' when you have exactly one frame,
--- or 'decrypt_frame_partial' to handle incremental reads and return how
--- many bytes are still needed.
+-- BOLT #8 runs over a byte stream, so callers often need to
+-- deal with partial buffers. Use 'decrypt_frame' when you
+-- have exactly one frame, or 'decrypt_frame_partial' to
+-- handle incremental reads and return how many bytes are
+-- still needed.
--
-- Maximum plaintext size is 65535 bytes.
@@ -74,6 +74,21 @@ module Lightning.Protocol.BOLT8 (
, parse_pub
, serialize_pub
+ -- * Newtypes
+ , Key32
+ , key32
+ , unKey32
+ , SessionNonce
+ , unSessionNonce
+ , MessagePayload
+ , unMessagePayload
+ , mkMessagePayload
+
+ -- * Handshake roles
+ , Initiator
+ , Responder
+ , HandshakeFor
+
-- * Handshake (initiator)
, act1
, act3
@@ -96,614 +111,4 @@ module Lightning.Protocol.BOLT8 (
, Error(..)
) where
-import Control.Monad (guard, unless)
-import qualified Crypto.AEAD.ChaCha20Poly1305 as AEAD
-import qualified Crypto.Curve.Secp256k1 as Secp256k1
-import qualified Crypto.Hash.SHA256 as SHA256
-import qualified Crypto.KDF.HMAC as HKDF
-import Data.Bits (unsafeShiftR, (.&.))
-import qualified Data.ByteString as BS
-import Data.Word (Word16, Word64)
-import GHC.Generics (Generic)
-
--- types ---------------------------------------------------------------------
-
--- | Secret key (32 bytes).
-newtype Sec = Sec BS.ByteString
- deriving (Eq, Generic)
-
--- | Compressed public key.
-newtype Pub = Pub Secp256k1.Projective
-
-instance Eq Pub where
- (Pub a) == (Pub b) =
- Secp256k1.serialize_point a == Secp256k1.serialize_point b
-
-instance Show Pub where
- show (Pub p) = "Pub " ++ show (Secp256k1.serialize_point p)
-
--- | Handshake errors.
-data Error =
- InvalidKey
- | InvalidPub
- | InvalidMAC
- | InvalidVersion
- | InvalidLength
- | DecryptionFailed
- deriving (Eq, Show, Generic)
-
--- | Result of attempting to decrypt a frame from a partial buffer.
-data FrameResult =
- NeedMore {-# UNPACK #-} !Int
- -- ^ More bytes needed; the 'Int' is the minimum additional bytes required.
- | FrameOk !BS.ByteString !BS.ByteString !Session
- -- ^ Successfully decrypted: plaintext, remainder, updated session.
- | FrameError !Error
- -- ^ Decryption failed with the given error.
- deriving Generic
-
--- | Post-handshake session state.
-data Session = Session {
- sess_sk :: {-# UNPACK #-} !BS.ByteString -- ^ send key (32 bytes)
- , sess_sn :: {-# UNPACK #-} !Word64 -- ^ send nonce
- , sess_sck :: {-# UNPACK #-} !BS.ByteString -- ^ send chaining key
- , sess_rk :: {-# UNPACK #-} !BS.ByteString -- ^ receive key (32 bytes)
- , sess_rn :: {-# UNPACK #-} !Word64 -- ^ receive nonce
- , sess_rck :: {-# UNPACK #-} !BS.ByteString -- ^ receive chaining key
- }
- deriving Generic
-
--- | Result of a successful handshake.
-data Handshake = Handshake {
- session :: !Session -- ^ session state
- , remote_static :: !Pub -- ^ authenticated remote static pubkey
- }
- deriving Generic
-
--- | Internal handshake state (exported for benchmarking).
-data HandshakeState = HandshakeState {
- hs_h :: {-# UNPACK #-} !BS.ByteString -- handshake hash (32 bytes)
- , hs_ck :: {-# UNPACK #-} !BS.ByteString -- chaining key (32 bytes)
- , hs_temp_k :: {-# UNPACK #-} !BS.ByteString -- temp key (32 bytes)
- , hs_e_sec :: !Sec -- ephemeral secret
- , hs_e_pub :: !Pub -- ephemeral public
- , hs_s_sec :: !Sec -- static secret
- , hs_s_pub :: !Pub -- static public
- , hs_re :: !(Maybe Pub) -- remote ephemeral
- , hs_rs :: !(Maybe Pub) -- remote static
- }
- deriving Generic
-
--- protocol constants --------------------------------------------------------
-
-_PROTOCOL_NAME :: BS.ByteString
-_PROTOCOL_NAME = "Noise_XK_secp256k1_ChaChaPoly_SHA256"
-
-_PROLOGUE :: BS.ByteString
-_PROLOGUE = "lightning"
-
--- key operations ------------------------------------------------------------
-
--- | Derive a keypair from 32 bytes of entropy.
---
--- Returns Nothing if the entropy is invalid (zero or >= curve order).
---
--- >>> let ent = BS.replicate 32 0x11
--- >>> case keypair ent of { Just _ -> "ok"; Nothing -> "fail" }
--- "ok"
--- >>> keypair (BS.replicate 31 0x11) -- wrong length
--- Nothing
-keypair :: BS.ByteString -> Maybe (Sec, Pub)
-keypair ent = do
- guard (BS.length ent == 32)
- k <- Secp256k1.parse_int256 ent
- p <- Secp256k1.derive_pub k
- pure (Sec ent, Pub p)
-
--- | Parse a 33-byte compressed public key.
---
--- >>> let Just (_, pub) = keypair (BS.replicate 32 0x11)
--- >>> let bytes = serialize_pub pub
--- >>> case parse_pub bytes of { Just _ -> "ok"; Nothing -> "fail" }
--- "ok"
--- >>> parse_pub (BS.replicate 32 0x00) -- wrong length
--- Nothing
-parse_pub :: BS.ByteString -> Maybe Pub
-parse_pub bs = do
- guard (BS.length bs == 33)
- p <- Secp256k1.parse_point bs
- pure (Pub p)
-
--- | Serialize a public key to 33-byte compressed form.
---
--- >>> let Just (_, pub) = keypair (BS.replicate 32 0x11)
--- >>> BS.length (serialize_pub pub)
--- 33
-serialize_pub :: Pub -> BS.ByteString
-serialize_pub (Pub p) = Secp256k1.serialize_point p
-
--- cryptographic primitives --------------------------------------------------
-
--- bolt8-style ECDH
-ecdh :: Sec -> Pub -> Maybe BS.ByteString
-ecdh (Sec sec) (Pub pub) = do
- k <- Secp256k1.parse_int256 sec
- pt <- Secp256k1.mul pub k
- let compressed = Secp256k1.serialize_point pt
- pure (SHA256.hash compressed)
-
--- h' = SHA256(h || data)
-mix_hash :: BS.ByteString -> BS.ByteString -> BS.ByteString
-mix_hash h dat = SHA256.hash (h <> dat)
-
--- Mix key: (ck', k) = HKDF(ck, input_key_material)
---
--- NB HKDF limits output to 255 * hashlen bytes. For SHA256 that's 8160,
--- well above the 64 bytes requested here, so 'Nothing' is impossible.
-mix_key :: BS.ByteString -> BS.ByteString -> (BS.ByteString, BS.ByteString)
-mix_key ck ikm = case HKDF.derive hmac ck mempty 64 ikm of
- Nothing -> error "ppad-bolt8: internal error, please report a bug!"
- Just output -> BS.splitAt 32 output
- where
- hmac k b = case SHA256.hmac k b of
- SHA256.MAC mac -> mac
-
--- Encrypt with associated data using ChaCha20-Poly1305
-encrypt_with_ad
- :: BS.ByteString -- ^ key (32 bytes)
- -> Word64 -- ^ nonce
- -> BS.ByteString -- ^ associated data
- -> BS.ByteString -- ^ plaintext
- -> Maybe BS.ByteString -- ^ ciphertext || mac (16 bytes)
-encrypt_with_ad key n ad pt =
- case AEAD.encrypt ad key (encode_nonce n) pt of
- Left _ -> Nothing
- Right (ct, mac) -> Just (ct <> mac)
-
--- Decrypt with associated data using ChaCha20-Poly1305
-decrypt_with_ad
- :: BS.ByteString -- ^ key (32 bytes)
- -> Word64 -- ^ nonce
- -> BS.ByteString -- ^ associated data
- -> BS.ByteString -- ^ ciphertext || mac
- -> Maybe BS.ByteString -- ^ plaintext
-decrypt_with_ad key n ad ctmac
- | BS.length ctmac < 16 = Nothing
- | otherwise =
- let (ct, mac) = BS.splitAt (BS.length ctmac - 16) ctmac
- in case AEAD.decrypt ad key (encode_nonce n) (ct, mac) of
- Left _ -> Nothing
- Right pt -> Just pt
-
--- Encode nonce as 96-bit value: 4 zero bytes + 8-byte little-endian
-encode_nonce :: Word64 -> BS.ByteString
-encode_nonce n = BS.replicate 4 0x00 <> encode_le64 n
-
--- Little-endian 64-bit encoding
-encode_le64 :: Word64 -> BS.ByteString
-encode_le64 n = BS.pack [
- fi (n .&. 0xff)
- , fi (unsafeShiftR n 8 .&. 0xff)
- , fi (unsafeShiftR n 16 .&. 0xff)
- , fi (unsafeShiftR n 24 .&. 0xff)
- , fi (unsafeShiftR n 32 .&. 0xff)
- , fi (unsafeShiftR n 40 .&. 0xff)
- , fi (unsafeShiftR n 48 .&. 0xff)
- , fi (unsafeShiftR n 56 .&. 0xff)
- ]
-
--- Big-endian 16-bit encoding
-encode_be16 :: Word16 -> BS.ByteString
-encode_be16 n = BS.pack [fi (unsafeShiftR n 8), fi (n .&. 0xff)]
-
--- Big-endian 16-bit decoding
-decode_be16 :: BS.ByteString -> Maybe Word16
-decode_be16 bs
- | BS.length bs /= 2 = Nothing
- | otherwise =
- let !b0 = BS.index bs 0
- !b1 = BS.index bs 1
- in Just (fi b0 * 0x100 + fi b1)
-
--- handshake -----------------------------------------------------------------
-
--- Initialize handshake state
---
--- h = SHA256(protocol_name)
--- ck = h
--- h = SHA256(h || prologue)
--- h = SHA256(h || responder_static_pubkey)
-init_handshake
- :: Sec -- ^ local static secret
- -> Pub -- ^ local static public
- -> Sec -- ^ ephemeral secret
- -> Pub -- ^ ephemeral public
- -> Maybe Pub -- ^ remote static (initiator knows, responder doesn't)
- -> Bool -- ^ True if initiator
- -> HandshakeState
-init_handshake s_sec s_pub e_sec e_pub m_rs is_initiator =
- let !h0 = SHA256.hash _PROTOCOL_NAME
- !ck = h0
- !h1 = mix_hash h0 _PROLOGUE
- -- Mix in responder's static pubkey
- !h2 = case (is_initiator, m_rs) of
- (True, Just rs) -> mix_hash h1 (serialize_pub rs)
- (False, Nothing) -> mix_hash h1 (serialize_pub s_pub)
- _ -> h1 -- shouldn't happen
- in HandshakeState {
- hs_h = h2
- , hs_ck = ck
- , hs_temp_k = BS.replicate 32 0x00
- , hs_e_sec = e_sec
- , hs_e_pub = e_pub
- , hs_s_sec = s_sec
- , hs_s_pub = s_pub
- , hs_re = Nothing
- , hs_rs = m_rs
- }
-
--- | Initiator: generate Act 1 message (50 bytes).
---
--- Takes local static key, remote static pubkey, and 32 bytes of
--- entropy for ephemeral key generation.
---
--- Returns the 50-byte Act 1 message and handshake state for Act 3.
---
--- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
--- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
--- >>> let eph_ent = BS.replicate 32 0x12
--- >>> case act1 i_sec i_pub r_pub eph_ent of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
--- 50
-act1
- :: Sec -- ^ local static secret
- -> Pub -- ^ local static public
- -> Pub -- ^ remote static public (responder's)
- -> BS.ByteString -- ^ 32 bytes entropy for ephemeral
- -> Either Error (BS.ByteString, HandshakeState)
-act1 s_sec s_pub rs ent = do
- (e_sec, e_pub) <- note InvalidKey (keypair ent)
- let !hs0 = init_handshake s_sec s_pub e_sec e_pub (Just rs) True
- !e_pub_bytes = serialize_pub e_pub
- !h1 = mix_hash (hs_h hs0) e_pub_bytes
- es <- note InvalidKey (ecdh e_sec rs)
- let !(ck1, temp_k1) = mix_key (hs_ck hs0) es
- c <- note InvalidMAC (encrypt_with_ad temp_k1 0 h1 BS.empty)
- let !h2 = mix_hash h1 c
- !msg = BS.singleton 0x00 <> e_pub_bytes <> c
- !hs1 = hs0 {
- hs_h = h2
- , hs_ck = ck1
- , hs_temp_k = temp_k1
- }
- pure (msg, hs1)
-
--- | Responder: process Act 1 and generate Act 2 message (50 bytes).
---
--- Takes local static key and 32 bytes of entropy for ephemeral key,
--- plus the 50-byte Act 1 message from initiator.
---
--- Returns the 50-byte Act 2 message and handshake state for finalize.
---
--- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
--- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
--- >>> let Right (msg1, _) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
--- >>> case act2 r_sec r_pub (BS.replicate 32 0x22) msg1 of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
--- 50
-act2
- :: Sec -- ^ local static secret
- -> Pub -- ^ local static public
- -> BS.ByteString -- ^ 32 bytes entropy for ephemeral
- -> BS.ByteString -- ^ Act 1 message (50 bytes)
- -> Either Error (BS.ByteString, HandshakeState)
-act2 s_sec s_pub ent msg1 = do
- require (BS.length msg1 == 50) InvalidLength
- let !version = BS.index msg1 0
- !re_bytes = BS.take 33 (BS.drop 1 msg1)
- !c = BS.drop 34 msg1
- require (version == 0x00) InvalidVersion
- re <- note InvalidPub (parse_pub re_bytes)
- (e_sec, e_pub) <- note InvalidKey (keypair ent)
- let !hs0 = init_handshake s_sec s_pub e_sec e_pub Nothing False
- !h1 = mix_hash (hs_h hs0) re_bytes
- es <- note InvalidKey (ecdh s_sec re)
- let !(ck1, temp_k1) = mix_key (hs_ck hs0) es
- _ <- note InvalidMAC (decrypt_with_ad temp_k1 0 h1 c)
- let !h2 = mix_hash h1 c
- !e_pub_bytes = serialize_pub e_pub
- !h3 = mix_hash h2 e_pub_bytes
- ee <- note InvalidKey (ecdh e_sec re)
- let !(ck2, temp_k2) = mix_key ck1 ee
- c2 <- note InvalidMAC (encrypt_with_ad temp_k2 0 h3 BS.empty)
- let !h4 = mix_hash h3 c2
- !msg = BS.singleton 0x00 <> e_pub_bytes <> c2
- !hs1 = hs0 {
- hs_h = h4
- , hs_ck = ck2
- , hs_temp_k = temp_k2
- , hs_re = Just re
- }
- pure (msg, hs1)
-
--- | Initiator: process Act 2 and generate Act 3 (66 bytes), completing
--- the handshake.
---
--- Returns the 66-byte Act 3 message and the handshake result.
---
--- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
--- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
--- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
--- >>> let Right (msg2, _) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
--- >>> case act3 i_hs msg2 of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
--- 66
-act3
- :: HandshakeState -- ^ state after Act 1
- -> BS.ByteString -- ^ Act 2 message (50 bytes)
- -> Either Error (BS.ByteString, Handshake)
-act3 hs msg2 = do
- require (BS.length msg2 == 50) InvalidLength
- let !version = BS.index msg2 0
- !re_bytes = BS.take 33 (BS.drop 1 msg2)
- !c = BS.drop 34 msg2
- require (version == 0x00) InvalidVersion
- re <- note InvalidPub (parse_pub re_bytes)
- let !h1 = mix_hash (hs_h hs) re_bytes
- ee <- note InvalidKey (ecdh (hs_e_sec hs) re)
- let !(ck1, temp_k2) = mix_key (hs_ck hs) ee
- _ <- note InvalidMAC (decrypt_with_ad temp_k2 0 h1 c)
- let !h2 = mix_hash h1 c
- !s_pub_bytes = serialize_pub (hs_s_pub hs)
- c3 <- note InvalidMAC (encrypt_with_ad temp_k2 1 h2 s_pub_bytes)
- let !h3 = mix_hash h2 c3
- se <- note InvalidKey (ecdh (hs_s_sec hs) re)
- let !(ck2, temp_k3) = mix_key ck1 se
- t <- note InvalidMAC (encrypt_with_ad temp_k3 0 h3 BS.empty)
- let !(sk, rk) = mix_key ck2 BS.empty
- !msg = BS.singleton 0x00 <> c3 <> t
- !sess = Session {
- sess_sk = sk
- , sess_sn = 0
- , sess_sck = ck2
- , sess_rk = rk
- , sess_rn = 0
- , sess_rck = ck2
- }
- rs <- note InvalidPub (hs_rs hs)
- let !result = Handshake {
- session = sess
- , remote_static = rs
- }
- pure (msg, result)
-
--- | Responder: process Act 3 (66 bytes) and complete the handshake.
---
--- Returns the handshake result with authenticated remote static pubkey.
---
--- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
--- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
--- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
--- >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
--- >>> let Right (msg3, _) = act3 i_hs msg2
--- >>> case finalize r_hs msg3 of { Right _ -> "ok"; Left e -> show e }
--- "ok"
-finalize
- :: HandshakeState -- ^ state after Act 2
- -> BS.ByteString -- ^ Act 3 message (66 bytes)
- -> Either Error Handshake
-finalize hs msg3 = do
- require (BS.length msg3 == 66) InvalidLength
- let !version = BS.index msg3 0
- !c = BS.take 49 (BS.drop 1 msg3)
- !t = BS.drop 50 msg3
- require (version == 0x00) InvalidVersion
- rs_bytes <- note InvalidMAC (decrypt_with_ad (hs_temp_k hs) 1 (hs_h hs) c)
- rs <- note InvalidPub (parse_pub rs_bytes)
- let !h1 = mix_hash (hs_h hs) c
- se <- note InvalidKey (ecdh (hs_e_sec hs) rs)
- let !(ck1, temp_k3) = mix_key (hs_ck hs) se
- _ <- note InvalidMAC (decrypt_with_ad temp_k3 0 h1 t)
- -- responder swaps order (receives what initiator sends)
- let !(rk, sk) = mix_key ck1 BS.empty
- !sess = Session {
- sess_sk = sk
- , sess_sn = 0
- , sess_sck = ck1
- , sess_rk = rk
- , sess_rn = 0
- , sess_rck = ck1
- }
- !result = Handshake {
- session = sess
- , remote_static = rs
- }
- pure result
-
--- message encryption --------------------------------------------------------
-
--- | Encrypt a message (max 65535 bytes).
---
--- Returns the encrypted packet and updated session. Key rotation
--- is handled automatically at nonce 1000.
---
--- Wire format: encrypted_length (2) || MAC (16) || encrypted_body || MAC (16)
---
--- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
--- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
--- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
--- >>> let Right (msg2, _) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
--- >>> let Right (_, i_result) = act3 i_hs msg2
--- >>> let sess = session i_result
--- >>> case encrypt sess "hello" of { Right (ct, _) -> BS.length ct; Left _ -> 0 }
--- 39
-encrypt
- :: Session
- -> BS.ByteString -- ^ plaintext (max 65535 bytes)
- -> Either Error (BS.ByteString, Session)
-encrypt sess pt = do
- let !len = BS.length pt
- require (len <= 65535) InvalidLength
- let !len_bytes = encode_be16 (fi len)
- lc <- note InvalidMAC (encrypt_with_ad (sess_sk sess) (sess_sn sess)
- BS.empty len_bytes)
- let !(sn1, sck1, sk1) = step_nonce (sess_sn sess) (sess_sck sess) (sess_sk sess)
- bc <- note InvalidMAC (encrypt_with_ad sk1 sn1 BS.empty pt)
- let !(sn2, sck2, sk2) = step_nonce sn1 sck1 sk1
- !packet = lc <> bc
- !sess' = sess {
- sess_sk = sk2
- , sess_sn = sn2
- , sess_sck = sck2
- }
- pure (packet, sess')
-
--- | Decrypt a message, requiring an exact packet with no trailing bytes.
---
--- Returns the plaintext and updated session. Key rotation
--- is handled automatically at nonce 1000.
---
--- This is a strict variant that rejects any trailing data. For
--- streaming use cases where you need to handle multiple frames in a
--- buffer, use 'decrypt_frame' instead.
---
--- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
--- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
--- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
--- >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
--- >>> let Right (msg3, i_result) = act3 i_hs msg2
--- >>> let Right r_result = finalize r_hs msg3
--- >>> let Right (ct, _) = encrypt (session i_result) "hello"
--- >>> case decrypt (session r_result) ct of { Right (pt, _) -> pt; Left _ -> "fail" }
--- "hello"
-decrypt
- :: Session
- -> BS.ByteString -- ^ encrypted packet (exact length required)
- -> Either Error (BS.ByteString, Session)
-decrypt sess packet = do
- (pt, remainder, sess') <- decrypt_frame sess packet
- require (BS.null remainder) InvalidLength
- pure (pt, sess')
-
--- | Decrypt a single frame from a buffer, returning the remainder.
---
--- Returns the plaintext, any unconsumed bytes, and the updated session.
--- Key rotation is handled automatically every 1000 messages.
---
--- This is useful for streaming scenarios where multiple messages may
--- be buffered together. The remainder can be passed to the next call
--- to 'decrypt_frame'.
---
--- Wire format consumed: encrypted_length (18) || encrypted_body (len + 16)
---
--- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
--- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
--- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
--- >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
--- >>> let Right (msg3, i_result) = act3 i_hs msg2
--- >>> let Right r_result = finalize r_hs msg3
--- >>> let Right (ct, _) = encrypt (session i_result) "hello"
--- >>> case decrypt_frame (session r_result) ct of { Right (pt, rem, _) -> (pt, BS.null rem); Left _ -> ("fail", False) }
--- ("hello",True)
-decrypt_frame
- :: Session
- -> BS.ByteString -- ^ buffer containing at least one encrypted frame
- -> Either Error (BS.ByteString, BS.ByteString, Session)
-decrypt_frame sess packet = do
- require (BS.length packet >= 34) InvalidLength
- let !lc = BS.take 18 packet
- !rest = BS.drop 18 packet
- len_bytes <- note InvalidMAC (decrypt_with_ad (sess_rk sess) (sess_rn sess)
- BS.empty lc)
- len <- note InvalidLength (decode_be16 len_bytes)
- let !(rn1, rck1, rk1) = step_nonce (sess_rn sess) (sess_rck sess) (sess_rk sess)
- !body_len = fi len + 16
- require (BS.length rest >= body_len) InvalidLength
- let !bc = BS.take body_len rest
- !remainder = BS.drop body_len rest
- pt <- note InvalidMAC (decrypt_with_ad rk1 rn1 BS.empty bc)
- let !(rn2, rck2, rk2) = step_nonce rn1 rck1 rk1
- !sess' = sess {
- sess_rk = rk2
- , sess_rn = rn2
- , sess_rck = rck2
- }
- pure (pt, remainder, sess')
-
--- | Decrypt a frame from a partial buffer, indicating when more data needed.
---
--- Unlike 'decrypt_frame', this function handles incomplete buffers
--- gracefully by returning 'NeedMore' with the number of additional
--- bytes required to make progress.
---
--- * If the buffer has fewer than 18 bytes (encrypted length + MAC),
--- returns @'NeedMore' n@ where @n@ is the bytes still needed.
--- * If the length header is complete but the body is incomplete,
--- returns @'NeedMore' n@ with bytes needed for the full frame.
--- * MAC or decryption failures return 'FrameError'.
--- * A complete, valid frame returns 'FrameOk' with plaintext,
--- remainder, and updated session.
---
--- This is useful for non-blocking I/O where data arrives incrementally.
-decrypt_frame_partial
- :: Session
- -> BS.ByteString -- ^ buffer (possibly incomplete)
- -> FrameResult
-decrypt_frame_partial sess buf
- | buflen < 18 = NeedMore (18 - buflen)
- | otherwise =
- let !lc = BS.take 18 buf
- !rest = BS.drop 18 buf
- in case decrypt_with_ad (sess_rk sess) (sess_rn sess) BS.empty lc of
- Nothing -> FrameError InvalidMAC
- Just len_bytes -> case decode_be16 len_bytes of
- Nothing -> FrameError InvalidLength
- Just len ->
- let !body_len = fi len + 16
- !(rn1, rck1, rk1) = step_nonce (sess_rn sess)
- (sess_rck sess) (sess_rk sess)
- in if BS.length rest < body_len
- then NeedMore (body_len - BS.length rest)
- else
- let !bc = BS.take body_len rest
- !remainder = BS.drop body_len rest
- in case decrypt_with_ad rk1 rn1 BS.empty bc of
- Nothing -> FrameError InvalidMAC
- Just pt ->
- let !(rn2, rck2, rk2) = step_nonce rn1 rck1 rk1
- !sess' = sess {
- sess_rk = rk2
- , sess_rn = rn2
- , sess_rck = rck2
- }
- in FrameOk pt remainder sess'
- where
- !buflen = BS.length buf
-
--- key rotation --------------------------------------------------------------
-
--- Key rotation occurs after nonce reaches 1000 (i.e., before using 1000)
--- (ck', k') = HKDF(ck, k), reset nonce to 0
-step_nonce
- :: Word64
- -> BS.ByteString
- -> BS.ByteString
- -> (Word64, BS.ByteString, BS.ByteString)
-step_nonce n ck k
- | n + 1 == 1000 =
- let !(ck', k') = mix_key ck k
- in (0, ck', k')
- | otherwise = (n + 1, ck, k)
-
--- utilities -----------------------------------------------------------------
-
--- Lift Maybe to Either
-note :: e -> Maybe a -> Either e a
-note e = maybe (Left e) Right
-{-# INLINE note #-}
-
--- Require condition or fail
-require :: Bool -> e -> Either e ()
-require cond e = unless cond (Left e)
-{-# INLINE require #-}
-
-fi :: (Integral a, Num b) => a -> b
-fi = fromIntegral
-{-# INLINE fi #-}
+import Lightning.Protocol.BOLT8.Internal
diff --git a/lib/Lightning/Protocol/BOLT8/Internal.hs b/lib/Lightning/Protocol/BOLT8/Internal.hs
@@ -0,0 +1,818 @@
+{-# OPTIONS_HADDOCK prune #-}
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE RecordWildCards #-}
+{-# LANGUAGE ViewPatterns #-}
+
+-- |
+-- Module: Lightning.Protocol.BOLT8.Internal
+-- Copyright: (c) 2025 Jared Tobin
+-- License: MIT
+-- Maintainer: Jared Tobin <jared@ppad.tech>
+--
+-- Internal module exporting all constructors for testing and
+-- benchmarking. Prefer "Lightning.Protocol.BOLT8" for general use.
+
+module Lightning.Protocol.BOLT8.Internal (
+ -- * Keys
+ Sec(..)
+ , Pub(..)
+ , keypair
+ , parse_pub
+ , serialize_pub
+
+ -- * Newtypes
+ , Key32(..)
+ , key32
+ , unsafeKey32
+ , SessionNonce(..)
+ , MessagePayload(..)
+ , mkMessagePayload
+
+ -- * Handshake roles
+ , Initiator
+ , Responder
+ , HandshakeFor(..)
+
+ -- * Handshake (initiator)
+ , act1
+ , act3
+
+ -- * Handshake (responder)
+ , act2
+ , finalize
+
+ -- * Session
+ , Session(..)
+ , HandshakeState(..)
+ , Handshake(..)
+ , encrypt
+ , decrypt
+ , decrypt_frame
+ , decrypt_frame_partial
+ , FrameResult(..)
+
+ -- * Errors
+ , Error(..)
+ ) where
+
+import Control.Monad (guard, unless)
+import qualified Crypto.AEAD.ChaCha20Poly1305 as AEAD
+import qualified Crypto.Curve.Secp256k1 as Secp256k1
+import qualified Crypto.Hash.SHA256 as SHA256
+import qualified Crypto.KDF.HMAC as HKDF
+import Data.Bits (unsafeShiftR, (.&.))
+import qualified Data.ByteString as BS
+import Data.Word (Word16, Word64)
+import GHC.Generics (Generic)
+
+-- types -----------------------------------------------------------
+
+-- | Secret key (32 bytes).
+newtype Sec = Sec BS.ByteString
+ deriving (Eq, Generic)
+
+-- | Compressed public key.
+newtype Pub = Pub Secp256k1.Projective
+
+instance Eq Pub where
+ (Pub a) == (Pub b) =
+ Secp256k1.serialize_point a
+ == Secp256k1.serialize_point b
+
+instance Show Pub where
+ show (Pub p) =
+ "Pub " ++ show (Secp256k1.serialize_point p)
+
+-- | A 32-byte key, validated at construction.
+newtype Key32 = Key32 { unKey32 :: BS.ByteString }
+ deriving (Eq, Generic)
+
+-- | Construct a 'Key32' from a 32-byte 'BS.ByteString'.
+--
+-- Returns 'Nothing' if the input is not exactly 32 bytes.
+--
+-- >>> key32 (BS.replicate 32 0x00)
+-- Just (Key32 {unKey32 = ...})
+-- >>> key32 (BS.replicate 31 0x00)
+-- Nothing
+key32 :: BS.ByteString -> Maybe Key32
+key32 bs
+ | BS.length bs == 32 = Just (Key32 bs)
+ | otherwise = Nothing
+
+-- | Construct a 'Key32' without validation.
+--
+-- For test and benchmark use only; prefer 'key32'.
+unsafeKey32 :: BS.ByteString -> Key32
+unsafeKey32 = Key32
+
+-- | Session nonce, distinguishing send from receive direction.
+newtype SessionNonce =
+ SessionNonce { unSessionNonce :: Word64 }
+ deriving (Eq, Generic)
+
+-- | Message payload (max 65535 bytes), validated at construction.
+newtype MessagePayload =
+ MessagePayload { unMessagePayload :: BS.ByteString }
+ deriving (Eq, Generic)
+
+-- | Construct a 'MessagePayload' from a 'BS.ByteString'.
+--
+-- Returns 'Left' if the payload exceeds 65535 bytes.
+mkMessagePayload
+ :: BS.ByteString -> Either Error MessagePayload
+mkMessagePayload bs
+ | BS.length bs > 65535 = Left InvalidLength
+ | otherwise = Right (MessagePayload bs)
+
+-- | Handshake errors.
+data Error =
+ InvalidKey
+ | InvalidPub
+ | InvalidMAC
+ | InvalidVersion
+ | InvalidLength
+ | DecryptionFailed
+ deriving (Eq, Show, Generic)
+
+-- | Result of attempting to decrypt a frame from a partial
+-- buffer.
+data FrameResult =
+ NeedMore {-# UNPACK #-} !Int
+ -- ^ More bytes needed; the 'Int' is the minimum
+ -- additional bytes required.
+ | FrameOk !BS.ByteString !BS.ByteString !Session
+ -- ^ Successfully decrypted: plaintext, remainder,
+ -- updated session.
+ | FrameError !Error
+ -- ^ Decryption failed with the given error.
+ deriving Generic
+
+-- | Post-handshake session state.
+data Session = Session {
+ sess_sk :: !Key32
+ -- ^ send key (32 bytes)
+ , sess_sn :: !SessionNonce
+ -- ^ send nonce
+ , sess_sck :: !Key32
+ -- ^ send chaining key
+ , sess_rk :: !Key32
+ -- ^ receive key (32 bytes)
+ , sess_rn :: !SessionNonce
+ -- ^ receive nonce
+ , sess_rck :: !Key32
+ -- ^ receive chaining key
+ }
+ deriving Generic
+
+-- | Result of a successful handshake.
+data Handshake = Handshake {
+ session :: !Session
+ -- ^ session state
+ , remote_static :: !Pub
+ -- ^ authenticated remote static pubkey
+ }
+ deriving Generic
+
+-- | Internal handshake state (exported for benchmarking).
+data HandshakeState = HandshakeState {
+ hs_h :: {-# UNPACK #-} !BS.ByteString
+ -- ^ handshake hash (32 bytes)
+ , hs_ck :: {-# UNPACK #-} !BS.ByteString
+ -- ^ chaining key (32 bytes)
+ , hs_temp_k :: {-# UNPACK #-} !BS.ByteString
+ -- ^ temp key (32 bytes)
+ , hs_e_sec :: !Sec
+ -- ^ ephemeral secret
+ , hs_e_pub :: !Pub
+ -- ^ ephemeral public
+ , hs_s_sec :: !Sec
+ -- ^ static secret
+ , hs_s_pub :: !Pub
+ -- ^ static public
+ , hs_re :: !(Maybe Pub)
+ -- ^ remote ephemeral
+ , hs_rs :: !(Maybe Pub)
+ -- ^ remote static
+ }
+ deriving Generic
+
+-- handshake roles -------------------------------------------------
+
+-- | Phantom type for initiator role.
+data Initiator
+
+-- | Phantom type for responder role.
+data Responder
+
+-- | Role-indexed handshake state.
+--
+-- The phantom type parameter prevents passing an initiator's
+-- state to a responder function and vice versa.
+data HandshakeFor a =
+ HandshakeFor { unHandshakeFor :: !HandshakeState }
+
+-- protocol constants ----------------------------------------------
+
+_PROTOCOL_NAME :: BS.ByteString
+_PROTOCOL_NAME =
+ "Noise_XK_secp256k1_ChaChaPoly_SHA256"
+
+_PROLOGUE :: BS.ByteString
+_PROLOGUE = "lightning"
+
+-- key operations --------------------------------------------------
+
+-- | Derive a keypair from 32 bytes of entropy.
+--
+-- Returns Nothing if the entropy is invalid
+-- (zero or >= curve order).
+--
+-- >>> let ent = BS.replicate 32 0x11
+-- >>> case keypair ent of
+-- ... Just _ -> "ok"
+-- ... Nothing -> "fail"
+-- "ok"
+-- >>> keypair (BS.replicate 31 0x11) -- wrong length
+-- Nothing
+keypair :: BS.ByteString -> Maybe (Sec, Pub)
+keypair ent = do
+ guard (BS.length ent == 32)
+ k <- Secp256k1.parse_int256 ent
+ p <- Secp256k1.derive_pub k
+ pure (Sec ent, Pub p)
+
+-- | Parse a 33-byte compressed public key.
+--
+-- >>> let Just (_, pub) = keypair (BS.replicate 32 0x11)
+-- >>> let bytes = serialize_pub pub
+-- >>> case parse_pub bytes of
+-- ... Just _ -> "ok"
+-- ... Nothing -> "fail"
+-- "ok"
+-- >>> parse_pub (BS.replicate 32 0x00) -- wrong length
+-- Nothing
+parse_pub :: BS.ByteString -> Maybe Pub
+parse_pub bs = do
+ guard (BS.length bs == 33)
+ p <- Secp256k1.parse_point bs
+ pure (Pub p)
+
+-- | Serialize a public key to 33-byte compressed form.
+--
+-- >>> let Just (_, pub) = keypair (BS.replicate 32 0x11)
+-- >>> BS.length (serialize_pub pub)
+-- 33
+serialize_pub :: Pub -> BS.ByteString
+serialize_pub (Pub p) = Secp256k1.serialize_point p
+
+-- cryptographic primitives ----------------------------------------
+
+-- bolt8-style ECDH
+ecdh :: Sec -> Pub -> Maybe BS.ByteString
+ecdh (Sec sec) (Pub pub) = do
+ k <- Secp256k1.parse_int256 sec
+ pt <- Secp256k1.mul pub k
+ let compressed = Secp256k1.serialize_point pt
+ pure (SHA256.hash compressed)
+
+-- h' = SHA256(h || data)
+mix_hash
+ :: BS.ByteString -> BS.ByteString -> BS.ByteString
+mix_hash h dat = SHA256.hash (h <> dat)
+
+-- Mix key: (ck', k) = HKDF(ck, input_key_material)
+--
+-- NB HKDF limits output to 255 * hashlen bytes. For SHA256
+-- that's 8160, well above the 64 bytes requested here, so
+-- 'Nothing' is impossible.
+mix_key
+ :: BS.ByteString
+ -> BS.ByteString
+ -> (BS.ByteString, BS.ByteString)
+mix_key ck ikm =
+ case HKDF.derive hmac ck mempty 64 ikm of
+ Nothing ->
+ error
+ "ppad-bolt8: internal error, please report a bug!"
+ Just output -> BS.splitAt 32 output
+ where
+ hmac k b = case SHA256.hmac k b of
+ SHA256.MAC mac -> mac
+
+-- Encrypt with associated data using ChaCha20-Poly1305
+encrypt_with_ad
+ :: BS.ByteString -- ^ key (32 bytes)
+ -> Word64 -- ^ nonce
+ -> BS.ByteString -- ^ associated data
+ -> BS.ByteString -- ^ plaintext
+ -> Maybe BS.ByteString -- ^ ciphertext || mac (16 bytes)
+encrypt_with_ad key n ad pt =
+ case AEAD.encrypt ad key (encode_nonce n) pt of
+ Left _ -> Nothing
+ Right (ct, mac) -> Just (ct <> mac)
+
+-- Decrypt with associated data using ChaCha20-Poly1305
+decrypt_with_ad
+ :: BS.ByteString -- ^ key (32 bytes)
+ -> Word64 -- ^ nonce
+ -> BS.ByteString -- ^ associated data
+ -> BS.ByteString -- ^ ciphertext || mac
+ -> Maybe BS.ByteString -- ^ plaintext
+decrypt_with_ad key n ad ctmac
+ | BS.length ctmac < 16 = Nothing
+ | otherwise =
+ let (ct, mac) =
+ BS.splitAt (BS.length ctmac - 16) ctmac
+ in case AEAD.decrypt ad key (encode_nonce n)
+ (ct, mac) of
+ Left _ -> Nothing
+ Right pt -> Just pt
+
+-- Encode nonce as 96-bit value: 4 zero bytes + 8-byte LE
+encode_nonce :: Word64 -> BS.ByteString
+encode_nonce n = BS.replicate 4 0x00 <> encode_le64 n
+
+-- Little-endian 64-bit encoding
+encode_le64 :: Word64 -> BS.ByteString
+encode_le64 n = BS.pack [
+ fi (n .&. 0xff)
+ , fi (unsafeShiftR n 8 .&. 0xff)
+ , fi (unsafeShiftR n 16 .&. 0xff)
+ , fi (unsafeShiftR n 24 .&. 0xff)
+ , fi (unsafeShiftR n 32 .&. 0xff)
+ , fi (unsafeShiftR n 40 .&. 0xff)
+ , fi (unsafeShiftR n 48 .&. 0xff)
+ , fi (unsafeShiftR n 56 .&. 0xff)
+ ]
+
+-- Big-endian 16-bit encoding
+encode_be16 :: Word16 -> BS.ByteString
+encode_be16 n =
+ BS.pack [fi (unsafeShiftR n 8), fi (n .&. 0xff)]
+
+-- Big-endian 16-bit decoding
+decode_be16 :: BS.ByteString -> Maybe Word16
+decode_be16 bs
+ | BS.length bs /= 2 = Nothing
+ | otherwise =
+ let !b0 = BS.index bs 0
+ !b1 = BS.index bs 1
+ in Just (fi b0 * 0x100 + fi b1)
+
+-- handshake -------------------------------------------------------
+
+-- Initialize handshake state
+--
+-- h = SHA256(protocol_name)
+-- ck = h
+-- h = SHA256(h || prologue)
+-- h = SHA256(h || responder_static_pubkey)
+init_handshake
+ :: Sec -- ^ local static secret
+ -> Pub -- ^ local static public
+ -> Sec -- ^ ephemeral secret
+ -> Pub -- ^ ephemeral public
+ -> Maybe Pub -- ^ remote static
+ -> Bool -- ^ True if initiator
+ -> HandshakeState
+init_handshake s_sec s_pub e_sec e_pub m_rs is_init =
+ let !h0 = SHA256.hash _PROTOCOL_NAME
+ !ck = h0
+ !h1 = mix_hash h0 _PROLOGUE
+ -- Mix in responder's static pubkey
+ !h2 = case (is_init, m_rs) of
+ (True, Just rs) ->
+ mix_hash h1 (serialize_pub rs)
+ (False, Nothing) ->
+ mix_hash h1 (serialize_pub s_pub)
+ _ -> h1 -- shouldn't happen
+ in HandshakeState {
+ hs_h = h2
+ , hs_ck = ck
+ , hs_temp_k = BS.replicate 32 0x00
+ , hs_e_sec = e_sec
+ , hs_e_pub = e_pub
+ , hs_s_sec = s_sec
+ , hs_s_pub = s_pub
+ , hs_re = Nothing
+ , hs_rs = m_rs
+ }
+
+-- | Initiator: generate Act 1 message (50 bytes).
+--
+-- Takes local static key, remote static pubkey, and 32
+-- bytes of entropy for ephemeral key generation.
+--
+-- Returns the 50-byte Act 1 message and handshake state
+-- for Act 3.
+--
+-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
+-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
+-- >>> let eph_ent = BS.replicate 32 0x12
+-- >>> case act1 i_sec i_pub r_pub eph_ent of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
+-- 50
+act1
+ :: Sec
+ -> Pub
+ -> Pub
+ -> BS.ByteString
+ -> Either Error
+ (BS.ByteString, HandshakeFor Initiator)
+act1 s_sec s_pub rs ent = do
+ (e_sec, e_pub) <- note InvalidKey (keypair ent)
+ let !hs0 = init_handshake
+ s_sec s_pub e_sec e_pub (Just rs) True
+ !e_pub_bytes = serialize_pub e_pub
+ !h1 = mix_hash (hs_h hs0) e_pub_bytes
+ es <- note InvalidKey (ecdh e_sec rs)
+ let !(ck1, temp_k1) = mix_key (hs_ck hs0) es
+ c <- note InvalidMAC
+ (encrypt_with_ad temp_k1 0 h1 BS.empty)
+ let !h2 = mix_hash h1 c
+ !msg = BS.singleton 0x00 <> e_pub_bytes <> c
+ !hs1 = hs0 {
+ hs_h = h2
+ , hs_ck = ck1
+ , hs_temp_k = temp_k1
+ }
+ pure (msg, HandshakeFor hs1)
+
+-- | Responder: process Act 1 and generate Act 2 message
+-- (50 bytes).
+--
+-- Takes local static key and 32 bytes of entropy for
+-- ephemeral key, plus the 50-byte Act 1 message from
+-- initiator.
+--
+-- Returns the 50-byte Act 2 message and handshake state
+-- for finalize.
+--
+-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
+-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
+-- >>> let Right (msg1, _) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
+-- >>> case act2 r_sec r_pub (BS.replicate 32 0x22) msg1 of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
+-- 50
+act2
+ :: Sec
+ -> Pub
+ -> BS.ByteString
+ -> BS.ByteString
+ -> Either Error
+ (BS.ByteString, HandshakeFor Responder)
+act2 s_sec s_pub ent msg1 = do
+ require (BS.length msg1 == 50) InvalidLength
+ let !version = BS.index msg1 0
+ !re_bytes = BS.take 33 (BS.drop 1 msg1)
+ !c = BS.drop 34 msg1
+ require (version == 0x00) InvalidVersion
+ re <- note InvalidPub (parse_pub re_bytes)
+ (e_sec, e_pub) <- note InvalidKey (keypair ent)
+ let !hs0 = init_handshake
+ s_sec s_pub e_sec e_pub Nothing False
+ !h1 = mix_hash (hs_h hs0) re_bytes
+ es <- note InvalidKey (ecdh s_sec re)
+ let !(ck1, temp_k1) = mix_key (hs_ck hs0) es
+ _ <- note InvalidMAC
+ (decrypt_with_ad temp_k1 0 h1 c)
+ let !h2 = mix_hash h1 c
+ !e_pub_bytes = serialize_pub e_pub
+ !h3 = mix_hash h2 e_pub_bytes
+ ee <- note InvalidKey (ecdh e_sec re)
+ let !(ck2, temp_k2) = mix_key ck1 ee
+ c2 <- note InvalidMAC
+ (encrypt_with_ad temp_k2 0 h3 BS.empty)
+ let !h4 = mix_hash h3 c2
+ !msg = BS.singleton 0x00 <> e_pub_bytes <> c2
+ !hs1 = hs0 {
+ hs_h = h4
+ , hs_ck = ck2
+ , hs_temp_k = temp_k2
+ , hs_re = Just re
+ }
+ pure (msg, HandshakeFor hs1)
+
+-- | Initiator: process Act 2 and generate Act 3 (66 bytes),
+-- completing the handshake.
+--
+-- Returns the 66-byte Act 3 message and the handshake
+-- result.
+--
+-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
+-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
+-- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
+-- >>> let Right (msg2, _) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
+-- >>> case act3 i_hs msg2 of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
+-- 66
+act3
+ :: HandshakeFor Initiator
+ -> BS.ByteString
+ -> Either Error (BS.ByteString, Handshake)
+act3 (HandshakeFor hs) msg2 = do
+ require (BS.length msg2 == 50) InvalidLength
+ let !version = BS.index msg2 0
+ !re_bytes = BS.take 33 (BS.drop 1 msg2)
+ !c = BS.drop 34 msg2
+ require (version == 0x00) InvalidVersion
+ re <- note InvalidPub (parse_pub re_bytes)
+ let !h1 = mix_hash (hs_h hs) re_bytes
+ ee <- note InvalidKey (ecdh (hs_e_sec hs) re)
+ let !(ck1, temp_k2) = mix_key (hs_ck hs) ee
+ _ <- note InvalidMAC
+ (decrypt_with_ad temp_k2 0 h1 c)
+ let !h2 = mix_hash h1 c
+ !s_pub_bytes = serialize_pub (hs_s_pub hs)
+ c3 <- note InvalidMAC
+ (encrypt_with_ad temp_k2 1 h2 s_pub_bytes)
+ let !h3 = mix_hash h2 c3
+ se <- note InvalidKey (ecdh (hs_s_sec hs) re)
+ let !(ck2, temp_k3) = mix_key ck1 se
+ t <- note InvalidMAC
+ (encrypt_with_ad temp_k3 0 h3 BS.empty)
+ let !(sk, rk) = mix_key ck2 BS.empty
+ !msg = BS.singleton 0x00 <> c3 <> t
+ !sess = Session {
+ sess_sk = Key32 sk
+ , sess_sn = SessionNonce 0
+ , sess_sck = Key32 ck2
+ , sess_rk = Key32 rk
+ , sess_rn = SessionNonce 0
+ , sess_rck = Key32 ck2
+ }
+ rs <- note InvalidPub (hs_rs hs)
+ let !result = Handshake {
+ session = sess
+ , remote_static = rs
+ }
+ pure (msg, result)
+
+-- | Responder: process Act 3 (66 bytes) and complete the
+-- handshake.
+--
+-- Returns the handshake result with authenticated remote
+-- static pubkey.
+--
+-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
+-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
+-- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
+-- >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
+-- >>> let Right (msg3, _) = act3 i_hs msg2
+-- >>> case finalize r_hs msg3 of { Right _ -> "ok"; Left e -> show e }
+-- "ok"
+finalize
+ :: HandshakeFor Responder
+ -> BS.ByteString
+ -> Either Error Handshake
+finalize (HandshakeFor hs) msg3 = do
+ require (BS.length msg3 == 66) InvalidLength
+ let !version = BS.index msg3 0
+ !c = BS.take 49 (BS.drop 1 msg3)
+ !t = BS.drop 50 msg3
+ require (version == 0x00) InvalidVersion
+ rs_bytes <- note InvalidMAC
+ (decrypt_with_ad (hs_temp_k hs) 1 (hs_h hs) c)
+ rs <- note InvalidPub (parse_pub rs_bytes)
+ let !h1 = mix_hash (hs_h hs) c
+ se <- note InvalidKey (ecdh (hs_e_sec hs) rs)
+ let !(ck1, temp_k3) = mix_key (hs_ck hs) se
+ _ <- note InvalidMAC
+ (decrypt_with_ad temp_k3 0 h1 t)
+ -- responder swaps order (receives what initiator sends)
+ let !(rk, sk) = mix_key ck1 BS.empty
+ !sess = Session {
+ sess_sk = Key32 sk
+ , sess_sn = SessionNonce 0
+ , sess_sck = Key32 ck1
+ , sess_rk = Key32 rk
+ , sess_rn = SessionNonce 0
+ , sess_rck = Key32 ck1
+ }
+ !result = Handshake {
+ session = sess
+ , remote_static = rs
+ }
+ pure result
+
+-- message encryption ----------------------------------------------
+
+-- | Encrypt a message (max 65535 bytes).
+--
+-- Returns the encrypted packet and updated session. Key
+-- rotation is handled automatically at nonce 1000.
+--
+-- Wire format:
+-- encrypted_length (2) || MAC (16)
+-- || encrypted_body || MAC (16)
+--
+-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
+-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
+-- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
+-- >>> let Right (msg2, _) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
+-- >>> let Right (_, i_result) = act3 i_hs msg2
+-- >>> let sess = session i_result
+-- >>> case encrypt sess "hello" of { Right (ct, _) -> BS.length ct; Left _ -> 0 }
+-- 39
+encrypt
+ :: Session
+ -> BS.ByteString
+ -> Either Error (BS.ByteString, Session)
+encrypt sess pt = do
+ let !len = BS.length pt
+ require (len <= 65535) InvalidLength
+ let !len_bytes = encode_be16 (fi len)
+ !sk = unKey32 (sess_sk sess)
+ !sn = unSessionNonce (sess_sn sess)
+ !sck = unKey32 (sess_sck sess)
+ lc <- note InvalidMAC
+ (encrypt_with_ad sk sn BS.empty len_bytes)
+ let !(sn1, sck1, sk1) = step_nonce sn sck sk
+ bc <- note InvalidMAC
+ (encrypt_with_ad sk1 sn1 BS.empty pt)
+ let !(sn2, sck2, sk2) = step_nonce sn1 sck1 sk1
+ !packet = lc <> bc
+ !sess' = sess {
+ sess_sk = Key32 sk2
+ , sess_sn = SessionNonce sn2
+ , sess_sck = Key32 sck2
+ }
+ pure (packet, sess')
+
+-- | Decrypt a message, requiring an exact packet with no
+-- trailing bytes.
+--
+-- Returns the plaintext and updated session. Key rotation
+-- is handled automatically at nonce 1000.
+--
+-- This is a strict variant that rejects any trailing data.
+-- For streaming use cases where you need to handle multiple
+-- frames in a buffer, use 'decrypt_frame' instead.
+--
+-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
+-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
+-- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
+-- >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
+-- >>> let Right (msg3, i_result) = act3 i_hs msg2
+-- >>> let Right r_result = finalize r_hs msg3
+-- >>> let Right (ct, _) = encrypt (session i_result) "hello"
+-- >>> case decrypt (session r_result) ct of { Right (pt, _) -> pt; Left _ -> "fail" }
+-- "hello"
+decrypt
+ :: Session
+ -> BS.ByteString
+ -> Either Error (BS.ByteString, Session)
+decrypt sess packet = do
+ (pt, remainder, sess') <- decrypt_frame sess packet
+ require (BS.null remainder) InvalidLength
+ pure (pt, sess')
+
+-- | Decrypt a single frame from a buffer, returning the
+-- remainder.
+--
+-- Returns the plaintext, any unconsumed bytes, and the
+-- updated session. Key rotation is handled automatically
+-- every 1000 messages.
+--
+-- This is useful for streaming scenarios where multiple
+-- messages may be buffered together. The remainder can be
+-- passed to the next call to 'decrypt_frame'.
+--
+-- Wire format consumed:
+-- encrypted_length (18) || encrypted_body (len + 16)
+--
+-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
+-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
+-- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
+-- >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
+-- >>> let Right (msg3, i_result) = act3 i_hs msg2
+-- >>> let Right r_result = finalize r_hs msg3
+-- >>> let Right (ct, _) = encrypt (session i_result) "hello"
+-- >>> case decrypt_frame (session r_result) ct of { Right (pt, rem, _) -> (pt, BS.null rem); Left _ -> ("fail", False) }
+-- ("hello",True)
+decrypt_frame
+ :: Session
+ -> BS.ByteString
+ -> Either Error
+ (BS.ByteString, BS.ByteString, Session)
+decrypt_frame sess packet = do
+ require (BS.length packet >= 34) InvalidLength
+ let !lc = BS.take 18 packet
+ !rest = BS.drop 18 packet
+ !rk = unKey32 (sess_rk sess)
+ !rn = unSessionNonce (sess_rn sess)
+ !rck = unKey32 (sess_rck sess)
+ len_bytes <- note InvalidMAC
+ (decrypt_with_ad rk rn BS.empty lc)
+ len <- note InvalidLength (decode_be16 len_bytes)
+ let !(rn1, rck1, rk1) = step_nonce rn rck rk
+ !body_len = fi len + 16
+ require (BS.length rest >= body_len) InvalidLength
+ let !bc = BS.take body_len rest
+ !remainder = BS.drop body_len rest
+ pt <- note InvalidMAC
+ (decrypt_with_ad rk1 rn1 BS.empty bc)
+ let !(rn2, rck2, rk2) = step_nonce rn1 rck1 rk1
+ !sess' = sess {
+ sess_rk = Key32 rk2
+ , sess_rn = SessionNonce rn2
+ , sess_rck = Key32 rck2
+ }
+ pure (pt, remainder, sess')
+
+-- | Decrypt a frame from a partial buffer, indicating when
+-- more data needed.
+--
+-- Unlike 'decrypt_frame', this function handles incomplete
+-- buffers gracefully by returning 'NeedMore' with the
+-- number of additional bytes required to make progress.
+--
+-- * If the buffer has fewer than 18 bytes (encrypted
+-- length + MAC), returns @'NeedMore' n@ where @n@ is
+-- the bytes still needed.
+-- * If the length header is complete but the body is
+-- incomplete, returns @'NeedMore' n@ with bytes needed
+-- for the full frame.
+-- * MAC or decryption failures return 'FrameError'.
+-- * A complete, valid frame returns 'FrameOk' with
+-- plaintext, remainder, and updated session.
+--
+-- This is useful for non-blocking I/O where data arrives
+-- incrementally.
+decrypt_frame_partial
+ :: Session
+ -> BS.ByteString
+ -> FrameResult
+decrypt_frame_partial sess buf
+ | buflen < 18 = NeedMore (18 - buflen)
+ | otherwise =
+ let !lc = BS.take 18 buf
+ !rest = BS.drop 18 buf
+ !rk = unKey32 (sess_rk sess)
+ !rn = unSessionNonce (sess_rn sess)
+ !rck = unKey32 (sess_rck sess)
+ in case decrypt_with_ad rk rn BS.empty lc of
+ Nothing -> FrameError InvalidMAC
+ Just len_bytes ->
+ case decode_be16 len_bytes of
+ Nothing -> FrameError InvalidLength
+ Just len ->
+ let !body_len = fi len + 16
+ !(rn1, rck1, rk1) =
+ step_nonce rn rck rk
+ in if BS.length rest < body_len
+ then NeedMore
+ (body_len - BS.length rest)
+ else
+ let !bc = BS.take body_len rest
+ !remainder =
+ BS.drop body_len rest
+ in case decrypt_with_ad
+ rk1 rn1 BS.empty bc of
+ Nothing ->
+ FrameError InvalidMAC
+ Just pt ->
+ let !(rn2, rck2, rk2) =
+ step_nonce rn1 rck1 rk1
+ !sess' = sess {
+ sess_rk = Key32 rk2
+ , sess_rn =
+ SessionNonce rn2
+ , sess_rck = Key32 rck2
+ }
+ in FrameOk pt remainder sess'
+ where
+ !buflen = BS.length buf
+
+-- key rotation ----------------------------------------------------
+
+-- Key rotation occurs after nonce reaches 1000 (i.e., before
+-- using 1000)
+-- (ck', k') = HKDF(ck, k), reset nonce to 0
+step_nonce
+ :: Word64
+ -> BS.ByteString
+ -> BS.ByteString
+ -> (Word64, BS.ByteString, BS.ByteString)
+step_nonce n ck k
+ | n + 1 == 1000 =
+ let !(ck', k') = mix_key ck k
+ in (0, ck', k')
+ | otherwise = (n + 1, ck, k)
+
+-- utilities -------------------------------------------------------
+
+-- Lift Maybe to Either
+note :: e -> Maybe a -> Either e a
+note e = maybe (Left e) Right
+{-# INLINE note #-}
+
+-- Require condition or fail
+require :: Bool -> e -> Either e ()
+require cond e = unless cond (Left e)
+{-# INLINE require #-}
+
+fi :: (Integral a, Num b) => a -> b
+fi = fromIntegral
+{-# INLINE fi #-}
diff --git a/ppad-bolt8.cabal b/ppad-bolt8.cabal
@@ -25,6 +25,7 @@ library
-Wall
exposed-modules:
Lightning.Protocol.BOLT8
+ Lightning.Protocol.BOLT8.Internal
build-depends:
base >= 4.9 && < 5
, bytestring >= 0.9 && < 0.13