commit 29d72235097657273919adcf73f4030d6d7a9b7d
parent a82df07d3cf43cb46d609887b555e3cc587ebb58
Author: Jared Tobin <jared@jtobin.io>
Date: Sun, 11 Jan 2026 11:29:01 +0400
lib: minor refactoring
Diffstat:
1 file changed, 77 insertions(+), 224 deletions(-)
diff --git a/lib/Lightning/Protocol/BOLT8.hs b/lib/Lightning/Protocol/BOLT8.hs
@@ -43,11 +43,12 @@ module Lightning.Protocol.BOLT8 (
, Error(..)
) where
-import Control.Monad (guard)
+import Control.Monad (guard, unless)
import qualified Crypto.AEAD.ChaCha20Poly1305 as AEAD
import qualified Crypto.Curve.Secp256k1 as Secp256k1
import qualified Crypto.Hash.SHA256 as SHA256
import qualified Crypto.KDF.HMAC as HKDF
+import Data.Bits (unsafeShiftR, (.&.))
import qualified Data.ByteString as BS
import Data.Word (Word16, Word64)
@@ -211,19 +212,19 @@ encode_nonce n = BS.replicate 4 0x00 <> encode_le64 n
-- Little-endian 64-bit encoding
encode_le64 :: Word64 -> BS.ByteString
encode_le64 n = BS.pack [
- fi n
- , fi (n `div` 0x100)
- , fi (n `div` 0x10000)
- , fi (n `div` 0x1000000)
- , fi (n `div` 0x100000000)
- , fi (n `div` 0x10000000000)
- , fi (n `div` 0x1000000000000)
- , fi (n `div` 0x100000000000000)
+ fi (n .&. 0xff)
+ , fi (unsafeShiftR n 8 .&. 0xff)
+ , fi (unsafeShiftR n 16 .&. 0xff)
+ , fi (unsafeShiftR n 24 .&. 0xff)
+ , fi (unsafeShiftR n 32 .&. 0xff)
+ , fi (unsafeShiftR n 40 .&. 0xff)
+ , fi (unsafeShiftR n 48 .&. 0xff)
+ , fi (unsafeShiftR n 56 .&. 0xff)
]
-- Big-endian 16-bit encoding
encode_be16 :: Word16 -> BS.ByteString
-encode_be16 n = BS.pack [fi (n `div` 0x100), fi n]
+encode_be16 n = BS.pack [fi (unsafeShiftR n 8), fi (n .&. 0xff)]
-- Big-endian 16-bit decoding
decode_be16 :: BS.ByteString -> Maybe Word16
@@ -290,19 +291,13 @@ initiator_act1
-> BS.ByteString -- ^ 32 bytes entropy for ephemeral
-> Either Error (BS.ByteString, HandshakeState)
initiator_act1 s_sec s_pub rs ent = do
- -- Generate ephemeral keypair
- (e_sec, e_pub) <- maybe (Left InvalidKey) Right (keypair ent)
-
+ (e_sec, e_pub) <- note InvalidKey (keypair ent)
let !hs0 = init_handshake s_sec s_pub e_sec e_pub (Just rs) True
!e_pub_bytes = serialize_pub e_pub
!h1 = mix_hash (hs_h hs0) e_pub_bytes
-
- es <- maybe (Left InvalidKey) Right (ecdh e_sec rs)
-
+ es <- note InvalidKey (ecdh e_sec rs)
let !(ck1, temp_k1) = mix_key (hs_ck hs0) es
-
- c <- maybe (Left InvalidMAC) Right (encrypt_with_ad temp_k1 0 h1 BS.empty)
-
+ c <- note InvalidMAC (encrypt_with_ad temp_k1 0 h1 BS.empty)
let !h2 = mix_hash h1 c
!msg = BS.singleton 0x00 <> e_pub_bytes <> c
!hs1 = hs0 {
@@ -310,8 +305,7 @@ initiator_act1 s_sec s_pub rs ent = do
, hs_ck = ck1
, hs_temp_k = temp_k1
}
-
- Right (msg, hs1)
+ pure (msg, hs1)
-- | Responder: process Act 1 and generate Act 2 message (50 bytes).
--
@@ -332,73 +326,33 @@ responder_act2
-> BS.ByteString -- ^ Act 1 message (50 bytes)
-> Either Error (BS.ByteString, HandshakeState)
responder_act2 s_sec s_pub ent act1 = do
- -- Validate length
- if BS.length act1 /= 50
- then Left InvalidLength
- else pure ()
-
- -- Parse Act 1: version || e.pub || c
+ require (BS.length act1 == 50) InvalidLength
let !version = BS.index act1 0
!re_bytes = BS.take 33 (BS.drop 1 act1)
!c = BS.drop 34 act1
-
- -- Validate version
- if version /= 0x00
- then Left InvalidVersion
- else pure ()
-
- -- Parse remote ephemeral
- re <- maybe (Left InvalidPub) Right (parse_pub re_bytes)
-
- -- Generate our ephemeral keypair
- (e_sec, e_pub) <- maybe (Left InvalidKey) Right (keypair ent)
-
- -- Initialize state (responder doesn't know remote static yet)
+ require (version == 0x00) InvalidVersion
+ re <- note InvalidPub (parse_pub re_bytes)
+ (e_sec, e_pub) <- note InvalidKey (keypair ent)
let !hs0 = init_handshake s_sec s_pub e_sec e_pub Nothing False
-
- -- h = SHA256(h || re)
- let !h1 = mix_hash (hs_h hs0) re_bytes
-
- -- es = ECDH(s.priv, re)
- es <- maybe (Left InvalidKey) Right (ecdh s_sec re)
-
- -- ck, temp_k1 = HKDF(ck, es)
+ !h1 = mix_hash (hs_h hs0) re_bytes
+ es <- note InvalidKey (ecdh s_sec re)
let !(ck1, temp_k1) = mix_key (hs_ck hs0) es
-
- -- Decrypt and verify MAC
- _ <- maybe (Left InvalidMAC) Right (decrypt_with_ad temp_k1 0 h1 c)
-
- -- h = SHA256(h || c)
+ _ <- note InvalidMAC (decrypt_with_ad temp_k1 0 h1 c)
let !h2 = mix_hash h1 c
-
- -- Now generate Act 2
- -- h = SHA256(h || e.pub)
- let !e_pub_bytes = serialize_pub e_pub
+ !e_pub_bytes = serialize_pub e_pub
!h3 = mix_hash h2 e_pub_bytes
-
- -- ee = ECDH(e.priv, re)
- ee <- maybe (Left InvalidKey) Right (ecdh e_sec re)
-
- -- ck, temp_k2 = HKDF(ck, ee)
+ ee <- note InvalidKey (ecdh e_sec re)
let !(ck2, temp_k2) = mix_key ck1 ee
-
- -- c2 = encrypt(temp_k2, 0, h, "")
- c2 <- maybe (Left InvalidMAC) Right (encrypt_with_ad temp_k2 0 h3 BS.empty)
-
- -- h = SHA256(h || c2)
+ c2 <- note InvalidMAC (encrypt_with_ad temp_k2 0 h3 BS.empty)
let !h4 = mix_hash h3 c2
-
- -- Build message: version || e.pub || c2
- let !msg = BS.singleton 0x00 <> e_pub_bytes <> c2
-
- let !hs1 = hs0 {
+ !msg = BS.singleton 0x00 <> e_pub_bytes <> c2
+ !hs1 = hs0 {
hs_h = h4
, hs_ck = ck2
, hs_temp_k = temp_k2
, hs_re = Just re
}
-
- Right (msg, hs1)
+ pure (msg, hs1)
-- | Initiator: process Act 2 and generate Act 3 (66 bytes), completing
-- the handshake.
@@ -416,64 +370,26 @@ initiator_act3
-> BS.ByteString -- ^ Act 2 message (50 bytes)
-> Either Error (BS.ByteString, HandshakeResult)
initiator_act3 hs act2 = do
- -- Validate length
- if BS.length act2 /= 50
- then Left InvalidLength
- else pure ()
-
- -- Parse Act 2: version || e.pub || c
+ require (BS.length act2 == 50) InvalidLength
let !version = BS.index act2 0
!re_bytes = BS.take 33 (BS.drop 1 act2)
!c = BS.drop 34 act2
-
- -- Validate version
- if version /= 0x00
- then Left InvalidVersion
- else pure ()
-
- -- Parse remote ephemeral
- re <- maybe (Left InvalidPub) Right (parse_pub re_bytes)
-
- -- h = SHA256(h || re)
+ require (version == 0x00) InvalidVersion
+ re <- note InvalidPub (parse_pub re_bytes)
let !h1 = mix_hash (hs_h hs) re_bytes
-
- -- ee = ECDH(e.priv, re)
- ee <- maybe (Left InvalidKey) Right (ecdh (hs_e_sec hs) re)
-
- -- ck, temp_k2 = HKDF(ck, ee)
+ ee <- note InvalidKey (ecdh (hs_e_sec hs) re)
let !(ck1, temp_k2) = mix_key (hs_ck hs) ee
-
- -- Decrypt and verify MAC
- _ <- maybe (Left InvalidMAC) Right (decrypt_with_ad temp_k2 0 h1 c)
-
- -- h = SHA256(h || c)
+ _ <- note InvalidMAC (decrypt_with_ad temp_k2 0 h1 c)
let !h2 = mix_hash h1 c
-
- -- Now generate Act 3
- -- c = encrypt(temp_k2, 1, h, s.pub)
- let !s_pub_bytes = serialize_pub (hs_s_pub hs)
- c3 <- maybe (Left InvalidMAC) Right (encrypt_with_ad temp_k2 1 h2 s_pub_bytes)
-
- -- h = SHA256(h || c)
+ !s_pub_bytes = serialize_pub (hs_s_pub hs)
+ c3 <- note InvalidMAC (encrypt_with_ad temp_k2 1 h2 s_pub_bytes)
let !h3 = mix_hash h2 c3
-
- -- se = ECDH(s.priv, re)
- se <- maybe (Left InvalidKey) Right (ecdh (hs_s_sec hs) re)
-
- -- ck, temp_k3 = HKDF(ck, se)
+ se <- note InvalidKey (ecdh (hs_s_sec hs) re)
let !(ck2, temp_k3) = mix_key ck1 se
-
- -- t = encrypt(temp_k3, 0, h, "")
- t <- maybe (Left InvalidMAC) Right (encrypt_with_ad temp_k3 0 h3 BS.empty)
-
- -- Derive session keys: sk, rk = HKDF(ck, "")
+ t <- note InvalidMAC (encrypt_with_ad temp_k3 0 h3 BS.empty)
let !(sk, rk) = mix_key ck2 BS.empty
-
- -- Build message: version || c || t
- let !msg = BS.singleton 0x00 <> c3 <> t
-
- -- Build session (initiator: sk = send, rk = receive)
- let !session = Session {
+ !msg = BS.singleton 0x00 <> c3 <> t
+ !session = Session {
sess_sk = sk
, sess_sn = 0
, sess_sck = ck2
@@ -481,16 +397,12 @@ initiator_act3 hs act2 = do
, sess_rn = 0
, sess_rck = ck2
}
-
- -- Get remote static from handshake state (we knew it from the start)
- rs <- maybe (Left InvalidPub) Right (hs_rs hs)
-
+ rs <- note InvalidPub (hs_rs hs)
let !result = HandshakeResult {
hr_session = session
, hr_remote_pk = rs
}
-
- Right (msg, result)
+ pure (msg, result)
-- | Responder: process Act 3 (66 bytes) and complete the handshake.
--
@@ -508,46 +420,20 @@ responder_finalize
-> BS.ByteString -- ^ Act 3 message (66 bytes)
-> Either Error HandshakeResult
responder_finalize hs act3 = do
- -- Validate length
- if BS.length act3 /= 66
- then Left InvalidLength
- else pure ()
-
- -- Parse Act 3: version || encrypted_static (49 bytes) || t (16 bytes)
+ require (BS.length act3 == 66) InvalidLength
let !version = BS.index act3 0
!c = BS.take 49 (BS.drop 1 act3)
!t = BS.drop 50 act3
-
- -- Validate version
- if version /= 0x00
- then Left InvalidVersion
- else pure ()
-
- -- Decrypt static key: rs = decrypt(temp_k2, 1, h, c)
- rs_bytes <- maybe (Left InvalidMAC) Right
- (decrypt_with_ad (hs_temp_k hs) 1 (hs_h hs) c)
-
- -- Parse remote static
- rs <- maybe (Left InvalidPub) Right (parse_pub rs_bytes)
-
- -- h = SHA256(h || c)
+ require (version == 0x00) InvalidVersion
+ rs_bytes <- note InvalidMAC (decrypt_with_ad (hs_temp_k hs) 1 (hs_h hs) c)
+ rs <- note InvalidPub (parse_pub rs_bytes)
let !h1 = mix_hash (hs_h hs) c
-
- -- se = ECDH(e.priv, rs)
- se <- maybe (Left InvalidKey) Right (ecdh (hs_e_sec hs) rs)
-
- -- ck, temp_k3 = HKDF(ck, se)
+ se <- note InvalidKey (ecdh (hs_e_sec hs) rs)
let !(ck1, temp_k3) = mix_key (hs_ck hs) se
-
- -- Decrypt and verify final MAC
- _ <- maybe (Left InvalidMAC) Right (decrypt_with_ad temp_k3 0 h1 t)
-
- -- Derive session keys: rk, sk = HKDF(ck, "")
- -- Note: responder swaps order (receives what initiator sends)
+ _ <- note InvalidMAC (decrypt_with_ad temp_k3 0 h1 t)
+ -- responder swaps order (receives what initiator sends)
let !(rk, sk) = mix_key ck1 BS.empty
-
- -- Build session (responder: sk = send, rk = receive)
- let !session = Session {
+ !session = Session {
sess_sk = sk
, sess_sn = 0
, sess_sck = ck1
@@ -555,13 +441,11 @@ responder_finalize hs act3 = do
, sess_rn = 0
, sess_rck = ck1
}
-
- let !result = HandshakeResult {
+ !result = HandshakeResult {
hr_session = session
, hr_remote_pk = rs
}
-
- Right result
+ pure result
-- message encryption --------------------------------------------------------
@@ -585,38 +469,21 @@ encrypt_message
-> BS.ByteString -- ^ plaintext (max 65535 bytes)
-> Either Error (BS.ByteString, Session)
encrypt_message sess pt = do
- -- Validate length
let !len = BS.length pt
- if len > 65535
- then Left InvalidLength
- else pure ()
-
- -- Encrypt length (2-byte big-endian)
+ require (len <= 65535) InvalidLength
let !len_bytes = encode_be16 (fi len)
- lc <- maybe (Left InvalidMAC) Right
- (encrypt_with_ad (sess_sk sess) (sess_sn sess) BS.empty len_bytes)
-
- -- Step nonce (possibly rotate)
+ lc <- note InvalidMAC (encrypt_with_ad (sess_sk sess) (sess_sn sess)
+ BS.empty len_bytes)
let !(sn1, sck1, sk1) = step_nonce (sess_sn sess) (sess_sck sess) (sess_sk sess)
-
- -- Encrypt body
- bc <- maybe (Left InvalidMAC) Right
- (encrypt_with_ad sk1 sn1 BS.empty pt)
-
- -- Step nonce again (possibly rotate)
+ bc <- note InvalidMAC (encrypt_with_ad sk1 sn1 BS.empty pt)
let !(sn2, sck2, sk2) = step_nonce sn1 sck1 sk1
-
- -- Build packet
- let !packet = lc <> bc
-
- -- Update session
- let !sess' = sess {
+ !packet = lc <> bc
+ !sess' = sess {
sess_sk = sk2
, sess_sn = sn2
, sess_sck = sck2
}
-
- Right (packet, sess')
+ pure (packet, sess')
-- | Decrypt a message.
--
@@ -637,48 +504,24 @@ decrypt_message
-> BS.ByteString -- ^ encrypted packet
-> Either Error (BS.ByteString, Session)
decrypt_message sess packet = do
- -- Need at least length ciphertext (18 bytes) + body MAC (16 bytes)
- if BS.length packet < 34
- then Left InvalidLength
- else pure ()
-
- -- Split length ciphertext
+ require (BS.length packet >= 34) InvalidLength
let !lc = BS.take 18 packet
!rest = BS.drop 18 packet
-
- -- Decrypt length
- len_bytes <- maybe (Left InvalidMAC) Right
- (decrypt_with_ad (sess_rk sess) (sess_rn sess) BS.empty lc)
-
- len <- maybe (Left InvalidLength) Right (decode_be16 len_bytes)
-
- -- Step nonce (possibly rotate)
+ len_bytes <- note InvalidMAC (decrypt_with_ad (sess_rk sess) (sess_rn sess)
+ BS.empty lc)
+ len <- note InvalidLength (decode_be16 len_bytes)
let !(rn1, rck1, rk1) = step_nonce (sess_rn sess) (sess_rck sess) (sess_rk sess)
-
- -- Validate we have enough data for body
- let !body_len = fi len + 16
- if BS.length rest < body_len
- then Left InvalidLength
- else pure ()
-
- -- Split body ciphertext
+ !body_len = fi len + 16
+ require (BS.length rest >= body_len) InvalidLength
let !bc = BS.take body_len rest
-
- -- Decrypt body
- pt <- maybe (Left InvalidMAC) Right
- (decrypt_with_ad rk1 rn1 BS.empty bc)
-
- -- Step nonce again (possibly rotate)
+ pt <- note InvalidMAC (decrypt_with_ad rk1 rn1 BS.empty bc)
let !(rn2, rck2, rk2) = step_nonce rn1 rck1 rk1
-
- -- Update session
- let !sess' = sess {
+ !sess' = sess {
sess_rk = rk2
, sess_rn = rn2
, sess_rck = rck2
}
-
- Right (pt, sess')
+ pure (pt, sess')
-- key rotation --------------------------------------------------------------
@@ -697,6 +540,16 @@ step_nonce n ck k
-- utilities -----------------------------------------------------------------
+-- Lift Maybe to Either
+note :: e -> Maybe a -> Either e a
+note e = maybe (Left e) Right
+{-# INLINE note #-}
+
+-- Require condition or fail
+require :: Bool -> e -> Either e ()
+require cond e = unless cond (Left e)
+{-# INLINE require #-}
+
fi :: (Integral a, Num b) => a -> b
fi = fromIntegral
{-# INLINE fi #-}