bolt8

Encrypted and authenticated transport, per BOLT #8 (docs.ppad.tech/bolt8).
git clone git://git.ppad.tech/bolt8.git
Log | Files | Refs | README | LICENSE

commit fe70c6219be62ecdfcefae0cc52620b34ad12263
parent ca737a5e3d1970d51078cee6cc10279351759870
Author: Jared Tobin <jared@jtobin.io>
Date:   Mon, 20 Apr 2026 14:59:57 +0800

Merge branch 'impl/type-safety'

Type system improvements to prevent misuse at compile time:

- SessionNonce newtype: distinguishes send/receive nonces
- Key32 newtype: validates 32-byte key length via smart
  constructor; used for all session keys
- MessagePayload newtype: validates max 65535-byte payload
  length via smart constructor
- HandshakeFor phantom type: indexed by Initiator/Responder
  to prevent passing wrong role's state to act3/finalize

Split library into public module (safe API) and Internal
module (constructors exposed for tests/benchmarks).

No behavioural changes. All 26 tests pass.

Diffstat:
Mbench/Main.hs | 80+++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------------
Mbench/Weight.hs | 63+++++++++++++++++++++++++++++++++++++++++++--------------------
Mlib/Lightning/Protocol/BOLT8.hs | 665+++++--------------------------------------------------------------------------
Alib/Lightning/Protocol/BOLT8/Internal.hs | 818+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Mppad-bolt8.cabal | 1+
5 files changed, 951 insertions(+), 676 deletions(-)

