commit 82bc4e46f21942423984099fba7b2e3bc7ebe74d
parent af22e72c4b9258f4ec5f48af5c9bc4bf5260cdfb
Author: Jared Tobin <jared@jtobin.io>
Date: Sun, 25 Jan 2026 14:30:10 +0400
feat: add TlvStream smart constructor with ordering validation
- Add tlvStream smart constructor (validates strictly increasing types)
- Add unsafeTlvStream for internal/trusted use
- Hide TlvStream data constructor from public API
- Update internal uses in decode functions to use unsafeTlvStream
- Add tests for ordering validation
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat:
4 files changed, 76 insertions(+), 27 deletions(-)
diff --git a/lib/Lightning/Protocol/BOLT1.hs b/lib/Lightning/Protocol/BOLT1.hs
@@ -30,7 +30,10 @@ module Lightning.Protocol.BOLT1 (
-- * TLV
, TlvRecord(..)
- , TlvStream(..)
+ , TlvStream
+ , unTlvStream
+ , tlvStream
+ , unsafeTlvStream
, TlvError(..)
, encodeTlvStream
, decodeTlvStream
diff --git a/lib/Lightning/Protocol/BOLT1/Codec.hs b/lib/Lightning/Protocol/BOLT1/Codec.hs
@@ -182,11 +182,11 @@ decodeInit !bs = do
let !feat = BS.take (fromIntegral fLen) rest3
!rest4 = BS.drop (fromIntegral fLen) rest3
-- Parse optional TLV stream (consumes all remaining bytes for init)
- tlvStream <- if BS.null rest4
- then Right (TlvStream [])
+ tlvs <- if BS.null rest4
+ then Right (unsafeTlvStream [])
else either (Left . DecodeTlvError) Right (decodeTlvStream rest4)
initTlvList <- either (Left . DecodeTlvError) Right
- (parseInitTlvs tlvStream)
+ (parseInitTlvs tlvs)
-- Init consumes all bytes (TLVs are part of init, not extensions)
Right (Init gf feat initTlvList, BS.empty)
diff --git a/lib/Lightning/Protocol/BOLT1/TLV.hs b/lib/Lightning/Protocol/BOLT1/TLV.hs
@@ -14,7 +14,10 @@
module Lightning.Protocol.BOLT1.TLV (
-- * TLV types
TlvRecord(..)
- , TlvStream(..)
+ , TlvStream
+ , unTlvStream
+ , tlvStream
+ , unsafeTlvStream
, TlvError(..)
-- * TLV encoding
@@ -55,6 +58,26 @@ newtype TlvStream = TlvStream { unTlvStream :: [TlvRecord] }
instance NFData TlvStream
+-- | Smart constructor for 'TlvStream' that validates records are
+-- strictly increasing by type.
+--
+-- Returns 'Nothing' if types are not strictly increasing.
+tlvStream :: [TlvRecord] -> Maybe TlvStream
+tlvStream recs
+ | isStrictlyIncreasing (map tlvType recs) = Just (TlvStream recs)
+ | otherwise = Nothing
+ where
+ isStrictlyIncreasing :: [Word64] -> Bool
+ isStrictlyIncreasing [] = True
+ isStrictlyIncreasing [_] = True
+ isStrictlyIncreasing (x:y:rest) = x < y && isStrictlyIncreasing (y:rest)
+
+-- | Unsafe constructor for 'TlvStream' that skips validation.
+--
+-- Use only when ordering is already guaranteed (e.g., in decode functions).
+unsafeTlvStream :: [TlvRecord] -> TlvStream
+unsafeTlvStream = TlvStream
+
-- | TLV decoding errors.
data TlvError
= TlvNonMinimalEncoding
@@ -97,7 +120,7 @@ decodeTlvStreamRaw = go Nothing []
go :: Maybe Word64 -> [TlvRecord] -> BS.ByteString
-> Either TlvError TlvStream
go !_ !acc !bs
- | BS.null bs = Right (TlvStream (reverse acc))
+ | BS.null bs = Right (unsafeTlvStream (reverse acc))
go !mPrevType !acc !bs = do
(typ, rest1) <- maybe (Left TlvNonMinimalEncoding) Right
(decodeBigSize bs)
@@ -133,7 +156,7 @@ decodeTlvStreamWith isKnown = go Nothing []
go :: Maybe Word64 -> [TlvRecord] -> BS.ByteString
-> Either TlvError TlvStream
go !_ !acc !bs
- | BS.null bs = Right (TlvStream (reverse acc))
+ | BS.null bs = Right (unsafeTlvStream (reverse acc))
go !mPrevType !acc !bs = do
(typ, rest1) <- maybe (Left TlvNonMinimalEncoding) Right
(decodeBigSize bs)
@@ -201,7 +224,7 @@ chunksOf !n !bs
-- | Encode init TLVs to a TLV stream.
encodeInitTlvs :: [InitTlv] -> TlvStream
-encodeInitTlvs = TlvStream . map toRecord
+encodeInitTlvs = unsafeTlvStream . map toRecord
where
toRecord (InitNetworks chains) =
TlvRecord 1 (mconcat chains)
diff --git a/test/Main.hs b/test/Main.hs
@@ -259,8 +259,25 @@ minsigned_tests = testGroup "Minimal signed (Appendix D)" [
tlv_tests :: TestTree
tlv_tests = testGroup "TLV" [
- testCase "empty stream" $
- decodeTlvStream "" @?= Right (TlvStream [])
+ testGroup "tlvStream smart constructor" [
+ testCase "empty list succeeds" $
+ tlvStream [] @?= Just (unsafeTlvStream [])
+ , testCase "single record succeeds" $
+ tlvStream [TlvRecord 1 "a"] @?= Just (unsafeTlvStream [TlvRecord 1 "a"])
+ , testCase "strictly increasing succeeds" $
+ tlvStream [TlvRecord 1 "a", TlvRecord 3 "b", TlvRecord 5 "c"] @?=
+ Just (unsafeTlvStream [TlvRecord 1 "a", TlvRecord 3 "b",
+ TlvRecord 5 "c"])
+ , testCase "non-increasing fails" $
+ tlvStream [TlvRecord 5 "a", TlvRecord 3 "b"] @?= Nothing
+ , testCase "duplicate types fails" $
+ tlvStream [TlvRecord 1 "a", TlvRecord 1 "b"] @?= Nothing
+ , testCase "equal adjacent types fails" $
+ tlvStream [TlvRecord 1 "a", TlvRecord 2 "b", TlvRecord 2 "c"] @?=
+ Nothing
+ ]
+ , testCase "empty stream" $
+ decodeTlvStream "" @?= Right (unsafeTlvStream [])
, testCase "single record type 1" $ do
let bs = mconcat [
encodeBigSize 1 -- type
@@ -268,17 +285,19 @@ tlv_tests = testGroup "TLV" [
, BS.replicate 32 0x00 -- value (chain hash)
]
case decodeTlvStream bs of
- Right (TlvStream [r]) -> do
- tlvType r @?= 1
- BS.length (tlvValue r) @?= 32
- other -> assertFailure $ "unexpected: " ++ show other
+ Right stream -> case unTlvStream stream of
+ [r] -> do
+ tlvType r @?= 1
+ BS.length (tlvValue r) @?= 32
+ _ -> assertFailure "expected single record"
+ Left e -> assertFailure $ "unexpected error: " ++ show e
, testCase "strictly increasing types" $ do
let bs = mconcat [
encodeBigSize 1, encodeBigSize 0
, encodeBigSize 3, encodeBigSize 4, "test"
]
case decodeTlvStream bs of
- Right (TlvStream recs) -> length recs @?= 2
+ Right stream -> length (unTlvStream stream) @?= 2
Left e -> assertFailure $ "unexpected error: " ++ show e
, testCase "non-increasing types fails" $ do
let bs = mconcat [
@@ -309,7 +328,7 @@ tlv_tests = testGroup "TLV" [
, encodeBigSize 7, encodeBigSize 0
]
case decodeTlvStream bs of
- Right (TlvStream []) -> pure () -- both skipped (unknown odd)
+ Right stream | null (unTlvStream stream) -> pure () -- both skipped
other -> assertFailure $ "expected empty stream: " ++ show other
, testCase "length exceeds bounds fails" $ do
let bs = mconcat [encodeBigSize 1, encodeBigSize 100, "short"]
@@ -324,15 +343,17 @@ tlv_tests = testGroup "TLV" [
encodeBigSize 5, encodeBigSize 2, "hi"
]
case decodeTlvStreamWith isKnown bs of
- Right (TlvStream [r]) -> tlvType r @?= 5
- other -> assertFailure $ "unexpected: " ++ show other
+ Right stream -> case unTlvStream stream of
+ [r] -> tlvType r @?= 5
+ _ -> assertFailure "expected single record"
+ Left e -> assertFailure $ "unexpected error: " ++ show e
, testCase "decodeTlvStreamRaw returns all records" $ do
let bs = mconcat [
encodeBigSize 2, encodeBigSize 1, "a" -- even type
, encodeBigSize 5, encodeBigSize 1, "b" -- odd type
]
case decodeTlvStreamRaw bs of
- Right (TlvStream recs) -> length recs @?= 2
+ Right stream -> length (unTlvStream stream) @?= 2
Left e -> assertFailure $ "unexpected error: " ++ show e
]
@@ -491,7 +512,7 @@ extension_tests :: TestTree
extension_tests = testGroup "Extension TLV" [
testCase "encode envelope with extension (odd type)" $ do
let msg = MsgPingVal (Ping 10 "")
- ext = TlvStream [TlvRecord 101 "extension data"] -- odd type
+ ext = unsafeTlvStream [TlvRecord 101 "extension data"] -- odd type
case encodeEnvelope msg (Just ext) of
Left e -> assertFailure $ "encode failed: " ++ show e
Right encoded -> do
@@ -500,13 +521,14 @@ extension_tests = testGroup "Extension TLV" [
, testCase "decode envelope with odd extension - skipped per BOLT#1" $ do
-- Per BOLT #1: unknown odd types are ignored (skipped)
let msg = MsgPingVal (Ping 10 "")
- ext = TlvStream [TlvRecord 101 "ext"] -- odd type
+ ext = unsafeTlvStream [TlvRecord 101 "ext"] -- odd type
case encodeEnvelope msg (Just ext) of
Left e -> assertFailure $ "encode failed: " ++ show e
Right encoded -> case decodeEnvelope encoded of
- Right (Just decoded, Just (TlvStream [])) -> do
- -- Extension is empty because unknown odd types are skipped
- decoded @?= msg
+ Right (Just decoded, Just stream)
+ | null (unTlvStream stream) -> do
+ -- Extension is empty because unknown odd types are skipped
+ decoded @?= msg
other -> assertFailure $ "unexpected: " ++ show other
, testCase "decode envelope with unknown even extension fails" $ do
-- Per BOLT #1: unknown even types must cause failure
@@ -573,7 +595,7 @@ bounds_tests = testGroup "Bounds checking" [
-- Create a message that fits in encodeMessage but combined with
-- extension exceeds 65535 bytes total
let msg = MsgPongVal (Pong (BS.replicate 60000 0x00))
- ext = TlvStream [TlvRecord 101 (BS.replicate 10000 0x00)]
+ ext = unsafeTlvStream [TlvRecord 101 (BS.replicate 10000 0x00)]
case encodeEnvelope msg (Just ext) of
Left EncodeMessageTooLarge -> pure ()
other -> assertFailure $ "expected message too large: " ++ show other
@@ -634,12 +656,13 @@ property_tests = testGroup "Properties" [
-- Unknown odd types in extensions are skipped per BOLT #1
let msg = MsgPingVal (Ping 42 "")
extData = BS.pack (take 100 bs)
- ext = TlvStream [TlvRecord 101 extData] -- odd type, will be skipped
+ ext = unsafeTlvStream [TlvRecord 101 extData] -- odd type, skipped
in case encodeEnvelope msg (Just ext) of
Left _ -> False
Right encoded -> case decodeEnvelope encoded of
-- Extension should be empty (odd types skipped)
- Right (Just decoded, Just (TlvStream [])) -> decoded == msg
+ Right (Just decoded, Just stream) ->
+ null (unTlvStream stream) && decoded == msg
_ -> False
]