bolt1

Base Lightning protocol, per BOLT #1.
git clone git://git.ppad.tech/bolt1.git
Log | Files | Refs | README | LICENSE

commit 82bc4e46f21942423984099fba7b2e3bc7ebe74d
parent af22e72c4b9258f4ec5f48af5c9bc4bf5260cdfb
Author: Jared Tobin <jared@jtobin.io>
Date:   Sun, 25 Jan 2026 14:30:10 +0400

feat: add TlvStream smart constructor with ordering validation

- Add tlvStream smart constructor (validates strictly increasing types)
- Add unsafeTlvStream for internal/trusted use
- Hide TlvStream data constructor from public API
- Update internal uses in decode functions to use unsafeTlvStream
- Add tests for ordering validation

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Diffstat:
Mlib/Lightning/Protocol/BOLT1.hs | 5++++-
Mlib/Lightning/Protocol/BOLT1/Codec.hs | 6+++---
Mlib/Lightning/Protocol/BOLT1/TLV.hs | 31+++++++++++++++++++++++++++----
Mtest/Main.hs | 61++++++++++++++++++++++++++++++++++++++++++-------------------
4 files changed, 76 insertions(+), 27 deletions(-)

diff --git a/lib/Lightning/Protocol/BOLT1.hs b/lib/Lightning/Protocol/BOLT1.hs @@ -30,7 +30,10 @@ module Lightning.Protocol.BOLT1 ( -- * TLV , TlvRecord(..) - , TlvStream(..) + , TlvStream + , unTlvStream + , tlvStream + , unsafeTlvStream , TlvError(..) , encodeTlvStream , decodeTlvStream diff --git a/lib/Lightning/Protocol/BOLT1/Codec.hs b/lib/Lightning/Protocol/BOLT1/Codec.hs @@ -182,11 +182,11 @@ decodeInit !bs = do let !feat = BS.take (fromIntegral fLen) rest3 !rest4 = BS.drop (fromIntegral fLen) rest3 -- Parse optional TLV stream (consumes all remaining bytes for init) - tlvStream <- if BS.null rest4 - then Right (TlvStream []) + tlvs <- if BS.null rest4 + then Right (unsafeTlvStream []) else either (Left . DecodeTlvError) Right (decodeTlvStream rest4) initTlvList <- either (Left . DecodeTlvError) Right - (parseInitTlvs tlvStream) + (parseInitTlvs tlvs) -- Init consumes all bytes (TLVs are part of init, not extensions) Right (Init gf feat initTlvList, BS.empty) diff --git a/lib/Lightning/Protocol/BOLT1/TLV.hs b/lib/Lightning/Protocol/BOLT1/TLV.hs @@ -14,7 +14,10 @@ module Lightning.Protocol.BOLT1.TLV ( -- * TLV types TlvRecord(..) - , TlvStream(..) + , TlvStream + , unTlvStream + , tlvStream + , unsafeTlvStream , TlvError(..) -- * TLV encoding @@ -55,6 +58,26 @@ newtype TlvStream = TlvStream { unTlvStream :: [TlvRecord] } instance NFData TlvStream +-- | Smart constructor for 'TlvStream' that validates records are +-- strictly increasing by type. +-- +-- Returns 'Nothing' if types are not strictly increasing. +tlvStream :: [TlvRecord] -> Maybe TlvStream +tlvStream recs + | isStrictlyIncreasing (map tlvType recs) = Just (TlvStream recs) + | otherwise = Nothing + where + isStrictlyIncreasing :: [Word64] -> Bool + isStrictlyIncreasing [] = True + isStrictlyIncreasing [_] = True + isStrictlyIncreasing (x:y:rest) = x < y && isStrictlyIncreasing (y:rest) + +-- | Unsafe constructor for 'TlvStream' that skips validation. +-- +-- Use only when ordering is already guaranteed (e.g., in decode functions). +unsafeTlvStream :: [TlvRecord] -> TlvStream +unsafeTlvStream = TlvStream + -- | TLV decoding errors. data TlvError = TlvNonMinimalEncoding @@ -97,7 +120,7 @@ decodeTlvStreamRaw = go Nothing [] go :: Maybe Word64 -> [TlvRecord] -> BS.ByteString -> Either TlvError TlvStream go !_ !acc !bs - | BS.null bs = Right (TlvStream (reverse acc)) + | BS.null bs = Right (unsafeTlvStream (reverse acc)) go !mPrevType !acc !bs = do (typ, rest1) <- maybe (Left TlvNonMinimalEncoding) Right (decodeBigSize bs) @@ -133,7 +156,7 @@ decodeTlvStreamWith isKnown = go Nothing [] go :: Maybe Word64 -> [TlvRecord] -> BS.ByteString -> Either TlvError TlvStream go !_ !acc !bs - | BS.null bs = Right (TlvStream (reverse acc)) + | BS.null bs = Right (unsafeTlvStream (reverse acc)) go !mPrevType !acc !bs = do (typ, rest1) <- maybe (Left TlvNonMinimalEncoding) Right (decodeBigSize bs) @@ -201,7 +224,7 @@ chunksOf !n !bs -- | Encode init TLVs to a TLV stream. encodeInitTlvs :: [InitTlv] -> TlvStream -encodeInitTlvs = TlvStream . map toRecord +encodeInitTlvs = unsafeTlvStream . map toRecord where toRecord (InitNetworks chains) = TlvRecord 1 (mconcat chains) diff --git a/test/Main.hs b/test/Main.hs @@ -259,8 +259,25 @@ minsigned_tests = testGroup "Minimal signed (Appendix D)" [ tlv_tests :: TestTree tlv_tests = testGroup "TLV" [ - testCase "empty stream" $ - decodeTlvStream "" @?= Right (TlvStream []) + testGroup "tlvStream smart constructor" [ + testCase "empty list succeeds" $ + tlvStream [] @?= Just (unsafeTlvStream []) + , testCase "single record succeeds" $ + tlvStream [TlvRecord 1 "a"] @?= Just (unsafeTlvStream [TlvRecord 1 "a"]) + , testCase "strictly increasing succeeds" $ + tlvStream [TlvRecord 1 "a", TlvRecord 3 "b", TlvRecord 5 "c"] @?= + Just (unsafeTlvStream [TlvRecord 1 "a", TlvRecord 3 "b", + TlvRecord 5 "c"]) + , testCase "non-increasing fails" $ + tlvStream [TlvRecord 5 "a", TlvRecord 3 "b"] @?= Nothing + , testCase "duplicate types fails" $ + tlvStream [TlvRecord 1 "a", TlvRecord 1 "b"] @?= Nothing + , testCase "equal adjacent types fails" $ + tlvStream [TlvRecord 1 "a", TlvRecord 2 "b", TlvRecord 2 "c"] @?= + Nothing + ] + , testCase "empty stream" $ + decodeTlvStream "" @?= Right (unsafeTlvStream []) , testCase "single record type 1" $ do let bs = mconcat [ encodeBigSize 1 -- type @@ -268,17 +285,19 @@ tlv_tests = testGroup "TLV" [ , BS.replicate 32 0x00 -- value (chain hash) ] case decodeTlvStream bs of - Right (TlvStream [r]) -> do - tlvType r @?= 1 - BS.length (tlvValue r) @?= 32 - other -> assertFailure $ "unexpected: " ++ show other + Right stream -> case unTlvStream stream of + [r] -> do + tlvType r @?= 1 + BS.length (tlvValue r) @?= 32 + _ -> assertFailure "expected single record" + Left e -> assertFailure $ "unexpected error: " ++ show e , testCase "strictly increasing types" $ do let bs = mconcat [ encodeBigSize 1, encodeBigSize 0 , encodeBigSize 3, encodeBigSize 4, "test" ] case decodeTlvStream bs of - Right (TlvStream recs) -> length recs @?= 2 + Right stream -> length (unTlvStream stream) @?= 2 Left e -> assertFailure $ "unexpected error: " ++ show e , testCase "non-increasing types fails" $ do let bs = mconcat [ @@ -309,7 +328,7 @@ tlv_tests = testGroup "TLV" [ , encodeBigSize 7, encodeBigSize 0 ] case decodeTlvStream bs of - Right (TlvStream []) -> pure () -- both skipped (unknown odd) + Right stream | null (unTlvStream stream) -> pure () -- both skipped other -> assertFailure $ "expected empty stream: " ++ show other , testCase "length exceeds bounds fails" $ do let bs = mconcat [encodeBigSize 1, encodeBigSize 100, "short"] @@ -324,15 +343,17 @@ tlv_tests = testGroup "TLV" [ encodeBigSize 5, encodeBigSize 2, "hi" ] case decodeTlvStreamWith isKnown bs of - Right (TlvStream [r]) -> tlvType r @?= 5 - other -> assertFailure $ "unexpected: " ++ show other + Right stream -> case unTlvStream stream of + [r] -> tlvType r @?= 5 + _ -> assertFailure "expected single record" + Left e -> assertFailure $ "unexpected error: " ++ show e , testCase "decodeTlvStreamRaw returns all records" $ do let bs = mconcat [ encodeBigSize 2, encodeBigSize 1, "a" -- even type , encodeBigSize 5, encodeBigSize 1, "b" -- odd type ] case decodeTlvStreamRaw bs of - Right (TlvStream recs) -> length recs @?= 2 + Right stream -> length (unTlvStream stream) @?= 2 Left e -> assertFailure $ "unexpected error: " ++ show e ] @@ -491,7 +512,7 @@ extension_tests :: TestTree extension_tests = testGroup "Extension TLV" [ testCase "encode envelope with extension (odd type)" $ do let msg = MsgPingVal (Ping 10 "") - ext = TlvStream [TlvRecord 101 "extension data"] -- odd type + ext = unsafeTlvStream [TlvRecord 101 "extension data"] -- odd type case encodeEnvelope msg (Just ext) of Left e -> assertFailure $ "encode failed: " ++ show e Right encoded -> do @@ -500,13 +521,14 @@ extension_tests = testGroup "Extension TLV" [ , testCase "decode envelope with odd extension - skipped per BOLT#1" $ do -- Per BOLT #1: unknown odd types are ignored (skipped) let msg = MsgPingVal (Ping 10 "") - ext = TlvStream [TlvRecord 101 "ext"] -- odd type + ext = unsafeTlvStream [TlvRecord 101 "ext"] -- odd type case encodeEnvelope msg (Just ext) of Left e -> assertFailure $ "encode failed: " ++ show e Right encoded -> case decodeEnvelope encoded of - Right (Just decoded, Just (TlvStream [])) -> do - -- Extension is empty because unknown odd types are skipped - decoded @?= msg + Right (Just decoded, Just stream) + | null (unTlvStream stream) -> do + -- Extension is empty because unknown odd types are skipped + decoded @?= msg other -> assertFailure $ "unexpected: " ++ show other , testCase "decode envelope with unknown even extension fails" $ do -- Per BOLT #1: unknown even types must cause failure @@ -573,7 +595,7 @@ bounds_tests = testGroup "Bounds checking" [ -- Create a message that fits in encodeMessage but combined with -- extension exceeds 65535 bytes total let msg = MsgPongVal (Pong (BS.replicate 60000 0x00)) - ext = TlvStream [TlvRecord 101 (BS.replicate 10000 0x00)] + ext = unsafeTlvStream [TlvRecord 101 (BS.replicate 10000 0x00)] case encodeEnvelope msg (Just ext) of Left EncodeMessageTooLarge -> pure () other -> assertFailure $ "expected message too large: " ++ show other @@ -634,12 +656,13 @@ property_tests = testGroup "Properties" [ -- Unknown odd types in extensions are skipped per BOLT #1 let msg = MsgPingVal (Ping 42 "") extData = BS.pack (take 100 bs) - ext = TlvStream [TlvRecord 101 extData] -- odd type, will be skipped + ext = unsafeTlvStream [TlvRecord 101 extData] -- odd type, skipped in case encodeEnvelope msg (Just ext) of Left _ -> False Right encoded -> case decodeEnvelope encoded of -- Extension should be empty (odd types skipped) - Right (Just decoded, Just (TlvStream [])) -> decoded == msg + Right (Just decoded, Just stream) -> + null (unTlvStream stream) && decoded == msg _ -> False ]