bolt8

Encrypted and authenticated transport, per BOLT #8.
git clone git://git.ppad.tech/bolt8.git
Log | Files | Refs | README | LICENSE

commit b35594af9fe488daffa823331c81339e4f5c9a93
parent c47ceec3852147cc3d10f1ca586ad37e8e4aace2
Author: Jared Tobin <jared@jtobin.io>
Date:   Sun, 25 Jan 2026 09:31:34 +0400

lib: add decrypt_frame, make decrypt strict

- decrypt_frame decrypts a single frame and returns the remainder,
  useful for streaming scenarios with buffered data
- decrypt now rejects trailing bytes (wraps decrypt_frame)
- added tests for framing behavior

Diffstat:
Mlib/Lightning/Protocol/BOLT8.hs | 43+++++++++++++++++++++++++++++++++++++++----
Mtest/Main.hs | 119+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 158 insertions(+), 4 deletions(-)

diff --git a/lib/Lightning/Protocol/BOLT8.hs b/lib/Lightning/Protocol/BOLT8.hs @@ -81,6 +81,7 @@ module Lightning.Protocol.BOLT8 ( , Handshake(..) , encrypt , decrypt + , decrypt_frame -- * Errors , Error(..) @@ -532,10 +533,14 @@ encrypt sess pt = do } pure (packet, sess') --- | Decrypt a message. +-- | Decrypt a message, requiring an exact packet with no trailing bytes. -- -- Returns the plaintext and updated session. Key rotation --- is handled automatically every 500 messages. +-- is handled automatically every 1000 messages. +-- +-- This is a strict variant that rejects any trailing data. For +-- streaming use cases where you need to handle multiple frames in a +-- buffer, use 'decrypt_frame' instead. -- -- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) -- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) @@ -548,9 +553,38 @@ encrypt sess pt = do -- "hello" decrypt :: Session - -> BS.ByteString -- ^ encrypted packet + -> BS.ByteString -- ^ encrypted packet (exact length required) -> Either Error (BS.ByteString, Session) decrypt sess packet = do + (pt, remainder, sess') <- decrypt_frame sess packet + require (BS.null remainder) InvalidLength + pure (pt, sess') + +-- | Decrypt a single frame from a buffer, returning the remainder. +-- +-- Returns the plaintext, any unconsumed bytes, and the updated session. +-- Key rotation is handled automatically every 1000 messages. +-- +-- This is useful for streaming scenarios where multiple messages may +-- be buffered together. The remainder can be passed to the next call +-- to 'decrypt_frame'. +-- +-- Wire format consumed: encrypted_length (18) || encrypted_body (len + 16) +-- +-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11) +-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21) +-- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12) +-- >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1 +-- >>> let Right (msg3, i_result) = act3 i_hs msg2 +-- >>> let Right r_result = finalize r_hs msg3 +-- >>> let Right (ct, _) = encrypt (session i_result) "hello" +-- >>> case decrypt_frame (session r_result) ct of { Right (pt, rem, _) -> (pt, BS.null rem); Left _ -> ("fail", False) } +-- ("hello",True) +decrypt_frame + :: Session + -> BS.ByteString -- ^ buffer containing at least one encrypted frame + -> Either Error (BS.ByteString, BS.ByteString, Session) +decrypt_frame sess packet = do require (BS.length packet >= 34) InvalidLength let !lc = BS.take 18 packet !rest = BS.drop 18 packet @@ -561,6 +595,7 @@ decrypt sess packet = do !body_len = fi len + 16 require (BS.length rest >= body_len) InvalidLength let !bc = BS.take body_len rest + !remainder = BS.drop body_len rest pt <- note InvalidMAC (decrypt_with_ad rk1 rn1 BS.empty bc) let !(rn2, rck2, rk2) = step_nonce rn1 rck1 rk1 !sess' = sess { @@ -568,7 +603,7 @@ decrypt sess packet = do , sess_rn = rn2 , sess_rck = rck2 } - pure (pt, sess') + pure (pt, remainder, sess') -- key rotation -------------------------------------------------------------- diff --git a/test/Main.hs b/test/Main.hs @@ -12,6 +12,7 @@ main :: IO () main = defaultMain $ testGroup "ppad-bolt8" [ handshake_tests , message_tests + , framing_tests ] -- test vectors from BOLT #8 specification ----------------------------------- @@ -268,6 +269,124 @@ test_decrypt_roundtrip = do assertFailure $ "decrypt failed: " ++ show err Right (pt, _) -> pt @?= hello +-- framing tests ------------------------------------------------------------- + +framing_tests :: TestTree +framing_tests = testGroup "Packet Framing" [ + testCase "decrypt rejects trailing bytes" test_decrypt_trailing + , testCase "decrypt_frame returns remainder" test_decrypt_frame_remainder + , testCase "decrypt_frame handles multiple frames" test_decrypt_frame_multi + ] + +test_decrypt_trailing :: Assertion +test_decrypt_trailing = do + let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv + Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv + Just rs = BOLT8.parse_pub responder_s_pub + + case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of + Left err -> assertFailure $ "act1 failed: " ++ show err + Right (msg1, i_hs) -> + case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of + Left err -> assertFailure $ "act2 failed: " ++ show err + Right (msg2, r_hs) -> + case BOLT8.act3 i_hs msg2 of + Left err -> assertFailure $ "act3 failed: " ++ show err + Right (msg3, i_result) -> + case BOLT8.finalize r_hs msg3 of + Left err -> assertFailure $ "finalize failed: " ++ show err + Right r_result -> do + let i_sess = BOLT8.session i_result + r_sess = BOLT8.session r_result + case BOLT8.encrypt i_sess hello of + Left err -> assertFailure $ "encrypt failed: " ++ show err + Right (ct, _) -> do + -- append trailing bytes + let ct_with_trailing = ct <> "extra" + case BOLT8.decrypt r_sess ct_with_trailing of + Left BOLT8.InvalidLength -> pure () + Left err -> + assertFailure $ "expected InvalidLength, got: " + ++ show err + Right _ -> + assertFailure "decrypt should reject trailing bytes" + +test_decrypt_frame_remainder :: Assertion +test_decrypt_frame_remainder = do + let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv + Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv + Just rs = BOLT8.parse_pub responder_s_pub + + case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of + Left err -> assertFailure $ "act1 failed: " ++ show err + Right (msg1, i_hs) -> + case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of + Left err -> assertFailure $ "act2 failed: " ++ show err + Right (msg2, r_hs) -> + case BOLT8.act3 i_hs msg2 of + Left err -> assertFailure $ "act3 failed: " ++ show err + Right (msg3, i_result) -> + case BOLT8.finalize r_hs msg3 of + Left err -> assertFailure $ "finalize failed: " ++ show err + Right r_result -> do + let i_sess = BOLT8.session i_result + r_sess = BOLT8.session r_result + case BOLT8.encrypt i_sess hello of + Left err -> assertFailure $ "encrypt failed: " ++ show err + Right (ct, _) -> do + let trailing = "remainder" + ct_with_trailing = ct <> trailing + case BOLT8.decrypt_frame r_sess ct_with_trailing of + Left err -> + assertFailure $ "decrypt_frame failed: " ++ show err + Right (pt, remainder, _) -> do + pt @?= hello + remainder @?= trailing + +test_decrypt_frame_multi :: Assertion +test_decrypt_frame_multi = do + let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv + Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv + Just rs = BOLT8.parse_pub responder_s_pub + + case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of + Left err -> assertFailure $ "act1 failed: " ++ show err + Right (msg1, i_hs) -> + case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of + Left err -> assertFailure $ "act2 failed: " ++ show err + Right (msg2, r_hs) -> + case BOLT8.act3 i_hs msg2 of + Left err -> assertFailure $ "act3 failed: " ++ show err + Right (msg3, i_result) -> + case BOLT8.finalize r_hs msg3 of + Left err -> assertFailure $ "finalize failed: " ++ show err + Right r_result -> do + let i_sess = BOLT8.session i_result + r_sess = BOLT8.session r_result + -- encrypt two messages + case BOLT8.encrypt i_sess "first" of + Left err -> assertFailure $ "encrypt 1 failed: " ++ show err + Right (ct1, i_sess') -> + case BOLT8.encrypt i_sess' "second" of + Left err -> + assertFailure $ "encrypt 2 failed: " ++ show err + Right (ct2, _) -> do + -- concatenate frames + let buffer = ct1 <> ct2 + -- decrypt first frame + case BOLT8.decrypt_frame r_sess buffer of + Left err -> + assertFailure $ "frame 1 failed: " ++ show err + Right (pt1, rest, r_sess') -> do + pt1 @?= "first" + -- decrypt second frame from remainder + case BOLT8.decrypt_frame r_sess' rest of + Left err -> + assertFailure $ "frame 2 failed: " ++ show err + Right (pt2, rest2, _) -> do + pt2 @?= "second" + rest2 @?= BS.empty + -- utilities ----------------------------------------------------------------- hex :: BS.ByteString -> BS.ByteString