commit c890b7a7ca091a1a684253480c940a485b7af8e9
parent 73edeb4503c449cfdafd7a0817ce686b297960a8
Author: Jared Tobin <jared@jtobin.io>
Date: Sun, 25 Jan 2026 09:44:20 +0400
(IMPL4): Recoverable partial framing
Add FrameResult ADT and decrypt_frame_partial for non-blocking I/O:
- NeedMore indicates how many additional bytes are required
- FrameOk returns plaintext, remainder, and updated session
- FrameError wraps decryption errors
This enables incremental buffer processing without exceptions,
useful for async/event-driven network code.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat:
2 files changed, 166 insertions(+), 0 deletions(-)
diff --git a/lib/Lightning/Protocol/BOLT8.hs b/lib/Lightning/Protocol/BOLT8.hs
@@ -82,6 +82,8 @@ module Lightning.Protocol.BOLT8 (
, encrypt
, decrypt
, decrypt_frame
+ , decrypt_frame_partial
+ , FrameResult(..)
-- * Errors
, Error(..)
@@ -123,6 +125,16 @@ data Error =
| DecryptionFailed
deriving (Eq, Show, Generic)
+-- | Result of attempting to decrypt a frame from a partial buffer.
+data FrameResult =
+ NeedMore {-# UNPACK #-} !Int
+ -- ^ More bytes needed; the 'Int' is the minimum additional bytes required.
+ | FrameOk !BS.ByteString !BS.ByteString !Session
+ -- ^ Successfully decrypted: plaintext, remainder, updated session.
+ | FrameError !Error
+ -- ^ Decryption failed with the given error.
+ deriving Generic
+
-- | Post-handshake session state.
data Session = Session {
sess_sk :: {-# UNPACK #-} !BS.ByteString -- ^ send key (32 bytes)
@@ -608,6 +620,56 @@ decrypt_frame sess packet = do
}
pure (pt, remainder, sess')
+-- | Decrypt a frame from a partial buffer, indicating when more data needed.
+--
+-- Unlike 'decrypt_frame', this function handles incomplete buffers
+-- gracefully by returning 'NeedMore' with the number of additional
+-- bytes required to make progress.
+--
+-- * If the buffer has fewer than 18 bytes (encrypted length + MAC),
+-- returns @'NeedMore' n@ where @n@ is the bytes still needed.
+-- * If the length header is complete but the body is incomplete,
+-- returns @'NeedMore' n@ with bytes needed for the full frame.
+-- * MAC or decryption failures return 'FrameError'.
+-- * A complete, valid frame returns 'FrameOk' with plaintext,
+-- remainder, and updated session.
+--
+-- This is useful for non-blocking I/O where data arrives incrementally.
+decrypt_frame_partial
+ :: Session
+ -> BS.ByteString -- ^ buffer (possibly incomplete)
+ -> FrameResult
+decrypt_frame_partial sess buf
+ | buflen < 18 = NeedMore (18 - buflen)
+ | otherwise =
+ let !lc = BS.take 18 buf
+ !rest = BS.drop 18 buf
+ in case decrypt_with_ad (sess_rk sess) (sess_rn sess) BS.empty lc of
+ Nothing -> FrameError InvalidMAC
+ Just len_bytes -> case decode_be16 len_bytes of
+ Nothing -> FrameError InvalidLength
+ Just len ->
+ let !body_len = fi len + 16
+ !(rn1, rck1, rk1) = step_nonce (sess_rn sess)
+ (sess_rck sess) (sess_rk sess)
+ in if BS.length rest < body_len
+ then NeedMore (body_len - BS.length rest)
+ else
+ let !bc = BS.take body_len rest
+ !remainder = BS.drop body_len rest
+ in case decrypt_with_ad rk1 rn1 BS.empty bc of
+ Nothing -> FrameError InvalidMAC
+ Just pt ->
+ let !(rn2, rck2, rk2) = step_nonce rn1 rck1 rk1
+ !sess' = sess {
+ sess_rk = rk2
+ , sess_rn = rn2
+ , sess_rck = rck2
+ }
+ in FrameOk pt remainder sess'
+ where
+ !buflen = BS.length buf
+
-- key rotation --------------------------------------------------------------
-- Key rotation occurs after nonce reaches 1000 (i.e., before using 1000)
diff --git a/test/Main.hs b/test/Main.hs
@@ -14,6 +14,7 @@ main = defaultMain $ testGroup "ppad-bolt8" [
handshake_tests
, message_tests
, framing_tests
+ , partial_framing_tests
, negative_tests
]
@@ -389,6 +390,109 @@ test_decrypt_frame_multi = do
pt2 @?= "second"
rest2 @?= BS.empty
+-- partial framing tests -----------------------------------------------------
+
+partial_framing_tests :: TestTree
+partial_framing_tests = testGroup "Partial Framing" [
+ testCase "short buffer returns NeedMore" test_partial_short_buffer
+ , testCase "partial body returns NeedMore" test_partial_body
+ , testCase "full frame returns FrameOk" test_partial_full_frame
+ ]
+
+test_partial_short_buffer :: Assertion
+test_partial_short_buffer = do
+ let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
+ Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
+ Just rs = BOLT8.parse_pub responder_s_pub
+ case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
+ Left err -> assertFailure $ "act1 failed: " ++ show err
+ Right (msg1, i_hs) ->
+ case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
+ Left err -> assertFailure $ "act2 failed: " ++ show err
+ Right (msg2, r_hs) ->
+ case BOLT8.act3 i_hs msg2 of
+ Left err -> assertFailure $ "act3 failed: " ++ show err
+ Right (msg3, _) ->
+ case BOLT8.finalize r_hs msg3 of
+ Left err -> assertFailure $ "finalize failed: " ++ show err
+ Right r_result -> do
+ let r_sess = BOLT8.session r_result
+ short_buf = BS.replicate 10 0x00
+ case BOLT8.decrypt_frame_partial r_sess short_buf of
+ BOLT8.NeedMore n -> n @?= 8
+ BOLT8.FrameOk {} ->
+ assertFailure "expected NeedMore, got FrameOk"
+ BOLT8.FrameError err ->
+ assertFailure $ "expected NeedMore, got: " ++ show err
+
+test_partial_body :: Assertion
+test_partial_body = do
+ let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
+ Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
+ Just rs = BOLT8.parse_pub responder_s_pub
+ case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
+ Left err -> assertFailure $ "act1 failed: " ++ show err
+ Right (msg1, i_hs) ->
+ case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
+ Left err -> assertFailure $ "act2 failed: " ++ show err
+ Right (msg2, r_hs) ->
+ case BOLT8.act3 i_hs msg2 of
+ Left err -> assertFailure $ "act3 failed: " ++ show err
+ Right (msg3, i_result) ->
+ case BOLT8.finalize r_hs msg3 of
+ Left err -> assertFailure $ "finalize failed: " ++ show err
+ Right r_result -> do
+ let i_sess = BOLT8.session i_result
+ r_sess = BOLT8.session r_result
+ case BOLT8.encrypt i_sess hello of
+ Left err -> assertFailure $ "encrypt failed: " ++ show err
+ Right (ct, _) -> do
+ -- take only length header (18 bytes) + 5 bytes of body
+ let partial = BS.take 23 ct
+ case BOLT8.decrypt_frame_partial r_sess partial of
+ BOLT8.NeedMore n -> do
+ -- "hello" = 5 bytes, so body = 5 + 16 = 21
+ -- we have 5 bytes of body, need 16 more
+ n @?= 16
+ BOLT8.FrameOk {} ->
+ assertFailure "expected NeedMore, got FrameOk"
+ BOLT8.FrameError err ->
+ assertFailure $ "expected NeedMore, got: " ++ show err
+
+test_partial_full_frame :: Assertion
+test_partial_full_frame = do
+ let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
+ Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
+ Just rs = BOLT8.parse_pub responder_s_pub
+ case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
+ Left err -> assertFailure $ "act1 failed: " ++ show err
+ Right (msg1, i_hs) ->
+ case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
+ Left err -> assertFailure $ "act2 failed: " ++ show err
+ Right (msg2, r_hs) ->
+ case BOLT8.act3 i_hs msg2 of
+ Left err -> assertFailure $ "act3 failed: " ++ show err
+ Right (msg3, i_result) ->
+ case BOLT8.finalize r_hs msg3 of
+ Left err -> assertFailure $ "finalize failed: " ++ show err
+ Right r_result -> do
+ let i_sess = BOLT8.session i_result
+ r_sess = BOLT8.session r_result
+ case BOLT8.encrypt i_sess hello of
+ Left err -> assertFailure $ "encrypt failed: " ++ show err
+ Right (ct, _) -> do
+ let trailing = "extra"
+ buf = ct <> trailing
+ case BOLT8.decrypt_frame_partial r_sess buf of
+ BOLT8.FrameOk pt remainder _ -> do
+ pt @?= hello
+ remainder @?= trailing
+ BOLT8.NeedMore n ->
+ assertFailure $ "expected FrameOk, got NeedMore "
+ ++ show n
+ BOLT8.FrameError err ->
+ assertFailure $ "expected FrameOk, got: " ++ show err
+
-- negative tests ------------------------------------------------------------
negative_tests :: TestTree