diff --git a/bench/Main.hs b/bench/Main.hs @@ -7,16 +7,20 @@ module Main where import Control.DeepSeq import Criterion.Main import qualified Data.ByteString as BS -import qualified Lightning.Protocol.BOLT8 as BOLT8 +import qualified Lightning.Protocol.BOLT8.Internal as BOLT8 instance NFData BOLT8.Pub where rnf p = rnf (BOLT8.serialize_pub p) instance NFData BOLT8.Sec instance NFData BOLT8.Error +instance NFData BOLT8.Key32 +instance NFData BOLT8.SessionNonce instance NFData BOLT8.Session instance NFData BOLT8.HandshakeState instance NFData BOLT8.Handshake +instance NFData (BOLT8.HandshakeFor a) where + rnf (BOLT8.HandshakeFor s) = rnf s main :: IO () main = defaultMain [ @@ -36,50 +40,74 @@ keys :: Benchmark keys = bgroup "keys" [ bench "keypair" $ nf BOLT8.keypair i_s_ent , bench "parse_pub" $ nf BOLT8.parse_pub r_s_pub_bs - , bench "serialize_pub" $ nf BOLT8.serialize_pub r_s_pub + , bench "serialize_pub" $ + nf BOLT8.serialize_pub r_s_pub ] where Just (_, r_s_pub) = BOLT8.keypair r_s_ent r_s_pub_bs = BOLT8.serialize_pub r_s_pub handshake :: Benchmark -handshake = env setup $ \ ~(i_s_sec, i_s_pub, r_s_sec, r_s_pub, msg1, i_hs, - msg2, r_hs, msg3) -> +handshake = env setup $ + \ ~(i_s_sec, i_s_pub, r_s_sec, r_s_pub, + msg1, i_hs, msg2, r_hs, msg3) -> bgroup "handshake" [ - bench "act1" $ nf (BOLT8.act1 i_s_sec i_s_pub r_s_pub) i_e_ent - , bench "act2" $ nf (BOLT8.act2 r_s_sec r_s_pub r_e_ent) msg1 - , bench "act3" $ nf (BOLT8.act3 i_hs) msg2 - , bench "finalize" $ nf (BOLT8.finalize r_hs) msg3 + bench "act1" $ + nf (BOLT8.act1 i_s_sec i_s_pub r_s_pub) i_e_ent + , bench "act2" $ + nf (BOLT8.act2 r_s_sec r_s_pub r_e_ent) msg1 + , bench "act3" $ + nf (BOLT8.act3 i_hs) msg2 + , bench "finalize" $ + nf (BOLT8.finalize r_hs) msg3 ] where setup = do - let Just (!i_s_sec, !i_s_pub) = BOLT8.keypair i_s_ent - Just (!r_s_sec, !r_s_pub) = BOLT8.keypair r_s_ent - Right (!msg1, !i_hs) = BOLT8.act1 i_s_sec i_s_pub r_s_pub i_e_ent - Right (!msg2, !r_hs) = BOLT8.act2 r_s_sec r_s_pub r_e_ent msg1 + let Just (!i_s_sec, !i_s_pub) = + BOLT8.keypair i_s_ent + Just (!r_s_sec, !r_s_pub) = + BOLT8.keypair r_s_ent + Right (!msg1, !i_hs) = + BOLT8.act1 i_s_sec i_s_pub r_s_pub i_e_ent + Right (!msg2, !r_hs) = + BOLT8.act2 r_s_sec r_s_pub r_e_ent msg1 Right (!msg3, _) = BOLT8.act3 i_hs msg2 - pure (i_s_sec, i_s_pub, r_s_sec, r_s_pub, msg1, i_hs, msg2, r_hs, msg3) + pure ( i_s_sec, i_s_pub, r_s_sec, r_s_pub + , msg1, i_hs, msg2, r_hs, msg3 ) messages :: Benchmark -messages = env setup $ \ ~(i_sess, r_sess, ct_small, ct_large) -> +messages = env setup $ + \ ~(i_sess, r_sess, ct_small, ct_large) -> bgroup "messages" [ - bench "encrypt (32B)" $ nf (BOLT8.encrypt i_sess) small_msg - , bench "encrypt (1KB)" $ nf (BOLT8.encrypt i_sess) large_msg - , bench "decrypt (32B)" $ nf (BOLT8.decrypt r_sess) ct_small - , bench "decrypt (1KB)" $ nf (BOLT8.decrypt r_sess) ct_large + bench "encrypt (32B)" $ + nf (BOLT8.encrypt i_sess) small_msg + , bench "encrypt (1KB)" $ + nf (BOLT8.encrypt i_sess) large_msg + , bench "decrypt (32B)" $ + nf (BOLT8.decrypt r_sess) ct_small + , bench "decrypt (1KB)" $ + nf (BOLT8.decrypt r_sess) ct_large ] where small_msg = BS.replicate 32 0x00 large_msg = BS.replicate 1024 0x00 setup = do - let Just (!i_s_sec, !i_s_pub) = BOLT8.keypair i_s_ent - Just (!r_s_sec, !r_s_pub) = BOLT8.keypair r_s_ent - Right (msg1, i_hs) = BOLT8.act1 i_s_sec i_s_pub r_s_pub i_e_ent - Right (msg2, r_hs) = BOLT8.act2 r_s_sec r_s_pub r_e_ent msg1 - Right (msg3, i_result) = BOLT8.act3 i_hs msg2 - Right r_result = BOLT8.finalize r_hs msg3 + let Just (!i_s_sec, !i_s_pub) = + BOLT8.keypair i_s_ent + Just (!r_s_sec, !r_s_pub) = + BOLT8.keypair r_s_ent + Right (msg1, i_hs) = + BOLT8.act1 i_s_sec i_s_pub r_s_pub i_e_ent + Right (msg2, r_hs) = + BOLT8.act2 r_s_sec r_s_pub r_e_ent msg1 + Right (msg3, i_result) = + BOLT8.act3 i_hs msg2 + Right r_result = + BOLT8.finalize r_hs msg3 !i_sess = BOLT8.session i_result !r_sess = BOLT8.session r_result - Right (!ct_small, _) = BOLT8.encrypt i_sess small_msg - Right (!ct_large, _) = BOLT8.encrypt i_sess large_msg + Right (!ct_small, _) = + BOLT8.encrypt i_sess small_msg + Right (!ct_large, _) = + BOLT8.encrypt i_sess large_msg pure (i_sess, r_sess, ct_small, ct_large) diff --git a/bench/Weight.hs b/bench/Weight.hs @@ -6,7 +6,7 @@ module Main where import Control.DeepSeq import qualified Data.ByteString as BS -import qualified Lightning.Protocol.BOLT8 as BOLT8 +import qualified Lightning.Protocol.BOLT8.Internal as BOLT8 import Weigh instance NFData BOLT8.Pub where @@ -14,9 +14,13 @@ instance NFData BOLT8.Pub where instance NFData BOLT8.Sec instance NFData BOLT8.Error +instance NFData BOLT8.Key32 +instance NFData BOLT8.SessionNonce instance NFData BOLT8.Session instance NFData BOLT8.HandshakeState instance NFData BOLT8.Handshake +instance NFData (BOLT8.HandshakeFor a) where + rnf (BOLT8.HandshakeFor s) = rnf s -- note that 'weigh' doesn't work properly in a repl main :: IO () @@ -39,37 +43,56 @@ keys = in wgroup "keys" $ do func "keypair" BOLT8.keypair i_s_ent func "parse_pub" BOLT8.parse_pub r_s_pub_bs - func "serialize_pub" BOLT8.serialize_pub r_s_pub + func "serialize_pub" + BOLT8.serialize_pub r_s_pub handshake :: Weigh () handshake = - let Just (!i_s_sec, !i_s_pub) = BOLT8.keypair i_s_ent - Just (!r_s_sec, !r_s_pub) = BOLT8.keypair r_s_ent - Right (!msg1, !i_hs) = BOLT8.act1 i_s_sec i_s_pub r_s_pub i_e_ent - Right (!msg2, !r_hs) = BOLT8.act2 r_s_sec r_s_pub r_e_ent msg1 + let Just (!i_s_sec, !i_s_pub) = + BOLT8.keypair i_s_ent + Just (!r_s_sec, !r_s_pub) = + BOLT8.keypair r_s_ent + Right (!msg1, !i_hs) = + BOLT8.act1 i_s_sec i_s_pub r_s_pub i_e_ent + Right (!msg2, !r_hs) = + BOLT8.act2 r_s_sec r_s_pub r_e_ent msg1 Right (!msg3, _) = BOLT8.act3 i_hs msg2 in wgroup "handshake" $ do - func "act1" (BOLT8.act1 i_s_sec i_s_pub r_s_pub) i_e_ent - func "act2" (BOLT8.act2 r_s_sec r_s_pub r_e_ent) msg1 + func "act1" + (BOLT8.act1 i_s_sec i_s_pub r_s_pub) i_e_ent + func "act2" + (BOLT8.act2 r_s_sec r_s_pub r_e_ent) msg1 func "act3" (BOLT8.act3 i_hs) msg2 func "finalize" (BOLT8.finalize r_hs) msg3 messages :: Weigh () messages = - let Just (!i_s_sec, !i_s_pub) = BOLT8.keypair i_s_ent - Just (!r_s_sec, !r_s_pub) = BOLT8.keypair r_s_ent - Right (msg1, i_hs) = BOLT8.act1 i_s_sec i_s_pub r_s_pub i_e_ent - Right (msg2, r_hs) = BOLT8.act2 r_s_sec r_s_pub r_e_ent msg1 - Right (msg3, i_result) = BOLT8.act3 i_hs msg2 - Right r_result = BOLT8.finalize r_hs msg3 + let Just (!i_s_sec, !i_s_pub) = + BOLT8.keypair i_s_ent + Just (!r_s_sec, !r_s_pub) = + BOLT8.keypair r_s_ent + Right (msg1, i_hs) = + BOLT8.act1 i_s_sec i_s_pub r_s_pub i_e_ent + Right (msg2, r_hs) = + BOLT8.act2 r_s_sec r_s_pub r_e_ent msg1 + Right (msg3, i_result) = + BOLT8.act3 i_hs msg2 + Right r_result = + BOLT8.finalize r_hs msg3 !i_sess = BOLT8.session i_result !r_sess = BOLT8.session r_result !small_msg = BS.replicate 32 0x00 !large_msg = BS.replicate 1024 0x00 - Right (!ct_small, _) = BOLT8.encrypt i_sess small_msg - Right (!ct_large, _) = BOLT8.encrypt i_sess large_msg + Right (!ct_small, _) = + BOLT8.encrypt i_sess small_msg + Right (!ct_large, _) = + BOLT8.encrypt i_sess large_msg in wgroup "messages" $ do - func "encrypt (32B)" (BOLT8.encrypt i_sess) small_msg - func "encrypt (1KB)" (BOLT8.encrypt i_sess) large_msg - func "decrypt (32B)" (BOLT8.decrypt r_sess) ct_small - func "decrypt (1KB)" (BOLT8.decrypt r_sess) ct_large + func "encrypt (32B)" + (BOLT8.encrypt i_sess) small_msg + func "encrypt (1KB)" + (BOLT8.encrypt i_sess) large_msg + func "decrypt (32B)" + (BOLT8.decrypt r_sess) ct_small + func "decrypt (1KB)" + (BOLT8.decrypt r_sess) ct_large diff --git a/lib/Lightning/Protocol/BOLT8.hs b/lib/Lightning/Protocol/BOLT8.hs @@ -1,10 +1,4 @@ {-# OPTIONS_HADDOCK prune #-} -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE ViewPatterns #-} -- | -- Module: Lightning.Protocol.BOLT8 @@ -12,16 +6,19 @@ -- License: MIT -- Maintainer: Jared Tobin <jared@ppad.tech> -- --- Encrypted and authenticated transport for the Lightning Network, per +-- Encrypted and authenticated transport for the Lightning +-- Network, per -- [BOLT #8](https://github.com/lightning/bolts/blob/master/08-transport.md). -- --- This module implements the Noise_XK_secp256k1_ChaChaPoly_SHA256 --- handshake and subsequent encrypted message transport. +-- This module implements the +-- Noise_XK_secp256k1_ChaChaPoly_SHA256 handshake and +-- subsequent encrypted message transport. -- -- = Handshake -- --- A BOLT #8 handshake consists of three acts. The /initiator/ knows the --- responder's static public key in advance and initiates the connection: +-- A BOLT #8 handshake consists of three acts. The +-- /initiator/ knows the responder's static public key in +-- advance and initiates the connection: -- -- @ -- (msg1, state) <- act1 i_sec i_pub r_pub entropy @@ -32,7 +29,8 @@ -- let session = 'session' result -- @ -- --- The /responder/ receives the connection and authenticates the initiator: +-- The /responder/ receives the connection and authenticates +-- the initiator: -- -- @ -- -- receive msg1 (50 bytes) from initiator @@ -45,9 +43,10 @@ -- -- = Message Transport -- --- After a successful handshake, use 'encrypt' and 'decrypt' to exchange --- messages. Each returns an updated 'Session' that must be used for the --- next operation (keys rotate every 1000 messages): +-- After a successful handshake, use 'encrypt' and 'decrypt' +-- to exchange messages. Each returns an updated 'Session' +-- that must be used for the next operation (keys rotate +-- every 1000 messages): -- -- @ -- -- sender @@ -59,10 +58,11 @@ -- -- = Message Framing -- --- BOLT #8 runs over a byte stream, so callers often need to deal with --- partial buffers. Use 'decrypt_frame' when you have exactly one frame, --- or 'decrypt_frame_partial' to handle incremental reads and return how --- many bytes are still needed. +-- BOLT #8 runs over a byte stream, so callers often need to +-- deal with partial buffers. Use 'decrypt_frame' when you +-- have exactly one frame, or 'decrypt_frame_partial' to +-- handle incremental reads and return how many bytes are +-- still needed. -- -- Maximum plaintext size is 65535 bytes. @@ -74,6 +74,21 @@ module Lightning.Protocol.BOLT8 ( , parse_pub , serialize_pub + -- * Newtypes + , Key32 + , key32 + , unKey32 + , SessionNonce + , unSessionNonce + , MessagePayload + , unMessagePayload + , mkMessagePayload + + -- * Handshake roles + , Initiator + , Responder + , HandshakeFor + -- * Handshake (initiator) , act1 , act3 @@ -96,614 +111,4 @@ module Lightning.Protocol.BOLT8 ( , Error(..) ) where -import Control.Monad (guard, unless) -import qualified Crypto.AEAD.ChaCha20Poly1305 as AEAD -import qualified Crypto.Curve.Secp256k1 as Secp256k1 -import qualified Crypto.Hash.SHA256 as SHA256 -import qualified Crypto.KDF.HMAC as HKDF -import Data.Bits (unsafeShiftR, (.&.)) -import qualified Data.ByteString as BS -import Data.Word (Word16, Word64) -import GHC.Generics (Generic) - --- types --------------------------------------------------------------------- - --- | Secret key (32 bytes). -newtype Sec = Sec BS.ByteString - deriving (Eq, Generic) - --- | Compressed public key. -newtype Pub = Pub Secp256k1.Projective - -instance Eq Pub where - (Pub a) == (Pub b) = - Secp256k1.serialize_point a == Secp256k1.serialize_point b - -instance Show Pub where - show (Pub p) = "Pub " ++ show (Secp256k1.serialize_point p) - --- | Handshake errors. -data Error = - InvalidKey - | InvalidPub - | InvalidMAC - | InvalidVersion - | InvalidLength - | DecryptionFailed - deriving (Eq, Show, Generic) - --- | Result of attempting to decrypt a frame from a partial buffer. -data FrameResult = - NeedMore {-# UNPACK #-} !Int - -- ^ More bytes needed; the 'Int' is the minimum additional bytes required. - | FrameOk !BS.ByteString !BS.ByteString !Session - -- ^ Successfully decrypted: plaintext, remainder, updated session. - | FrameError !Error - -- ^ Decryption failed with the given error. - deriving Generic - --- | Post-handshake session state. -data Session = Session { - sess_sk :: {-# UNPACK #-} !BS.ByteString -- ^ send key (32 bytes) - , sess_sn :: {-# UNPACK #-} !Word64 -- ^ send nonce - , sess_sck :: {-# UNPACK #-} !BS.ByteString -- ^ send chaining key - , sess_rk :: {-# UNPACK #-} !BS.ByteString -- ^ receive key (32 bytes) - , sess_rn :: {-# UNPACK #-} !Word64 -- ^ receive nonce - , sess_rck :: {-# UNPACK #-} !BS.ByteString -- ^ receive chaining key - } - deriving Generic - --- | Result of a successful handshake. -data Handshake = Handshake { - session :: !Session -- ^ session state - , remote_static :: !Pub -- ^ authenticated remote static pubkey - } - deriving Generic - --- | Internal handshake state (exported for benchmarking). -data HandshakeState = HandshakeState { - hs_h :: {-# UNPACK #-} !BS.ByteString -- handshake hash (32 bytes) - , hs_ck :: {-# UNPACK #-} !BS.ByteString -- chaining key (32 bytes) - , hs_temp_k :: {-# UNPACK #-} !BS.ByteString -- temp key (32 bytes) - , hs_e_sec :: !Sec -- ephemeral secret - , hs_e_pub :: !Pub -- ephemeral public - , hs_s_sec :: !Sec -- static secret - , hs_s_pub :: !Pub -- static public - , hs_re :: !(Maybe Pub) -- remote ephemeral - , hs_rs :: !(Maybe Pub) -- remote static - } - deriving Generic - --- protocol constants -------------------------------------------------------- - -_PROTOCOL_NAME :: BS.ByteString -_PROTOCOL_NAME = "Noise_XK_secp256k1_ChaChaPoly_SHA256" - -_PROLOGUE :: BS.ByteString -_PROLOGUE = "lightning" - --- key operations ------------------------------------------------------------ - --- | Derive a keypair from 32 bytes of entropy. --- --- Returns Nothing if the entropy is invalid (zero or >= curve order). --- --- >>> let ent = BS.replicate 32 0x11 --- >>> case keypair ent of { Just _ -> "ok"; Nothing -> "fail" } --- "ok" --- >>> keypair (BS.replicate 31 0x11) -- wrong length --- Nothing -keypair :: BS.ByteString -> Maybe (Sec, Pub) -keypair ent = do - guard (BS.length ent == 32) - k <- Secp256k1.parse_int256 ent - p <- Secp256k1.derive_pub k - pure (Sec ent, Pub p) - --- | Parse a 33-byte compressed public key. --- --- >>> let Just (_, pub) = keypair (BS.replicate 32 0x11) --- >>> let bytes = serialize_pub pub --- >>> case parse_pub bytes of { Just _ -> "ok"; Nothing -> "fail" } --- "ok" --- >>> parse_pub (BS.replicate 32 0x00) -- wrong length --- Nothing -parse_pub :: BS.ByteString -> Maybe Pub -parse_pub bs = do - guard (BS.length bs == 33) - p <- Secp256k1.parse_point bs - pure (Pub p) - --- | Serialize a public key to 33-byte compressed form. --- --- >>> let Just (_, pub) = keypair (BS.replicate 32 0x11) --- >>> BS.length (serialize_pub pub) --- 33 -serialize_pub :: Pub -> BS.ByteString -serialize_pub (Pub p) = Secp256k1.serialize_point p - --- cryptographic primitives -------------------------------------------------- - --- bolt8-style ECDH -ecdh :: Sec -> Pub -> Maybe BS.ByteString -ecdh (Sec sec) (Pub pub) = do - k <- Secp256k1.parse_int256 sec - pt <- Secp256k1.mul pub k - let compressed = Secp256k1.serialize_point pt - pure (SHA256.hash compressed) - --- h' = SHA256(h || data) -mix_hash :: BS.ByteString -> BS.ByteString -> BS.ByteString -mix_hash h dat = SHA256.hash (h <> dat) - --- Mix key: (ck', k) = HKDF(ck, input_key_material) --- --- NB HKDF limits output to 255 * hashlen bytes. For SHA256 that's 8160, --- well above the 64 bytes requested here, so 'Nothing' is impossible. -mix_key :: BS.ByteString -> BS.ByteString -> (BS.ByteString, BS.ByteString) -mix_key ck ikm = case HKDF.derive hmac ck mempty 64 ikm of - Nothing -> error "ppad-bolt8: internal error, please report a bug!" - Just output -> BS.splitAt 32 output - where - hmac k b = case SHA256.hmac k b of - SHA256.MAC mac -> mac - --- Encrypt with associated data using ChaCha20-Poly1305 -encrypt_with_ad - :: BS.ByteString -- ^ key (32 bytes) - -> Word64 -- ^ nonce - -> BS.ByteString -- ^ associated data - -> BS.ByteString -- ^ plaintext - -> Maybe BS.ByteString -- ^ ciphertext || mac (16 bytes) -encrypt_with_ad key n ad pt = - case AEAD.encrypt ad key (encode_nonce n) pt of - Left _ -> Nothing - Right (ct, mac) -> Just (ct <> mac) - --- Decrypt with associated data using ChaCha20-Poly1305 -decrypt_with_ad - :: BS.ByteString -- ^ key (32 bytes) - -> Word64 -- ^ nonce - -> BS.ByteString -- ^ associated data - -> BS.ByteString -- ^ ciphertext || mac - -> Maybe BS.ByteString -- ^ plaintext -decrypt_with_ad key n ad ctmac - | BS.length ctmac < 16 = Nothing - | otherwise = - let (ct, mac) = BS.splitAt (BS.length ctmac - 16) ctmac - in case AEAD.decrypt ad key (encode_nonce n) (ct, mac) of - Left _ -> Nothing - Right pt -> Just pt - --- Encode nonce as 96-bit value: 4 zero bytes + 8-byte little-endian -encode_nonce :: Word64 -> BS.ByteString -encode_nonce n = BS.replicate 4 0x00 <> encode_le64 n - --- Little-endian 64-bit encoding -encode_le64 :: Word64 -> BS.ByteString -encode_le64 n = BS.pack [ - fi (n .&. 0xff) - , fi (unsafeShiftR n 8 .&. 0xff) - , fi (unsafeShiftR n 16 .&. 0xff) - , fi (unsafeShiftR n 24 .&. 0xff) - , fi (unsafeShiftR n 32 .&. 0xff) - , fi (unsafeShiftR n 40 .&. 0xff) - , fi (unsafeShiftR n 48 .&. 0xff) - , fi (unsafeShiftR n 56 .&. 0xff) - ] - --- Big-endian 16-bit encoding -encode_be16 :: Word16 -> BS.ByteString -encode_be16 n = BS.pack [fi (unsafeShiftR n 8), fi (n .&. 0xff)] - --- Big-endian 16-bit decoding -decode_be16 :: BS.ByteString -> Maybe Word16 -decode_be16 bs - | BS.length bs /= 2 = Nothing - | otherwise = - let !b0 = BS.index bs 0 - !b1 = BS.index bs 1 - in Just (fi b0 * 0x100 + fi b1) - --- handshake ----------------------------------------------------------------- - --- Initialize handshake state --- --- h = SHA256(protocol_name) --- ck = h --- h = SHA256(h || prologue) --- h = SHA256(h || responder_static_pubkey) -init_handshake - :: Sec -- ^ local static secret - -> Pub -- ^ local static public - -> Sec -- ^ ephemeral secret - -> Pub -- ^ ephemeral public - -> Maybe Pub -- ^ remote static (initiator knows, responder doesn't) - -> Bool -- ^ True if initiator - -> HandshakeState -init_handshake s_sec s_pub e_sec e_pub m_rs is_initiator = - let !h0 = SHA256.hash _PROTOCOL_NAME - !ck = h0 - !h1 = mix_hash h0 _PROLOGUE - -- Mix in responder's static pubkey - !h2 = case (is_initiator, m_rs) of - (True, Just rs) -> mix_hash h1 (serialize_pub rs) - (False, Nothing) -> mix_hash h1 (serialize_pub s_pub) - _ -> h1 -- shouldn't happen - in HandshakeState { - hs_h = h2 - , hs_ck = ck - , hs_temp_k = BS.replicate 32 0x00 - , hs_e_sec = e_sec - , hs_e_pub = e_pub - , hs_s_sec = s_sec - , hs_s_pub = s_pub - , hs_re = Nothing - , hs_rs = m_rs - } - --- | Initiator: generate Act 1 message (50 bytes). --- --- Takes local static key, remote static pubkey, and 32 bytes of --- entropy for ephemeral key generation. --- --- Returns the 50-byte Act 1 message and handshake state for Act 3. --- --- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) --- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) --- >>> let eph_ent = BS.replicate 32 0x12 --- >>> case act1 i_sec i_pub r_pub eph_ent of { Right (msg, _) -> BS.length msg; Left _ -> 0 } --- 50 -act1 - :: Sec -- ^ local static secret - -> Pub -- ^ local static public - -> Pub -- ^ remote static public (responder's) - -> BS.ByteString -- ^ 32 bytes entropy for ephemeral - -> Either Error (BS.ByteString, HandshakeState) -act1 s_sec s_pub rs ent = do - (e_sec, e_pub) <- note InvalidKey (keypair ent) - let !hs0 = init_handshake s_sec s_pub e_sec e_pub (Just rs) True - !e_pub_bytes = serialize_pub e_pub - !h1 = mix_hash (hs_h hs0) e_pub_bytes - es <- note InvalidKey (ecdh e_sec rs) - let !(ck1, temp_k1) = mix_key (hs_ck hs0) es - c <- note InvalidMAC (encrypt_with_ad temp_k1 0 h1 BS.empty) - let !h2 = mix_hash h1 c - !msg = BS.singleton 0x00 <> e_pub_bytes <> c - !hs1 = hs0 { - hs_h = h2 - , hs_ck = ck1 - , hs_temp_k = temp_k1 - } - pure (msg, hs1) - --- | Responder: process Act 1 and generate Act 2 message (50 bytes). --- --- Takes local static key and 32 bytes of entropy for ephemeral key, --- plus the 50-byte Act 1 message from initiator. --- --- Returns the 50-byte Act 2 message and handshake state for finalize. --- --- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) --- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) --- >>> let Right (msg1, _) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) --- >>> case act2 r_sec r_pub (BS.replicate 32 0x22) msg1 of { Right (msg, _) -> BS.length msg; Left _ -> 0 } --- 50 -act2 - :: Sec -- ^ local static secret - -> Pub -- ^ local static public - -> BS.ByteString -- ^ 32 bytes entropy for ephemeral - -> BS.ByteString -- ^ Act 1 message (50 bytes) - -> Either Error (BS.ByteString, HandshakeState) -act2 s_sec s_pub ent msg1 = do - require (BS.length msg1 == 50) InvalidLength - let !version = BS.index msg1 0 - !re_bytes = BS.take 33 (BS.drop 1 msg1) - !c = BS.drop 34 msg1 - require (version == 0x00) InvalidVersion - re <- note InvalidPub (parse_pub re_bytes) - (e_sec, e_pub) <- note InvalidKey (keypair ent) - let !hs0 = init_handshake s_sec s_pub e_sec e_pub Nothing False - !h1 = mix_hash (hs_h hs0) re_bytes - es <- note InvalidKey (ecdh s_sec re) - let !(ck1, temp_k1) = mix_key (hs_ck hs0) es - _ <- note InvalidMAC (decrypt_with_ad temp_k1 0 h1 c) - let !h2 = mix_hash h1 c - !e_pub_bytes = serialize_pub e_pub - !h3 = mix_hash h2 e_pub_bytes - ee <- note InvalidKey (ecdh e_sec re) - let !(ck2, temp_k2) = mix_key ck1 ee - c2 <- note InvalidMAC (encrypt_with_ad temp_k2 0 h3 BS.empty) - let !h4 = mix_hash h3 c2 - !msg = BS.singleton 0x00 <> e_pub_bytes <> c2 - !hs1 = hs0 { - hs_h = h4 - , hs_ck = ck2 - , hs_temp_k = temp_k2 - , hs_re = Just re - } - pure (msg, hs1) - --- | Initiator: process Act 2 and generate Act 3 (66 bytes), completing --- the handshake. --- --- Returns the 66-byte Act 3 message and the handshake result. --- --- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) --- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) --- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) --- >>> let Right (msg2, _) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1 --- >>> case act3 i_hs msg2 of { Right (msg, _) -> BS.length msg; Left _ -> 0 } --- 66 -act3 - :: HandshakeState -- ^ state after Act 1 - -> BS.ByteString -- ^ Act 2 message (50 bytes) - -> Either Error (BS.ByteString, Handshake) -act3 hs msg2 = do - require (BS.length msg2 == 50) InvalidLength - let !version = BS.index msg2 0 - !re_bytes = BS.take 33 (BS.drop 1 msg2) - !c = BS.drop 34 msg2 - require (version == 0x00) InvalidVersion - re <- note InvalidPub (parse_pub re_bytes) - let !h1 = mix_hash (hs_h hs) re_bytes - ee <- note InvalidKey (ecdh (hs_e_sec hs) re) - let !(ck1, temp_k2) = mix_key (hs_ck hs) ee - _ <- note InvalidMAC (decrypt_with_ad temp_k2 0 h1 c) - let !h2 = mix_hash h1 c - !s_pub_bytes = serialize_pub (hs_s_pub hs) - c3 <- note InvalidMAC (encrypt_with_ad temp_k2 1 h2 s_pub_bytes) - let !h3 = mix_hash h2 c3 - se <- note InvalidKey (ecdh (hs_s_sec hs) re) - let !(ck2, temp_k3) = mix_key ck1 se - t <- note InvalidMAC (encrypt_with_ad temp_k3 0 h3 BS.empty) - let !(sk, rk) = mix_key ck2 BS.empty - !msg = BS.singleton 0x00 <> c3 <> t - !sess = Session { - sess_sk = sk - , sess_sn = 0 - , sess_sck = ck2 - , sess_rk = rk - , sess_rn = 0 - , sess_rck = ck2 - } - rs <- note InvalidPub (hs_rs hs) - let !result = Handshake { - session = sess - , remote_static = rs - } - pure (msg, result) - --- | Responder: process Act 3 (66 bytes) and complete the handshake. --- --- Returns the handshake result with authenticated remote static pubkey. --- --- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) --- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) --- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) --- >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1 --- >>> let Right (msg3, _) = act3 i_hs msg2 --- >>> case finalize r_hs msg3 of { Right _ -> "ok"; Left e -> show e } --- "ok" -finalize - :: HandshakeState -- ^ state after Act 2 - -> BS.ByteString -- ^ Act 3 message (66 bytes) - -> Either Error Handshake -finalize hs msg3 = do - require (BS.length msg3 == 66) InvalidLength - let !version = BS.index msg3 0 - !c = BS.take 49 (BS.drop 1 msg3) - !t = BS.drop 50 msg3 - require (version == 0x00) InvalidVersion - rs_bytes <- note InvalidMAC (decrypt_with_ad (hs_temp_k hs) 1 (hs_h hs) c) - rs <- note InvalidPub (parse_pub rs_bytes) - let !h1 = mix_hash (hs_h hs) c - se <- note InvalidKey (ecdh (hs_e_sec hs) rs) - let !(ck1, temp_k3) = mix_key (hs_ck hs) se - _ <- note InvalidMAC (decrypt_with_ad temp_k3 0 h1 t) - -- responder swaps order (receives what initiator sends) - let !(rk, sk) = mix_key ck1 BS.empty - !sess = Session { - sess_sk = sk - , sess_sn = 0 - , sess_sck = ck1 - , sess_rk = rk - , sess_rn = 0 - , sess_rck = ck1 - } - !result = Handshake { - session = sess - , remote_static = rs - } - pure result - --- message encryption -------------------------------------------------------- - --- | Encrypt a message (max 65535 bytes). --- --- Returns the encrypted packet and updated session. Key rotation --- is handled automatically at nonce 1000. --- --- Wire format: encrypted_length (2) || MAC (16) || encrypted_body || MAC (16) --- --- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) --- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) --- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) --- >>> let Right (msg2, _) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1 --- >>> let Right (_, i_result) = act3 i_hs msg2 --- >>> let sess = session i_result --- >>> case encrypt sess "hello" of { Right (ct, _) -> BS.length ct; Left _ -> 0 } --- 39 -encrypt - :: Session - -> BS.ByteString -- ^ plaintext (max 65535 bytes) - -> Either Error (BS.ByteString, Session) -encrypt sess pt = do - let !len = BS.length pt - require (len <= 65535) InvalidLength - let !len_bytes = encode_be16 (fi len) - lc <- note InvalidMAC (encrypt_with_ad (sess_sk sess) (sess_sn sess) - BS.empty len_bytes) - let !(sn1, sck1, sk1) = step_nonce (sess_sn sess) (sess_sck sess) (sess_sk sess) - bc <- note InvalidMAC (encrypt_with_ad sk1 sn1 BS.empty pt) - let !(sn2, sck2, sk2) = step_nonce sn1 sck1 sk1 - !packet = lc <> bc - !sess' = sess { - sess_sk = sk2 - , sess_sn = sn2 - , sess_sck = sck2 - } - pure (packet, sess') - --- | Decrypt a message, requiring an exact packet with no trailing bytes. --- --- Returns the plaintext and updated session. Key rotation --- is handled automatically at nonce 1000. --- --- This is a strict variant that rejects any trailing data. For --- streaming use cases where you need to handle multiple frames in a --- buffer, use 'decrypt_frame' instead. --- --- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) --- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) --- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) --- >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1 --- >>> let Right (msg3, i_result) = act3 i_hs msg2 --- >>> let Right r_result = finalize r_hs msg3 --- >>> let Right (ct, _) = encrypt (session i_result) "hello" --- >>> case decrypt (session r_result) ct of { Right (pt, _) -> pt; Left _ -> "fail" } --- "hello" -decrypt - :: Session - -> BS.ByteString -- ^ encrypted packet (exact length required) - -> Either Error (BS.ByteString, Session) -decrypt sess packet = do - (pt, remainder, sess') <- decrypt_frame sess packet - require (BS.null remainder) InvalidLength - pure (pt, sess') - --- | Decrypt a single frame from a buffer, returning the remainder. --- --- Returns the plaintext, any unconsumed bytes, and the updated session. --- Key rotation is handled automatically every 1000 messages. --- --- This is useful for streaming scenarios where multiple messages may --- be buffered together. The remainder can be passed to the next call --- to 'decrypt_frame'. --- --- Wire format consumed: encrypted_length (18) || encrypted_body (len + 16) --- --- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) --- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) --- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) --- >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1 --- >>> let Right (msg3, i_result) = act3 i_hs msg2 --- >>> let Right r_result = finalize r_hs msg3 --- >>> let Right (ct, _) = encrypt (session i_result) "hello" --- >>> case decrypt_frame (session r_result) ct of { Right (pt, rem, _) -> (pt, BS.null rem); Left _ -> ("fail", False) } --- ("hello",True) -decrypt_frame - :: Session - -> BS.ByteString -- ^ buffer containing at least one encrypted frame - -> Either Error (BS.ByteString, BS.ByteString, Session) -decrypt_frame sess packet = do - require (BS.length packet >= 34) InvalidLength - let !lc = BS.take 18 packet - !rest = BS.drop 18 packet - len_bytes <- note InvalidMAC (decrypt_with_ad (sess_rk sess) (sess_rn sess) - BS.empty lc) - len <- note InvalidLength (decode_be16 len_bytes) - let !(rn1, rck1, rk1) = step_nonce (sess_rn sess) (sess_rck sess) (sess_rk sess) - !body_len = fi len + 16 - require (BS.length rest >= body_len) InvalidLength - let !bc = BS.take body_len rest - !remainder = BS.drop body_len rest - pt <- note InvalidMAC (decrypt_with_ad rk1 rn1 BS.empty bc) - let !(rn2, rck2, rk2) = step_nonce rn1 rck1 rk1 - !sess' = sess { - sess_rk = rk2 - , sess_rn = rn2 - , sess_rck = rck2 - } - pure (pt, remainder, sess') - --- | Decrypt a frame from a partial buffer, indicating when more data needed. --- --- Unlike 'decrypt_frame', this function handles incomplete buffers --- gracefully by returning 'NeedMore' with the number of additional --- bytes required to make progress. --- --- * If the buffer has fewer than 18 bytes (encrypted length + MAC), --- returns @'NeedMore' n@ where @n@ is the bytes still needed. --- * If the length header is complete but the body is incomplete, --- returns @'NeedMore' n@ with bytes needed for the full frame. --- * MAC or decryption failures return 'FrameError'. --- * A complete, valid frame returns 'FrameOk' with plaintext, --- remainder, and updated session. --- --- This is useful for non-blocking I/O where data arrives incrementally. -decrypt_frame_partial - :: Session - -> BS.ByteString -- ^ buffer (possibly incomplete) - -> FrameResult -decrypt_frame_partial sess buf - | buflen < 18 = NeedMore (18 - buflen) - | otherwise = - let !lc = BS.take 18 buf - !rest = BS.drop 18 buf - in case decrypt_with_ad (sess_rk sess) (sess_rn sess) BS.empty lc of - Nothing -> FrameError InvalidMAC - Just len_bytes -> case decode_be16 len_bytes of - Nothing -> FrameError InvalidLength - Just len -> - let !body_len = fi len + 16 - !(rn1, rck1, rk1) = step_nonce (sess_rn sess) - (sess_rck sess) (sess_rk sess) - in if BS.length rest < body_len - then NeedMore (body_len - BS.length rest) - else - let !bc = BS.take body_len rest - !remainder = BS.drop body_len rest - in case decrypt_with_ad rk1 rn1 BS.empty bc of - Nothing -> FrameError InvalidMAC - Just pt -> - let !(rn2, rck2, rk2) = step_nonce rn1 rck1 rk1 - !sess' = sess { - sess_rk = rk2 - , sess_rn = rn2 - , sess_rck = rck2 - } - in FrameOk pt remainder sess' - where - !buflen = BS.length buf - --- key rotation -------------------------------------------------------------- - --- Key rotation occurs after nonce reaches 1000 (i.e., before using 1000) --- (ck', k') = HKDF(ck, k), reset nonce to 0 -step_nonce - :: Word64 - -> BS.ByteString - -> BS.ByteString - -> (Word64, BS.ByteString, BS.ByteString) -step_nonce n ck k - | n + 1 == 1000 = - let !(ck', k') = mix_key ck k - in (0, ck', k') - | otherwise = (n + 1, ck, k) - --- utilities ----------------------------------------------------------------- - --- Lift Maybe to Either -note :: e -> Maybe a -> Either e a -note e = maybe (Left e) Right -{-# INLINE note #-} - --- Require condition or fail -require :: Bool -> e -> Either e () -require cond e = unless cond (Left e) -{-# INLINE require #-} - -fi :: (Integral a, Num b) => a -> b -fi = fromIntegral -{-# INLINE fi #-} +import Lightning.Protocol.BOLT8.Internal diff --git a/lib/Lightning/Protocol/BOLT8/Internal.hs b/lib/Lightning/Protocol/BOLT8/Internal.hs @@ -0,0 +1,818 @@ +{-# OPTIONS_HADDOCK prune #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ViewPatterns #-} + +-- | +-- Module: Lightning.Protocol.BOLT8.Internal +-- Copyright: (c) 2025 Jared Tobin +-- License: MIT +-- Maintainer: Jared Tobin <jared@ppad.tech> +-- +-- Internal module exporting all constructors for testing and +-- benchmarking. Prefer "Lightning.Protocol.BOLT8" for general use. + +module Lightning.Protocol.BOLT8.Internal ( + -- * Keys + Sec(..) + , Pub(..) + , keypair + , parse_pub + , serialize_pub + + -- * Newtypes + , Key32(..) + , key32 + , unsafeKey32 + , SessionNonce(..) + , MessagePayload(..) + , mkMessagePayload + + -- * Handshake roles + , Initiator + , Responder + , HandshakeFor(..) + + -- * Handshake (initiator) + , act1 + , act3 + + -- * Handshake (responder) + , act2 + , finalize + + -- * Session + , Session(..) + , HandshakeState(..) + , Handshake(..) + , encrypt + , decrypt + , decrypt_frame + , decrypt_frame_partial + , FrameResult(..) + + -- * Errors + , Error(..) + ) where + +import Control.Monad (guard, unless) +import qualified Crypto.AEAD.ChaCha20Poly1305 as AEAD +import qualified Crypto.Curve.Secp256k1 as Secp256k1 +import qualified Crypto.Hash.SHA256 as SHA256 +import qualified Crypto.KDF.HMAC as HKDF +import Data.Bits (unsafeShiftR, (.&.)) +import qualified Data.ByteString as BS +import Data.Word (Word16, Word64) +import GHC.Generics (Generic) + +-- types ----------------------------------------------------------- + +-- | Secret key (32 bytes). +newtype Sec = Sec BS.ByteString + deriving (Eq, Generic) + +-- | Compressed public key. +newtype Pub = Pub Secp256k1.Projective + +instance Eq Pub where + (Pub a) == (Pub b) = + Secp256k1.serialize_point a + == Secp256k1.serialize_point b + +instance Show Pub where + show (Pub p) = + "Pub " ++ show (Secp256k1.serialize_point p) + +-- | A 32-byte key, validated at construction. +newtype Key32 = Key32 { unKey32 :: BS.ByteString } + deriving (Eq, Generic) + +-- | Construct a 'Key32' from a 32-byte 'BS.ByteString'. +-- +-- Returns 'Nothing' if the input is not exactly 32 bytes. +-- +-- >>> key32 (BS.replicate 32 0x00) +-- Just (Key32 {unKey32 = ...}) +-- >>> key32 (BS.replicate 31 0x00) +-- Nothing +key32 :: BS.ByteString -> Maybe Key32 +key32 bs + | BS.length bs == 32 = Just (Key32 bs) + | otherwise = Nothing + +-- | Construct a 'Key32' without validation. +-- +-- For test and benchmark use only; prefer 'key32'. +unsafeKey32 :: BS.ByteString -> Key32 +unsafeKey32 = Key32 + +-- | Session nonce, distinguishing send from receive direction. +newtype SessionNonce = + SessionNonce { unSessionNonce :: Word64 } + deriving (Eq, Generic) + +-- | Message payload (max 65535 bytes), validated at construction. +newtype MessagePayload = + MessagePayload { unMessagePayload :: BS.ByteString } + deriving (Eq, Generic) + +-- | Construct a 'MessagePayload' from a 'BS.ByteString'. +-- +-- Returns 'Left' if the payload exceeds 65535 bytes. +mkMessagePayload + :: BS.ByteString -> Either Error MessagePayload +mkMessagePayload bs + | BS.length bs > 65535 = Left InvalidLength + | otherwise = Right (MessagePayload bs) + +-- | Handshake errors. +data Error = + InvalidKey + | InvalidPub + | InvalidMAC + | InvalidVersion + | InvalidLength + | DecryptionFailed + deriving (Eq, Show, Generic) + +-- | Result of attempting to decrypt a frame from a partial +-- buffer. +data FrameResult = + NeedMore {-# UNPACK #-} !Int + -- ^ More bytes needed; the 'Int' is the minimum + -- additional bytes required. + | FrameOk !BS.ByteString !BS.ByteString !Session + -- ^ Successfully decrypted: plaintext, remainder, + -- updated session. + | FrameError !Error + -- ^ Decryption failed with the given error. + deriving Generic + +-- | Post-handshake session state. +data Session = Session { + sess_sk :: !Key32 + -- ^ send key (32 bytes) + , sess_sn :: !SessionNonce + -- ^ send nonce + , sess_sck :: !Key32 + -- ^ send chaining key + , sess_rk :: !Key32 + -- ^ receive key (32 bytes) + , sess_rn :: !SessionNonce + -- ^ receive nonce + , sess_rck :: !Key32 + -- ^ receive chaining key + } + deriving Generic + +-- | Result of a successful handshake. +data Handshake = Handshake { + session :: !Session + -- ^ session state + , remote_static :: !Pub + -- ^ authenticated remote static pubkey + } + deriving Generic + +-- | Internal handshake state (exported for benchmarking). +data HandshakeState = HandshakeState { + hs_h :: {-# UNPACK #-} !BS.ByteString + -- ^ handshake hash (32 bytes) + , hs_ck :: {-# UNPACK #-} !BS.ByteString + -- ^ chaining key (32 bytes) + , hs_temp_k :: {-# UNPACK #-} !BS.ByteString + -- ^ temp key (32 bytes) + , hs_e_sec :: !Sec + -- ^ ephemeral secret + , hs_e_pub :: !Pub + -- ^ ephemeral public + , hs_s_sec :: !Sec + -- ^ static secret + , hs_s_pub :: !Pub + -- ^ static public + , hs_re :: !(Maybe Pub) + -- ^ remote ephemeral + , hs_rs :: !(Maybe Pub) + -- ^ remote static + } + deriving Generic + +-- handshake roles ------------------------------------------------- + +-- | Phantom type for initiator role. +data Initiator + +-- | Phantom type for responder role. +data Responder + +-- | Role-indexed handshake state. +-- +-- The phantom type parameter prevents passing an initiator's +-- state to a responder function and vice versa. +data HandshakeFor a = + HandshakeFor { unHandshakeFor :: !HandshakeState } + +-- protocol constants ---------------------------------------------- + +_PROTOCOL_NAME :: BS.ByteString +_PROTOCOL_NAME = + "Noise_XK_secp256k1_ChaChaPoly_SHA256" + +_PROLOGUE :: BS.ByteString +_PROLOGUE = "lightning" + +-- key operations -------------------------------------------------- + +-- | Derive a keypair from 32 bytes of entropy. +-- +-- Returns Nothing if the entropy is invalid +-- (zero or >= curve order). +-- +-- >>> let ent = BS.replicate 32 0x11 +-- >>> case keypair ent of +-- ... Just _ -> "ok" +-- ... Nothing -> "fail" +-- "ok" +-- >>> keypair (BS.replicate 31 0x11) -- wrong length +-- Nothing +keypair :: BS.ByteString -> Maybe (Sec, Pub) +keypair ent = do + guard (BS.length ent == 32) + k <- Secp256k1.parse_int256 ent + p <- Secp256k1.derive_pub k + pure (Sec ent, Pub p) + +-- | Parse a 33-byte compressed public key. +-- +-- >>> let Just (_, pub) = keypair (BS.replicate 32 0x11) +-- >>> let bytes = serialize_pub pub +-- >>> case parse_pub bytes of +-- ... Just _ -> "ok" +-- ... Nothing -> "fail" +-- "ok" +-- >>> parse_pub (BS.replicate 32 0x00) -- wrong length +-- Nothing +parse_pub :: BS.ByteString -> Maybe Pub +parse_pub bs = do + guard (BS.length bs == 33) + p <- Secp256k1.parse_point bs + pure (Pub p) + +-- | Serialize a public key to 33-byte compressed form. +-- +-- >>> let Just (_, pub) = keypair (BS.replicate 32 0x11) +-- >>> BS.length (serialize_pub pub) +-- 33 +serialize_pub :: Pub -> BS.ByteString +serialize_pub (Pub p) = Secp256k1.serialize_point p + +-- cryptographic primitives ---------------------------------------- + +-- bolt8-style ECDH +ecdh :: Sec -> Pub -> Maybe BS.ByteString +ecdh (Sec sec) (Pub pub) = do + k <- Secp256k1.parse_int256 sec + pt <- Secp256k1.mul pub k + let compressed = Secp256k1.serialize_point pt + pure (SHA256.hash compressed) + +-- h' = SHA256(h || data) +mix_hash + :: BS.ByteString -> BS.ByteString -> BS.ByteString +mix_hash h dat = SHA256.hash (h <> dat) + +-- Mix key: (ck', k) = HKDF(ck, input_key_material) +-- +-- NB HKDF limits output to 255 * hashlen bytes. For SHA256 +-- that's 8160, well above the 64 bytes requested here, so +-- 'Nothing' is impossible. +mix_key + :: BS.ByteString + -> BS.ByteString + -> (BS.ByteString, BS.ByteString) +mix_key ck ikm = + case HKDF.derive hmac ck mempty 64 ikm of + Nothing -> + error + "ppad-bolt8: internal error, please report a bug!" + Just output -> BS.splitAt 32 output + where + hmac k b = case SHA256.hmac k b of + SHA256.MAC mac -> mac + +-- Encrypt with associated data using ChaCha20-Poly1305 +encrypt_with_ad + :: BS.ByteString -- ^ key (32 bytes) + -> Word64 -- ^ nonce + -> BS.ByteString -- ^ associated data + -> BS.ByteString -- ^ plaintext + -> Maybe BS.ByteString -- ^ ciphertext || mac (16 bytes) +encrypt_with_ad key n ad pt = + case AEAD.encrypt ad key (encode_nonce n) pt of + Left _ -> Nothing + Right (ct, mac) -> Just (ct <> mac) + +-- Decrypt with associated data using ChaCha20-Poly1305 +decrypt_with_ad + :: BS.ByteString -- ^ key (32 bytes) + -> Word64 -- ^ nonce + -> BS.ByteString -- ^ associated data + -> BS.ByteString -- ^ ciphertext || mac + -> Maybe BS.ByteString -- ^ plaintext +decrypt_with_ad key n ad ctmac + | BS.length ctmac < 16 = Nothing + | otherwise = + let (ct, mac) = + BS.splitAt (BS.length ctmac - 16) ctmac + in case AEAD.decrypt ad key (encode_nonce n) + (ct, mac) of + Left _ -> Nothing + Right pt -> Just pt + +-- Encode nonce as 96-bit value: 4 zero bytes + 8-byte LE +encode_nonce :: Word64 -> BS.ByteString +encode_nonce n = BS.replicate 4 0x00 <> encode_le64 n + +-- Little-endian 64-bit encoding +encode_le64 :: Word64 -> BS.ByteString +encode_le64 n = BS.pack [ + fi (n .&. 0xff) + , fi (unsafeShiftR n 8 .&. 0xff) + , fi (unsafeShiftR n 16 .&. 0xff) + , fi (unsafeShiftR n 24 .&. 0xff) + , fi (unsafeShiftR n 32 .&. 0xff) + , fi (unsafeShiftR n 40 .&. 0xff) + , fi (unsafeShiftR n 48 .&. 0xff) + , fi (unsafeShiftR n 56 .&. 0xff) + ] + +-- Big-endian 16-bit encoding +encode_be16 :: Word16 -> BS.ByteString +encode_be16 n = + BS.pack [fi (unsafeShiftR n 8), fi (n .&. 0xff)] + +-- Big-endian 16-bit decoding +decode_be16 :: BS.ByteString -> Maybe Word16 +decode_be16 bs + | BS.length bs /= 2 = Nothing + | otherwise = + let !b0 = BS.index bs 0 + !b1 = BS.index bs 1 + in Just (fi b0 * 0x100 + fi b1) + +-- handshake ------------------------------------------------------- + +-- Initialize handshake state +-- +-- h = SHA256(protocol_name) +-- ck = h +-- h = SHA256(h || prologue) +-- h = SHA256(h || responder_static_pubkey) +init_handshake + :: Sec -- ^ local static secret + -> Pub -- ^ local static public + -> Sec -- ^ ephemeral secret + -> Pub -- ^ ephemeral public + -> Maybe Pub -- ^ remote static + -> Bool -- ^ True if initiator + -> HandshakeState +init_handshake s_sec s_pub e_sec e_pub m_rs is_init = + let !h0 = SHA256.hash _PROTOCOL_NAME + !ck = h0 + !h1 = mix_hash h0 _PROLOGUE + -- Mix in responder's static pubkey + !h2 = case (is_init, m_rs) of + (True, Just rs) -> + mix_hash h1 (serialize_pub rs) + (False, Nothing) -> + mix_hash h1 (serialize_pub s_pub) + _ -> h1 -- shouldn't happen + in HandshakeState { + hs_h = h2 + , hs_ck = ck + , hs_temp_k = BS.replicate 32 0x00 + , hs_e_sec = e_sec + , hs_e_pub = e_pub + , hs_s_sec = s_sec + , hs_s_pub = s_pub + , hs_re = Nothing + , hs_rs = m_rs + } + +-- | Initiator: generate Act 1 message (50 bytes). +-- +-- Takes local static key, remote static pubkey, and 32 +-- bytes of entropy for ephemeral key generation. +-- +-- Returns the 50-byte Act 1 message and handshake state +-- for Act 3. +-- +-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) +-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) +-- >>> let eph_ent = BS.replicate 32 0x12 +-- >>> case act1 i_sec i_pub r_pub eph_ent of { Right (msg, _) -> BS.length msg; Left _ -> 0 } +-- 50 +act1 + :: Sec + -> Pub + -> Pub + -> BS.ByteString + -> Either Error + (BS.ByteString, HandshakeFor Initiator) +act1 s_sec s_pub rs ent = do + (e_sec, e_pub) <- note InvalidKey (keypair ent) + let !hs0 = init_handshake + s_sec s_pub e_sec e_pub (Just rs) True + !e_pub_bytes = serialize_pub e_pub + !h1 = mix_hash (hs_h hs0) e_pub_bytes + es <- note InvalidKey (ecdh e_sec rs) + let !(ck1, temp_k1) = mix_key (hs_ck hs0) es + c <- note InvalidMAC + (encrypt_with_ad temp_k1 0 h1 BS.empty) + let !h2 = mix_hash h1 c + !msg = BS.singleton 0x00 <> e_pub_bytes <> c + !hs1 = hs0 { + hs_h = h2 + , hs_ck = ck1 + , hs_temp_k = temp_k1 + } + pure (msg, HandshakeFor hs1) + +-- | Responder: process Act 1 and generate Act 2 message +-- (50 bytes). +-- +-- Takes local static key and 32 bytes of entropy for +-- ephemeral key, plus the 50-byte Act 1 message from +-- initiator. +-- +-- Returns the 50-byte Act 2 message and handshake state +-- for finalize. +-- +-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) +-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) +-- >>> let Right (msg1, _) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) +-- >>> case act2 r_sec r_pub (BS.replicate 32 0x22) msg1 of { Right (msg, _) -> BS.length msg; Left _ -> 0 } +-- 50 +act2 + :: Sec + -> Pub + -> BS.ByteString + -> BS.ByteString + -> Either Error + (BS.ByteString, HandshakeFor Responder) +act2 s_sec s_pub ent msg1 = do + require (BS.length msg1 == 50) InvalidLength + let !version = BS.index msg1 0 + !re_bytes = BS.take 33 (BS.drop 1 msg1) + !c = BS.drop 34 msg1 + require (version == 0x00) InvalidVersion + re <- note InvalidPub (parse_pub re_bytes) + (e_sec, e_pub) <- note InvalidKey (keypair ent) + let !hs0 = init_handshake + s_sec s_pub e_sec e_pub Nothing False + !h1 = mix_hash (hs_h hs0) re_bytes + es <- note InvalidKey (ecdh s_sec re) + let !(ck1, temp_k1) = mix_key (hs_ck hs0) es + _ <- note InvalidMAC + (decrypt_with_ad temp_k1 0 h1 c) + let !h2 = mix_hash h1 c + !e_pub_bytes = serialize_pub e_pub + !h3 = mix_hash h2 e_pub_bytes + ee <- note InvalidKey (ecdh e_sec re) + let !(ck2, temp_k2) = mix_key ck1 ee + c2 <- note InvalidMAC + (encrypt_with_ad temp_k2 0 h3 BS.empty) + let !h4 = mix_hash h3 c2 + !msg = BS.singleton 0x00 <> e_pub_bytes <> c2 + !hs1 = hs0 { + hs_h = h4 + , hs_ck = ck2 + , hs_temp_k = temp_k2 + , hs_re = Just re + } + pure (msg, HandshakeFor hs1) + +-- | Initiator: process Act 2 and generate Act 3 (66 bytes), +-- completing the handshake. +-- +-- Returns the 66-byte Act 3 message and the handshake +-- result. +-- +-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) +-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) +-- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) +-- >>> let Right (msg2, _) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1 +-- >>> case act3 i_hs msg2 of { Right (msg, _) -> BS.length msg; Left _ -> 0 } +-- 66 +act3 + :: HandshakeFor Initiator + -> BS.ByteString + -> Either Error (BS.ByteString, Handshake) +act3 (HandshakeFor hs) msg2 = do + require (BS.length msg2 == 50) InvalidLength + let !version = BS.index msg2 0 + !re_bytes = BS.take 33 (BS.drop 1 msg2) + !c = BS.drop 34 msg2 + require (version == 0x00) InvalidVersion + re <- note InvalidPub (parse_pub re_bytes) + let !h1 = mix_hash (hs_h hs) re_bytes + ee <- note InvalidKey (ecdh (hs_e_sec hs) re) + let !(ck1, temp_k2) = mix_key (hs_ck hs) ee + _ <- note InvalidMAC + (decrypt_with_ad temp_k2 0 h1 c) + let !h2 = mix_hash h1 c + !s_pub_bytes = serialize_pub (hs_s_pub hs) + c3 <- note InvalidMAC + (encrypt_with_ad temp_k2 1 h2 s_pub_bytes) + let !h3 = mix_hash h2 c3 + se <- note InvalidKey (ecdh (hs_s_sec hs) re) + let !(ck2, temp_k3) = mix_key ck1 se + t <- note InvalidMAC + (encrypt_with_ad temp_k3 0 h3 BS.empty) + let !(sk, rk) = mix_key ck2 BS.empty + !msg = BS.singleton 0x00 <> c3 <> t + !sess = Session { + sess_sk = Key32 sk + , sess_sn = SessionNonce 0 + , sess_sck = Key32 ck2 + , sess_rk = Key32 rk + , sess_rn = SessionNonce 0 + , sess_rck = Key32 ck2 + } + rs <- note InvalidPub (hs_rs hs) + let !result = Handshake { + session = sess + , remote_static = rs + } + pure (msg, result) + +-- | Responder: process Act 3 (66 bytes) and complete the +-- handshake. +-- +-- Returns the handshake result with authenticated remote +-- static pubkey. +-- +-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) +-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) +-- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) +-- >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1 +-- >>> let Right (msg3, _) = act3 i_hs msg2 +-- >>> case finalize r_hs msg3 of { Right _ -> "ok"; Left e -> show e } +-- "ok" +finalize + :: HandshakeFor Responder + -> BS.ByteString + -> Either Error Handshake +finalize (HandshakeFor hs) msg3 = do + require (BS.length msg3 == 66) InvalidLength + let !version = BS.index msg3 0 + !c = BS.take 49 (BS.drop 1 msg3) + !t = BS.drop 50 msg3 + require (version == 0x00) InvalidVersion + rs_bytes <- note InvalidMAC + (decrypt_with_ad (hs_temp_k hs) 1 (hs_h hs) c) + rs <- note InvalidPub (parse_pub rs_bytes) + let !h1 = mix_hash (hs_h hs) c + se <- note InvalidKey (ecdh (hs_e_sec hs) rs) + let !(ck1, temp_k3) = mix_key (hs_ck hs) se + _ <- note InvalidMAC + (decrypt_with_ad temp_k3 0 h1 t) + -- responder swaps order (receives what initiator sends) + let !(rk, sk) = mix_key ck1 BS.empty + !sess = Session { + sess_sk = Key32 sk + , sess_sn = SessionNonce 0 + , sess_sck = Key32 ck1 + , sess_rk = Key32 rk + , sess_rn = SessionNonce 0 + , sess_rck = Key32 ck1 + } + !result = Handshake { + session = sess + , remote_static = rs + } + pure result + +-- message encryption ---------------------------------------------- + +-- | Encrypt a message (max 65535 bytes). +-- +-- Returns the encrypted packet and updated session. Key +-- rotation is handled automatically at nonce 1000. +-- +-- Wire format: +-- encrypted_length (2) || MAC (16) +-- || encrypted_body || MAC (16) +-- +-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) +-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) +-- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) +-- >>> let Right (msg2, _) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1 +-- >>> let Right (_, i_result) = act3 i_hs msg2 +-- >>> let sess = session i_result +-- >>> case encrypt sess "hello" of { Right (ct, _) -> BS.length ct; Left _ -> 0 } +-- 39 +encrypt + :: Session + -> BS.ByteString + -> Either Error (BS.ByteString, Session) +encrypt sess pt = do + let !len = BS.length pt + require (len <= 65535) InvalidLength + let !len_bytes = encode_be16 (fi len) + !sk = unKey32 (sess_sk sess) + !sn = unSessionNonce (sess_sn sess) + !sck = unKey32 (sess_sck sess) + lc <- note InvalidMAC + (encrypt_with_ad sk sn BS.empty len_bytes) + let !(sn1, sck1, sk1) = step_nonce sn sck sk + bc <- note InvalidMAC + (encrypt_with_ad sk1 sn1 BS.empty pt) + let !(sn2, sck2, sk2) = step_nonce sn1 sck1 sk1 + !packet = lc <> bc + !sess' = sess { + sess_sk = Key32 sk2 + , sess_sn = SessionNonce sn2 + , sess_sck = Key32 sck2 + } + pure (packet, sess') + +-- | Decrypt a message, requiring an exact packet with no +-- trailing bytes. +-- +-- Returns the plaintext and updated session. Key rotation +-- is handled automatically at nonce 1000. +-- +-- This is a strict variant that rejects any trailing data. +-- For streaming use cases where you need to handle multiple +-- frames in a buffer, use 'decrypt_frame' instead. +-- +-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) +-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) +-- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) +-- >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1 +-- >>> let Right (msg3, i_result) = act3 i_hs msg2 +-- >>> let Right r_result = finalize r_hs msg3 +-- >>> let Right (ct, _) = encrypt (session i_result) "hello" +-- >>> case decrypt (session r_result) ct of { Right (pt, _) -> pt; Left _ -> "fail" } +-- "hello" +decrypt + :: Session + -> BS.ByteString + -> Either Error (BS.ByteString, Session) +decrypt sess packet = do + (pt, remainder, sess') <- decrypt_frame sess packet + require (BS.null remainder) InvalidLength + pure (pt, sess') + +-- | Decrypt a single frame from a buffer, returning the +-- remainder. +-- +-- Returns the plaintext, any unconsumed bytes, and the +-- updated session. Key rotation is handled automatically +-- every 1000 messages. +-- +-- This is useful for streaming scenarios where multiple +-- messages may be buffered together. The remainder can be +-- passed to the next call to 'decrypt_frame'. +-- +-- Wire format consumed: +-- encrypted_length (18) || encrypted_body (len + 16) +-- +-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) +-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) +-- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) +-- >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1 +-- >>> let Right (msg3, i_result) = act3 i_hs msg2 +-- >>> let Right r_result = finalize r_hs msg3 +-- >>> let Right (ct, _) = encrypt (session i_result) "hello" +-- >>> case decrypt_frame (session r_result) ct of { Right (pt, rem, _) -> (pt, BS.null rem); Left _ -> ("fail", False) } +-- ("hello",True) +decrypt_frame + :: Session + -> BS.ByteString + -> Either Error + (BS.ByteString, BS.ByteString, Session) +decrypt_frame sess packet = do + require (BS.length packet >= 34) InvalidLength + let !lc = BS.take 18 packet + !rest = BS.drop 18 packet + !rk = unKey32 (sess_rk sess) + !rn = unSessionNonce (sess_rn sess) + !rck = unKey32 (sess_rck sess) + len_bytes <- note InvalidMAC + (decrypt_with_ad rk rn BS.empty lc) + len <- note InvalidLength (decode_be16 len_bytes) + let !(rn1, rck1, rk1) = step_nonce rn rck rk + !body_len = fi len + 16 + require (BS.length rest >= body_len) InvalidLength + let !bc = BS.take body_len rest + !remainder = BS.drop body_len rest + pt <- note InvalidMAC + (decrypt_with_ad rk1 rn1 BS.empty bc) + let !(rn2, rck2, rk2) = step_nonce rn1 rck1 rk1 + !sess' = sess { + sess_rk = Key32 rk2 + , sess_rn = SessionNonce rn2 + , sess_rck = Key32 rck2 + } + pure (pt, remainder, sess') + +-- | Decrypt a frame from a partial buffer, indicating when +-- more data needed. +-- +-- Unlike 'decrypt_frame', this function handles incomplete +-- buffers gracefully by returning 'NeedMore' with the +-- number of additional bytes required to make progress. +-- +-- * If the buffer has fewer than 18 bytes (encrypted +-- length + MAC), returns @'NeedMore' n@ where @n@ is +-- the bytes still needed. +-- * If the length header is complete but the body is +-- incomplete, returns @'NeedMore' n@ with bytes needed +-- for the full frame. +-- * MAC or decryption failures return 'FrameError'. +-- * A complete, valid frame returns 'FrameOk' with +-- plaintext, remainder, and updated session. +-- +-- This is useful for non-blocking I/O where data arrives +-- incrementally. +decrypt_frame_partial + :: Session + -> BS.ByteString + -> FrameResult +decrypt_frame_partial sess buf + | buflen < 18 = NeedMore (18 - buflen) + | otherwise = + let !lc = BS.take 18 buf + !rest = BS.drop 18 buf + !rk = unKey32 (sess_rk sess) + !rn = unSessionNonce (sess_rn sess) + !rck = unKey32 (sess_rck sess) + in case decrypt_with_ad rk rn BS.empty lc of + Nothing -> FrameError InvalidMAC + Just len_bytes -> + case decode_be16 len_bytes of + Nothing -> FrameError InvalidLength + Just len -> + let !body_len = fi len + 16 + !(rn1, rck1, rk1) = + step_nonce rn rck rk + in if BS.length rest < body_len + then NeedMore + (body_len - BS.length rest) + else + let !bc = BS.take body_len rest + !remainder = + BS.drop body_len rest + in case decrypt_with_ad + rk1 rn1 BS.empty bc of + Nothing -> + FrameError InvalidMAC + Just pt -> + let !(rn2, rck2, rk2) = + step_nonce rn1 rck1 rk1 + !sess' = sess { + sess_rk = Key32 rk2 + , sess_rn = + SessionNonce rn2 + , sess_rck = Key32 rck2 + } + in FrameOk pt remainder sess' + where + !buflen = BS.length buf + +-- key rotation ---------------------------------------------------- + +-- Key rotation occurs after nonce reaches 1000 (i.e., before +-- using 1000) +-- (ck', k') = HKDF(ck, k), reset nonce to 0 +step_nonce + :: Word64 + -> BS.ByteString + -> BS.ByteString + -> (Word64, BS.ByteString, BS.ByteString) +step_nonce n ck k + | n + 1 == 1000 = + let !(ck', k') = mix_key ck k + in (0, ck', k') + | otherwise = (n + 1, ck, k) + +-- utilities ------------------------------------------------------- + +-- Lift Maybe to Either +note :: e -> Maybe a -> Either e a +note e = maybe (Left e) Right +{-# INLINE note #-} + +-- Require condition or fail +require :: Bool -> e -> Either e () +require cond e = unless cond (Left e) +{-# INLINE require #-} + +fi :: (Integral a, Num b) => a -> b +fi = fromIntegral +{-# INLINE fi #-} diff --git a/ppad-bolt8.cabal b/ppad-bolt8.cabal @@ -25,6 +25,7 @@ library -Wall exposed-modules: Lightning.Protocol.BOLT8 + Lightning.Protocol.BOLT8.Internal build-depends: base >= 4.9 && < 5 , bytestring >= 0.9 && < 0.13