commit c3040f68303ff39e02a8b61074f02074662315d6
parent c47ceec3852147cc3d10f1ca586ad37e8e4aace2
Author: Jared Tobin <jared@jtobin.io>
Date: Sun, 25 Jan 2026 09:32:09 +0400
(IMPL1): Packet framing for decrypt
Add decrypt_frame that returns the remainder after consuming one frame,
enabling streaming decryption. Update decrypt to be strict, rejecting
trailing bytes.
Changes:
- decrypt_frame :: Session -> ByteString
-> Either Error (ByteString, ByteString, Session)
- decrypt now wraps decrypt_frame and rejects non-empty remainder
- Added tests for framing behavior
Diffstat:
2 files changed, 158 insertions(+), 4 deletions(-)
diff --git a/lib/Lightning/Protocol/BOLT8.hs b/lib/Lightning/Protocol/BOLT8.hs
@@ -81,6 +81,7 @@ module Lightning.Protocol.BOLT8 (
, Handshake(..)
, encrypt
, decrypt
+ , decrypt_frame
-- * Errors
, Error(..)
@@ -532,10 +533,14 @@ encrypt sess pt = do
}
pure (packet, sess')
--- | Decrypt a message.
+-- | Decrypt a message, requiring an exact packet with no trailing bytes.
--
-- Returns the plaintext and updated session. Key rotation
--- is handled automatically every 500 messages.
+-- is handled automatically every 1000 messages.
+--
+-- This is a strict variant that rejects any trailing data. For
+-- streaming use cases where you need to handle multiple frames in a
+-- buffer, use 'decrypt_frame' instead.
--
-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
@@ -548,9 +553,38 @@ encrypt sess pt = do
-- "hello"
decrypt
:: Session
- -> BS.ByteString -- ^ encrypted packet
+ -> BS.ByteString -- ^ encrypted packet (exact length required)
-> Either Error (BS.ByteString, Session)
decrypt sess packet = do
+ (pt, remainder, sess') <- decrypt_frame sess packet
+ require (BS.null remainder) InvalidLength
+ pure (pt, sess')
+
+-- | Decrypt a single frame from a buffer, returning the remainder.
+--
+-- Returns the plaintext, any unconsumed bytes, and the updated session.
+-- Key rotation is handled automatically every 1000 messages.
+--
+-- This is useful for streaming scenarios where multiple messages may
+-- be buffered together. The remainder can be passed to the next call
+-- to 'decrypt_frame'.
+--
+-- Wire format consumed: encrypted_length (18) || encrypted_body (len + 16)
+--
+-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
+-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
+-- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
+-- >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
+-- >>> let Right (msg3, i_result) = act3 i_hs msg2
+-- >>> let Right r_result = finalize r_hs msg3
+-- >>> let Right (ct, _) = encrypt (session i_result) "hello"
+-- >>> case decrypt_frame (session r_result) ct of { Right (pt, rem, _) -> (pt, BS.null rem); Left _ -> ("fail", False) }
+-- ("hello",True)
+decrypt_frame
+ :: Session
+ -> BS.ByteString -- ^ buffer containing at least one encrypted frame
+ -> Either Error (BS.ByteString, BS.ByteString, Session)
+decrypt_frame sess packet = do
require (BS.length packet >= 34) InvalidLength
let !lc = BS.take 18 packet
!rest = BS.drop 18 packet
@@ -561,6 +595,7 @@ decrypt sess packet = do
!body_len = fi len + 16
require (BS.length rest >= body_len) InvalidLength
let !bc = BS.take body_len rest
+ !remainder = BS.drop body_len rest
pt <- note InvalidMAC (decrypt_with_ad rk1 rn1 BS.empty bc)
let !(rn2, rck2, rk2) = step_nonce rn1 rck1 rk1
!sess' = sess {
@@ -568,7 +603,7 @@ decrypt sess packet = do
, sess_rn = rn2
, sess_rck = rck2
}
- pure (pt, sess')
+ pure (pt, remainder, sess')
-- key rotation --------------------------------------------------------------
diff --git a/test/Main.hs b/test/Main.hs
@@ -12,6 +12,7 @@ main :: IO ()
main = defaultMain $ testGroup "ppad-bolt8" [
handshake_tests
, message_tests
+ , framing_tests
]
-- test vectors from BOLT #8 specification -----------------------------------
@@ -268,6 +269,124 @@ test_decrypt_roundtrip = do
assertFailure $ "decrypt failed: " ++ show err
Right (pt, _) -> pt @?= hello
+-- framing tests -------------------------------------------------------------
+
+framing_tests :: TestTree
+framing_tests = testGroup "Packet Framing" [
+ testCase "decrypt rejects trailing bytes" test_decrypt_trailing
+ , testCase "decrypt_frame returns remainder" test_decrypt_frame_remainder
+ , testCase "decrypt_frame handles multiple frames" test_decrypt_frame_multi
+ ]
+
+test_decrypt_trailing :: Assertion
+test_decrypt_trailing = do
+ let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
+ Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
+ Just rs = BOLT8.parse_pub responder_s_pub
+
+ case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
+ Left err -> assertFailure $ "act1 failed: " ++ show err
+ Right (msg1, i_hs) ->
+ case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
+ Left err -> assertFailure $ "act2 failed: " ++ show err
+ Right (msg2, r_hs) ->
+ case BOLT8.act3 i_hs msg2 of
+ Left err -> assertFailure $ "act3 failed: " ++ show err
+ Right (msg3, i_result) ->
+ case BOLT8.finalize r_hs msg3 of
+ Left err -> assertFailure $ "finalize failed: " ++ show err
+ Right r_result -> do
+ let i_sess = BOLT8.session i_result
+ r_sess = BOLT8.session r_result
+ case BOLT8.encrypt i_sess hello of
+ Left err -> assertFailure $ "encrypt failed: " ++ show err
+ Right (ct, _) -> do
+ -- append trailing bytes
+ let ct_with_trailing = ct <> "extra"
+ case BOLT8.decrypt r_sess ct_with_trailing of
+ Left BOLT8.InvalidLength -> pure ()
+ Left err ->
+ assertFailure $ "expected InvalidLength, got: "
+ ++ show err
+ Right _ ->
+ assertFailure "decrypt should reject trailing bytes"
+
+test_decrypt_frame_remainder :: Assertion
+test_decrypt_frame_remainder = do
+ let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
+ Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
+ Just rs = BOLT8.parse_pub responder_s_pub
+
+ case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
+ Left err -> assertFailure $ "act1 failed: " ++ show err
+ Right (msg1, i_hs) ->
+ case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
+ Left err -> assertFailure $ "act2 failed: " ++ show err
+ Right (msg2, r_hs) ->
+ case BOLT8.act3 i_hs msg2 of
+ Left err -> assertFailure $ "act3 failed: " ++ show err
+ Right (msg3, i_result) ->
+ case BOLT8.finalize r_hs msg3 of
+ Left err -> assertFailure $ "finalize failed: " ++ show err
+ Right r_result -> do
+ let i_sess = BOLT8.session i_result
+ r_sess = BOLT8.session r_result
+ case BOLT8.encrypt i_sess hello of
+ Left err -> assertFailure $ "encrypt failed: " ++ show err
+ Right (ct, _) -> do
+ let trailing = "remainder"
+ ct_with_trailing = ct <> trailing
+ case BOLT8.decrypt_frame r_sess ct_with_trailing of
+ Left err ->
+ assertFailure $ "decrypt_frame failed: " ++ show err
+ Right (pt, remainder, _) -> do
+ pt @?= hello
+ remainder @?= trailing
+
+test_decrypt_frame_multi :: Assertion
+test_decrypt_frame_multi = do
+ let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
+ Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
+ Just rs = BOLT8.parse_pub responder_s_pub
+
+ case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
+ Left err -> assertFailure $ "act1 failed: " ++ show err
+ Right (msg1, i_hs) ->
+ case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
+ Left err -> assertFailure $ "act2 failed: " ++ show err
+ Right (msg2, r_hs) ->
+ case BOLT8.act3 i_hs msg2 of
+ Left err -> assertFailure $ "act3 failed: " ++ show err
+ Right (msg3, i_result) ->
+ case BOLT8.finalize r_hs msg3 of
+ Left err -> assertFailure $ "finalize failed: " ++ show err
+ Right r_result -> do
+ let i_sess = BOLT8.session i_result
+ r_sess = BOLT8.session r_result
+ -- encrypt two messages
+ case BOLT8.encrypt i_sess "first" of
+ Left err -> assertFailure $ "encrypt 1 failed: " ++ show err
+ Right (ct1, i_sess') ->
+ case BOLT8.encrypt i_sess' "second" of
+ Left err ->
+ assertFailure $ "encrypt 2 failed: " ++ show err
+ Right (ct2, _) -> do
+ -- concatenate frames
+ let buffer = ct1 <> ct2
+ -- decrypt first frame
+ case BOLT8.decrypt_frame r_sess buffer of
+ Left err ->
+ assertFailure $ "frame 1 failed: " ++ show err
+ Right (pt1, rest, r_sess') -> do
+ pt1 @?= "first"
+ -- decrypt second frame from remainder
+ case BOLT8.decrypt_frame r_sess' rest of
+ Left err ->
+ assertFailure $ "frame 2 failed: " ++ show err
+ Right (pt2, rest2, _) -> do
+ pt2 @?= "second"
+ rest2 @?= BS.empty
+
-- utilities -----------------------------------------------------------------
hex :: BS.ByteString -> BS.ByteString