commit 56d27f6673b4a125218f55ff1ee0d2d74889183b
parent 80d0966d9fc9cf42d98d8f3d1d470defbdad6a01
Author: Jared Tobin <jared@jtobin.io>
Date: Sun, 25 Jan 2026 10:24:13 +0400
Merge branch 'impl/review-fixes'
Address review findings for BOLT #1 message codec (REVIEW-80d0966):
High severity fixes:
- decodeEnvelope now parses and returns extension TLVs
- Add decodeTlvStreamWith for configurable known-type predicates
- Add decodeTlvStreamRaw for raw TLV parsing without type validation
Medium severity fixes:
- Add bounds checking for u16 length fields (EncodeLengthOverflow error)
- Encoding functions now return Either EncodeError ByteString
- decodeMessage returns DecodeUnknownOddType for unknown odd types
API changes:
- encodeMessage/encodeEnvelope return Either EncodeError ByteString
- decodeMessage returns (Message, ByteString) tuple for remaining bytes
- decodeEnvelope returns (Maybe Message, Maybe TlvStream)
- New exports: decodeTlvStreamWith, decodeTlvStreamRaw, EncodeError
Tests expanded from 55 to 75 cases covering extension handling,
bounds checking, and proper error types.
Diffstat:
| M | lib/Lightning/Protocol/BOLT1.hs | | | 305 | ++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------------- |
| M | test/Main.hs | | | 288 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------------- |
2 files changed, 419 insertions(+), 174 deletions(-)
diff --git a/lib/Lightning/Protocol/BOLT1.hs b/lib/Lightning/Protocol/BOLT1.hs
@@ -39,6 +39,8 @@ module Lightning.Protocol.BOLT1 (
, TlvError(..)
, encodeTlvStream
, decodeTlvStream
+ , decodeTlvStreamWith
+ , decodeTlvStreamRaw
-- ** Init TLVs
, InitTlv(..)
@@ -47,6 +49,7 @@ module Lightning.Protocol.BOLT1 (
, Envelope(..)
-- * Encoding
+ , EncodeError(..)
, encodeMessage
, encodeEnvelope
@@ -121,6 +124,15 @@ encodeBigSize !x
| otherwise = BS.cons 0xff (encodeU64 x)
{-# INLINE encodeBigSize #-}
+-- | Encode a length as u16, checking bounds.
+--
+-- Returns Nothing if the length exceeds 65535.
+encodeLength :: BS.ByteString -> Maybe BS.ByteString
+encodeLength !bs
+ | BS.length bs > 65535 = Nothing
+ | otherwise = Just (encodeU16 (fromIntegral (BS.length bs)))
+{-# INLINE encodeLength #-}
+
-- Primitive decoding ----------------------------------------------------------
-- | Decode a 16-bit unsigned integer (big-endian).
@@ -232,13 +244,52 @@ data TlvError
instance NFData TlvError
--- | Decode a TLV stream with BOLT #1 validation.
+-- | Decode a TLV stream without any known-type validation.
+--
+-- This decoder only enforces structural validity:
+-- - Types must be strictly increasing
+-- - Lengths must not exceed bounds
+--
+-- All records are returned regardless of type. Use this for parsing
+-- extension TLVs or when you want to handle type validation separately.
+decodeTlvStreamRaw :: BS.ByteString -> Either TlvError TlvStream
+decodeTlvStreamRaw = go Nothing []
+ where
+ go :: Maybe Word64 -> [TlvRecord] -> BS.ByteString
+ -> Either TlvError TlvStream
+ go !_ !acc !bs
+ | BS.null bs = Right (TlvStream (reverse acc))
+ go !mPrevType !acc !bs = do
+ (typ, rest1) <- maybe (Left TlvNonMinimalEncoding) Right
+ (decodeBigSize bs)
+ -- Strictly increasing check
+ case mPrevType of
+ Just prevType -> when (typ <= prevType) $
+ Left TlvNotStrictlyIncreasing
+ Nothing -> pure ()
+ (len, rest2) <- maybe (Left TlvNonMinimalEncoding) Right
+ (decodeBigSize rest1)
+ -- Length bounds check
+ when (fromIntegral len > BS.length rest2) $
+ Left TlvLengthExceedsBounds
+ let !val = BS.take (fromIntegral len) rest2
+ !rest3 = BS.drop (fromIntegral len) rest2
+ !rec = TlvRecord typ val
+ go (Just typ) (rec : acc) rest3
+
+-- | Decode a TLV stream with configurable known-type predicate.
--
+-- Per BOLT #1:
-- - Types must be strictly increasing
-- - Unknown even types cause failure
-- - Unknown odd types are skipped
-decodeTlvStream :: BS.ByteString -> Either TlvError TlvStream
-decodeTlvStream = go Nothing []
+--
+-- The predicate determines which types are "known" for the context.
+decodeTlvStreamWith
+ :: (Word64 -> Bool) -- ^ Predicate: is this type known?
+ -> BS.ByteString
+ -> Either TlvError TlvStream
+decodeTlvStreamWith isKnown = go Nothing []
where
go :: Maybe Word64 -> [TlvRecord] -> BS.ByteString
-> Either TlvError TlvStream
@@ -261,18 +312,24 @@ decodeTlvStream = go Nothing []
!rest3 = BS.drop (fromIntegral len) rest2
!rec = TlvRecord typ val
-- Unknown type handling: even = fail, odd = skip
- if isKnownTlvType typ
+ if isKnown typ
then go (Just typ) (rec : acc) rest3
else if even typ
then Left (TlvUnknownEvenType typ)
else go (Just typ) acc rest3 -- skip unknown odd
--- | Check if a TLV type is known (for init_tlvs).
--- Types 1 (networks) and 3 (remote_addr) are known.
-isKnownTlvType :: Word64 -> Bool
-isKnownTlvType 1 = True -- networks
-isKnownTlvType 3 = True -- remote_addr
-isKnownTlvType _ = False
+-- | Decode a TLV stream with BOLT #1 init_tlvs validation.
+--
+-- This uses the default known types for init messages (1 and 3).
+-- For other contexts, use 'decodeTlvStreamWith' with an appropriate
+-- predicate.
+decodeTlvStream :: BS.ByteString -> Either TlvError TlvStream
+decodeTlvStream = decodeTlvStreamWith isInitTlvType
+ where
+ isInitTlvType :: Word64 -> Bool
+ isInitTlvType 1 = True -- networks
+ isInitTlvType 3 = True -- remote_addr
+ isInitTlvType _ = False
-- Init TLV types --------------------------------------------------------------
@@ -432,89 +489,93 @@ instance NFData Envelope
-- Message encoding ------------------------------------------------------------
+-- | Encoding errors.
+data EncodeError
+ = EncodeLengthOverflow -- ^ Payload exceeds u16 max (65535 bytes)
+ deriving stock (Eq, Show, Generic)
+
+instance NFData EncodeError
+
-- | Encode an Init message payload.
-encodeInit :: Init -> BS.ByteString
-encodeInit (Init gf feat tlvs) = mconcat
- [ encodeU16 (fromIntegral (BS.length gf))
- , gf
- , encodeU16 (fromIntegral (BS.length feat))
- , feat
- , encodeTlvStream (encodeInitTlvs tlvs)
- ]
+encodeInit :: Init -> Either EncodeError BS.ByteString
+encodeInit (Init gf feat tlvs) = do
+ gfLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength gf)
+ featLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength feat)
+ Right $ mconcat
+ [ gfLen
+ , gf
+ , featLen
+ , feat
+ , encodeTlvStream (encodeInitTlvs tlvs)
+ ]
-- | Encode an Error message payload.
-encodeError :: Error -> BS.ByteString
-encodeError (Error cid dat) = mconcat
- [ cid -- 32 bytes
- , encodeU16 (fromIntegral (BS.length dat))
- , dat
- ]
+encodeError :: Error -> Either EncodeError BS.ByteString
+encodeError (Error cid dat) = do
+ datLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength dat)
+ Right $ mconcat [cid, datLen, dat]
-- | Encode a Warning message payload.
-encodeWarning :: Warning -> BS.ByteString
-encodeWarning (Warning cid dat) = mconcat
- [ cid -- 32 bytes
- , encodeU16 (fromIntegral (BS.length dat))
- , dat
- ]
+encodeWarning :: Warning -> Either EncodeError BS.ByteString
+encodeWarning (Warning cid dat) = do
+ datLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength dat)
+ Right $ mconcat [cid, datLen, dat]
-- | Encode a Ping message payload.
-encodePing :: Ping -> BS.ByteString
-encodePing (Ping numPong ignored) = mconcat
- [ encodeU16 numPong
- , encodeU16 (fromIntegral (BS.length ignored))
- , ignored
- ]
+encodePing :: Ping -> Either EncodeError BS.ByteString
+encodePing (Ping numPong ignored) = do
+ ignoredLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength ignored)
+ Right $ mconcat [encodeU16 numPong, ignoredLen, ignored]
-- | Encode a Pong message payload.
-encodePong :: Pong -> BS.ByteString
-encodePong (Pong ignored) = mconcat
- [ encodeU16 (fromIntegral (BS.length ignored))
- , ignored
- ]
+encodePong :: Pong -> Either EncodeError BS.ByteString
+encodePong (Pong ignored) = do
+ ignoredLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength ignored)
+ Right $ mconcat [ignoredLen, ignored]
-- | Encode a PeerStorage message payload.
-encodePeerStorage :: PeerStorage -> BS.ByteString
-encodePeerStorage (PeerStorage blob) = mconcat
- [ encodeU16 (fromIntegral (BS.length blob))
- , blob
- ]
+encodePeerStorage :: PeerStorage -> Either EncodeError BS.ByteString
+encodePeerStorage (PeerStorage blob) = do
+ blobLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength blob)
+ Right $ mconcat [blobLen, blob]
-- | Encode a PeerStorageRetrieval message payload.
-encodePeerStorageRetrieval :: PeerStorageRetrieval -> BS.ByteString
-encodePeerStorageRetrieval (PeerStorageRetrieval blob) = mconcat
- [ encodeU16 (fromIntegral (BS.length blob))
- , blob
- ]
+encodePeerStorageRetrieval
+ :: PeerStorageRetrieval -> Either EncodeError BS.ByteString
+encodePeerStorageRetrieval (PeerStorageRetrieval blob) = do
+ blobLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength blob)
+ Right $ mconcat [blobLen, blob]
-- | Encode a message to its payload bytes.
-encodeMessage :: Message -> BS.ByteString
+encodeMessage :: Message -> Either EncodeError BS.ByteString
encodeMessage = \case
- MsgInitVal m -> encodeInit m
- MsgErrorVal m -> encodeError m
- MsgWarningVal m -> encodeWarning m
- MsgPingVal m -> encodePing m
- MsgPongVal m -> encodePong m
- MsgPeerStorageVal m -> encodePeerStorage m
+ MsgInitVal m -> encodeInit m
+ MsgErrorVal m -> encodeError m
+ MsgWarningVal m -> encodeWarning m
+ MsgPingVal m -> encodePing m
+ MsgPongVal m -> encodePong m
+ MsgPeerStorageVal m -> encodePeerStorage m
MsgPeerStorageRetrievalVal m -> encodePeerStorageRetrieval m
-- | Get the message type for a message.
messageType :: Message -> MsgType
messageType = \case
- MsgInitVal _ -> MsgInit
- MsgErrorVal _ -> MsgError
- MsgWarningVal _ -> MsgWarning
- MsgPingVal _ -> MsgPing
- MsgPongVal _ -> MsgPong
- MsgPeerStorageVal _ -> MsgPeerStorage
+ MsgInitVal _ -> MsgInit
+ MsgErrorVal _ -> MsgError
+ MsgWarningVal _ -> MsgWarning
+ MsgPingVal _ -> MsgPing
+ MsgPongVal _ -> MsgPong
+ MsgPeerStorageVal _ -> MsgPeerStorage
MsgPeerStorageRetrievalVal _ -> MsgPeerStorageRet
--- | Encode a message as a complete envelope (type + payload).
-encodeEnvelope :: Message -> Maybe TlvStream -> BS.ByteString
-encodeEnvelope msg mext = mconcat $
- [ encodeU16 (msgTypeWord (messageType msg))
- , encodeMessage msg
- ] ++ maybe [] (\ext -> [encodeTlvStream ext]) mext
+-- | Encode a message as a complete envelope (type + payload + extension).
+encodeEnvelope :: Message -> Maybe TlvStream -> Either EncodeError BS.ByteString
+encodeEnvelope msg mext = do
+ payload <- encodeMessage msg
+ Right $ mconcat $
+ [ encodeU16 (msgTypeWord (messageType msg))
+ , payload
+ ] ++ maybe [] (\ext -> [encodeTlvStream ext]) mext
-- Message decoding ------------------------------------------------------------
@@ -523,14 +584,18 @@ data DecodeError
= DecodeInsufficientBytes
| DecodeInvalidLength
| DecodeUnknownEvenType !Word16
+ | DecodeUnknownOddType !Word16
| DecodeTlvError !TlvError
| DecodeInvalidChannelId
+ | DecodeInvalidExtension !TlvError
deriving stock (Eq, Show, Generic)
instance NFData DecodeError
-- | Decode an Init message from payload bytes.
-decodeInit :: BS.ByteString -> Either DecodeError Init
+--
+-- Returns the decoded message and any remaining bytes.
+decodeInit :: BS.ByteString -> Either DecodeError (Init, BS.ByteString)
decodeInit !bs = do
(gfLen, rest1) <- maybe (Left DecodeInsufficientBytes) Right
(decodeU16 bs)
@@ -544,16 +609,17 @@ decodeInit !bs = do
Left DecodeInsufficientBytes
let !feat = BS.take (fromIntegral fLen) rest3
!rest4 = BS.drop (fromIntegral fLen) rest3
- -- Parse optional TLV stream
+ -- Parse optional TLV stream (consumes all remaining bytes for init)
tlvStream <- if BS.null rest4
then Right (TlvStream [])
else either (Left . DecodeTlvError) Right (decodeTlvStream rest4)
initTlvList <- either (Left . DecodeTlvError) Right
(parseInitTlvs tlvStream)
- Right (Init gf feat initTlvList)
+ -- Init consumes all bytes (TLVs are part of init, not extensions)
+ Right (Init gf feat initTlvList, BS.empty)
-- | Decode an Error message from payload bytes.
-decodeError :: BS.ByteString -> Either DecodeError Error
+decodeError :: BS.ByteString -> Either DecodeError (Error, BS.ByteString)
decodeError !bs = do
unless (BS.length bs >= 32) $ Left DecodeInsufficientBytes
let !cid = BS.take 32 bs
@@ -563,10 +629,11 @@ decodeError !bs = do
unless (BS.length rest2 >= fromIntegral dLen) $
Left DecodeInsufficientBytes
let !dat = BS.take (fromIntegral dLen) rest2
- Right (Error cid dat)
+ !rest3 = BS.drop (fromIntegral dLen) rest2
+ Right (Error cid dat, rest3)
-- | Decode a Warning message from payload bytes.
-decodeWarning :: BS.ByteString -> Either DecodeError Warning
+decodeWarning :: BS.ByteString -> Either DecodeError (Warning, BS.ByteString)
decodeWarning !bs = do
unless (BS.length bs >= 32) $ Left DecodeInsufficientBytes
let !cid = BS.take 32 bs
@@ -576,10 +643,11 @@ decodeWarning !bs = do
unless (BS.length rest2 >= fromIntegral dLen) $
Left DecodeInsufficientBytes
let !dat = BS.take (fromIntegral dLen) rest2
- Right (Warning cid dat)
+ !rest3 = BS.drop (fromIntegral dLen) rest2
+ Right (Warning cid dat, rest3)
-- | Decode a Ping message from payload bytes.
-decodePing :: BS.ByteString -> Either DecodeError Ping
+decodePing :: BS.ByteString -> Either DecodeError (Ping, BS.ByteString)
decodePing !bs = do
(numPong, rest1) <- maybe (Left DecodeInsufficientBytes) Right
(decodeU16 bs)
@@ -588,60 +656,87 @@ decodePing !bs = do
unless (BS.length rest2 >= fromIntegral bLen) $
Left DecodeInsufficientBytes
let !ignored = BS.take (fromIntegral bLen) rest2
- Right (Ping numPong ignored)
+ !rest3 = BS.drop (fromIntegral bLen) rest2
+ Right (Ping numPong ignored, rest3)
-- | Decode a Pong message from payload bytes.
-decodePong :: BS.ByteString -> Either DecodeError Pong
+decodePong :: BS.ByteString -> Either DecodeError (Pong, BS.ByteString)
decodePong !bs = do
(bLen, rest1) <- maybe (Left DecodeInsufficientBytes) Right
(decodeU16 bs)
unless (BS.length rest1 >= fromIntegral bLen) $
Left DecodeInsufficientBytes
let !ignored = BS.take (fromIntegral bLen) rest1
- Right (Pong ignored)
+ !rest2 = BS.drop (fromIntegral bLen) rest1
+ Right (Pong ignored, rest2)
-- | Decode a PeerStorage message from payload bytes.
-decodePeerStorage :: BS.ByteString -> Either DecodeError PeerStorage
+decodePeerStorage
+ :: BS.ByteString -> Either DecodeError (PeerStorage, BS.ByteString)
decodePeerStorage !bs = do
(bLen, rest1) <- maybe (Left DecodeInsufficientBytes) Right
(decodeU16 bs)
unless (BS.length rest1 >= fromIntegral bLen) $
Left DecodeInsufficientBytes
let !blob = BS.take (fromIntegral bLen) rest1
- Right (PeerStorage blob)
+ !rest2 = BS.drop (fromIntegral bLen) rest1
+ Right (PeerStorage blob, rest2)
-- | Decode a PeerStorageRetrieval message from payload bytes.
-decodePeerStorageRetrieval :: BS.ByteString
- -> Either DecodeError PeerStorageRetrieval
+decodePeerStorageRetrieval
+ :: BS.ByteString
+ -> Either DecodeError (PeerStorageRetrieval, BS.ByteString)
decodePeerStorageRetrieval !bs = do
(bLen, rest1) <- maybe (Left DecodeInsufficientBytes) Right
(decodeU16 bs)
unless (BS.length rest1 >= fromIntegral bLen) $
Left DecodeInsufficientBytes
let !blob = BS.take (fromIntegral bLen) rest1
- Right (PeerStorageRetrieval blob)
+ !rest2 = BS.drop (fromIntegral bLen) rest1
+ Right (PeerStorageRetrieval blob, rest2)
-- | Decode a message from its type and payload.
-decodeMessage :: MsgType -> BS.ByteString -> Either DecodeError Message
-decodeMessage MsgInit bs = MsgInitVal <$> decodeInit bs
-decodeMessage MsgError bs = MsgErrorVal <$> decodeError bs
-decodeMessage MsgWarning bs = MsgWarningVal <$> decodeWarning bs
-decodeMessage MsgPing bs = MsgPingVal <$> decodePing bs
-decodeMessage MsgPong bs = MsgPongVal <$> decodePong bs
-decodeMessage MsgPeerStorage bs = MsgPeerStorageVal <$> decodePeerStorage bs
-decodeMessage MsgPeerStorageRet bs =
- MsgPeerStorageRetrievalVal <$> decodePeerStorageRetrieval bs
+--
+-- Returns the decoded message and any remaining bytes (for extensions).
+-- For unknown types, returns an appropriate error.
+decodeMessage
+ :: MsgType -> BS.ByteString -> Either DecodeError (Message, BS.ByteString)
+decodeMessage MsgInit bs = do
+ (m, rest) <- decodeInit bs
+ Right (MsgInitVal m, rest)
+decodeMessage MsgError bs = do
+ (m, rest) <- decodeError bs
+ Right (MsgErrorVal m, rest)
+decodeMessage MsgWarning bs = do
+ (m, rest) <- decodeWarning bs
+ Right (MsgWarningVal m, rest)
+decodeMessage MsgPing bs = do
+ (m, rest) <- decodePing bs
+ Right (MsgPingVal m, rest)
+decodeMessage MsgPong bs = do
+ (m, rest) <- decodePong bs
+ Right (MsgPongVal m, rest)
+decodeMessage MsgPeerStorage bs = do
+ (m, rest) <- decodePeerStorage bs
+ Right (MsgPeerStorageVal m, rest)
+decodeMessage MsgPeerStorageRet bs = do
+ (m, rest) <- decodePeerStorageRetrieval bs
+ Right (MsgPeerStorageRetrievalVal m, rest)
decodeMessage (MsgUnknown w) _
| even w = Left (DecodeUnknownEvenType w)
- | otherwise = Left DecodeInsufficientBytes
+ | otherwise = Left (DecodeUnknownOddType w)
-- | Decode a complete envelope (type + payload + optional extension).
--
-- Per BOLT #1:
--- - Unknown odd message types are ignored (returns Nothing)
+-- - Unknown odd message types are ignored (returns Nothing for message)
-- - Unknown even message types cause connection close (returns error)
--- - Invalid extension TLV causes connection close
-decodeEnvelope :: BS.ByteString -> Either DecodeError (Maybe Message)
+-- - Invalid extension TLV causes connection close (returns error)
+--
+-- Returns the decoded message (if known) and any extension TLVs.
+decodeEnvelope
+ :: BS.ByteString
+ -> Either DecodeError (Maybe Message, Maybe TlvStream)
decodeEnvelope !bs = do
(typeWord, rest1) <- maybe (Left DecodeInsufficientBytes) Right
(decodeU16 bs)
@@ -649,7 +744,13 @@ decodeEnvelope !bs = do
case msgType of
MsgUnknown w
| even w -> Left (DecodeUnknownEvenType w)
- | otherwise -> Right Nothing -- Ignore unknown odd types
+ | otherwise -> Right (Nothing, Nothing) -- Ignore unknown odd types
_ -> do
- msg <- decodeMessage msgType rest1
- Right (Just msg)
+ (msg, rest2) <- decodeMessage msgType rest1
+ -- Parse any remaining bytes as extension TLV
+ ext <- if BS.null rest2
+ then Right Nothing
+ else case decodeTlvStreamRaw rest2 of
+ Left e -> Left (DecodeInvalidExtension e)
+ Right s -> Right (Just s)
+ Right (Just msg, ext)
diff --git a/test/Main.hs b/test/Main.hs
@@ -16,6 +16,8 @@ main = defaultMain $ testGroup "ppad-bolt1" [
, tlv_tests
, message_tests
, envelope_tests
+ , extension_tests
+ , bounds_tests
, property_tests
]
@@ -152,6 +154,23 @@ tlv_tests = testGroup "TLV" [
Left TlvLengthExceedsBounds -> pure ()
other -> assertFailure $ "expected TlvLengthExceedsBounds: " ++
show other
+ , testCase "decodeTlvStreamWith custom predicate" $ do
+ -- Use a predicate that only knows type 5
+ let isKnown t = t == 5
+ bs = mconcat [
+ encodeBigSize 5, encodeBigSize 2, "hi"
+ ]
+ case decodeTlvStreamWith isKnown bs of
+ Right (TlvStream [r]) -> tlvType r @?= 5
+ other -> assertFailure $ "unexpected: " ++ show other
+ , testCase "decodeTlvStreamRaw returns all records" $ do
+ let bs = mconcat [
+ encodeBigSize 2, encodeBigSize 1, "a" -- even type
+ , encodeBigSize 5, encodeBigSize 1, "b" -- odd type
+ ]
+ case decodeTlvStreamRaw bs of
+ Right (TlvStream recs) -> length recs @?= 2
+ Left e -> assertFailure $ "unexpected error: " ++ show e
]
-- Message encode/decode tests -------------------------------------------------
@@ -161,32 +180,36 @@ message_tests = testGroup "Messages" [
testGroup "Init" [
testCase "encode/decode minimal init" $ do
let msg = Init "" "" []
- encoded = encodeMessage (MsgInitVal msg)
- case decodeMessage MsgInit encoded of
- Right (MsgInitVal decoded) -> decoded @?= msg
- other -> assertFailure $ "unexpected: " ++ show other
+ case encodeMessage (MsgInitVal msg) of
+ Left e -> assertFailure $ "encode failed: " ++ show e
+ Right encoded -> case decodeMessage MsgInit encoded of
+ Right (MsgInitVal decoded, _) -> decoded @?= msg
+ other -> assertFailure $ "unexpected: " ++ show other
, testCase "encode/decode init with features" $ do
let msg = Init (BS.pack [0x01]) (BS.pack [0x02, 0x0a]) []
- encoded = encodeMessage (MsgInitVal msg)
- case decodeMessage MsgInit encoded of
- Right (MsgInitVal decoded) -> decoded @?= msg
- other -> assertFailure $ "unexpected: " ++ show other
+ case encodeMessage (MsgInitVal msg) of
+ Left e -> assertFailure $ "encode failed: " ++ show e
+ Right encoded -> case decodeMessage MsgInit encoded of
+ Right (MsgInitVal decoded, _) -> decoded @?= msg
+ other -> assertFailure $ "unexpected: " ++ show other
, testCase "encode/decode init with networks TLV" $ do
let chainHash = BS.replicate 32 0xab
msg = Init "" "" [InitNetworks [chainHash]]
- encoded = encodeMessage (MsgInitVal msg)
- case decodeMessage MsgInit encoded of
- Right (MsgInitVal decoded) -> decoded @?= msg
- other -> assertFailure $ "unexpected: " ++ show other
+ case encodeMessage (MsgInitVal msg) of
+ Left e -> assertFailure $ "encode failed: " ++ show e
+ Right encoded -> case decodeMessage MsgInit encoded of
+ Right (MsgInitVal decoded, _) -> decoded @?= msg
+ other -> assertFailure $ "unexpected: " ++ show other
]
, testGroup "Error" [
testCase "encode/decode error" $ do
let cid = BS.replicate 32 0xff
msg = Error cid "something went wrong"
- encoded = encodeMessage (MsgErrorVal msg)
- case decodeMessage MsgError encoded of
- Right (MsgErrorVal decoded) -> decoded @?= msg
- other -> assertFailure $ "unexpected: " ++ show other
+ case encodeMessage (MsgErrorVal msg) of
+ Left e -> assertFailure $ "encode failed: " ++ show e
+ Right encoded -> case decodeMessage MsgError encoded of
+ Right (MsgErrorVal decoded, _) -> decoded @?= msg
+ other -> assertFailure $ "unexpected: " ++ show other
, testCase "error insufficient channel_id" $ do
case decodeMessage MsgError (BS.replicate 31 0x00) of
Left DecodeInsufficientBytes -> pure ()
@@ -196,48 +219,64 @@ message_tests = testGroup "Messages" [
testCase "encode/decode warning" $ do
let cid = BS.replicate 32 0x00
msg = Warning cid "be careful"
- encoded = encodeMessage (MsgWarningVal msg)
- case decodeMessage MsgWarning encoded of
- Right (MsgWarningVal decoded) -> decoded @?= msg
- other -> assertFailure $ "unexpected: " ++ show other
+ case encodeMessage (MsgWarningVal msg) of
+ Left e -> assertFailure $ "encode failed: " ++ show e
+ Right encoded -> case decodeMessage MsgWarning encoded of
+ Right (MsgWarningVal decoded, _) -> decoded @?= msg
+ other -> assertFailure $ "unexpected: " ++ show other
]
, testGroup "Ping" [
testCase "encode/decode ping" $ do
let msg = Ping 100 (BS.replicate 10 0x00)
- encoded = encodeMessage (MsgPingVal msg)
- case decodeMessage MsgPing encoded of
- Right (MsgPingVal decoded) -> decoded @?= msg
- other -> assertFailure $ "unexpected: " ++ show other
+ case encodeMessage (MsgPingVal msg) of
+ Left e -> assertFailure $ "encode failed: " ++ show e
+ Right encoded -> case decodeMessage MsgPing encoded of
+ Right (MsgPingVal decoded, _) -> decoded @?= msg
+ other -> assertFailure $ "unexpected: " ++ show other
, testCase "ping with zero ignored" $ do
let msg = Ping 50 ""
- encoded = encodeMessage (MsgPingVal msg)
- case decodeMessage MsgPing encoded of
- Right (MsgPingVal decoded) -> decoded @?= msg
- other -> assertFailure $ "unexpected: " ++ show other
+ case encodeMessage (MsgPingVal msg) of
+ Left e -> assertFailure $ "encode failed: " ++ show e
+ Right encoded -> case decodeMessage MsgPing encoded of
+ Right (MsgPingVal decoded, _) -> decoded @?= msg
+ other -> assertFailure $ "unexpected: " ++ show other
]
, testGroup "Pong" [
testCase "encode/decode pong" $ do
let msg = Pong (BS.replicate 100 0x00)
- encoded = encodeMessage (MsgPongVal msg)
- case decodeMessage MsgPong encoded of
- Right (MsgPongVal decoded) -> decoded @?= msg
- other -> assertFailure $ "unexpected: " ++ show other
+ case encodeMessage (MsgPongVal msg) of
+ Left e -> assertFailure $ "encode failed: " ++ show e
+ Right encoded -> case decodeMessage MsgPong encoded of
+ Right (MsgPongVal decoded, _) -> decoded @?= msg
+ other -> assertFailure $ "unexpected: " ++ show other
]
, testGroup "PeerStorage" [
testCase "encode/decode peer_storage" $ do
let msg = PeerStorage "encrypted blob data"
- encoded = encodeMessage (MsgPeerStorageVal msg)
- case decodeMessage MsgPeerStorage encoded of
- Right (MsgPeerStorageVal decoded) -> decoded @?= msg
- other -> assertFailure $ "unexpected: " ++ show other
+ case encodeMessage (MsgPeerStorageVal msg) of
+ Left e -> assertFailure $ "encode failed: " ++ show e
+ Right encoded -> case decodeMessage MsgPeerStorage encoded of
+ Right (MsgPeerStorageVal decoded, _) -> decoded @?= msg
+ other -> assertFailure $ "unexpected: " ++ show other
]
, testGroup "PeerStorageRetrieval" [
testCase "encode/decode peer_storage_retrieval" $ do
let msg = PeerStorageRetrieval "retrieved blob"
- encoded = encodeMessage (MsgPeerStorageRetrievalVal msg)
- case decodeMessage MsgPeerStorageRet encoded of
- Right (MsgPeerStorageRetrievalVal decoded) -> decoded @?= msg
- other -> assertFailure $ "unexpected: " ++ show other
+ case encodeMessage (MsgPeerStorageRetrievalVal msg) of
+ Left e -> assertFailure $ "encode failed: " ++ show e
+ Right encoded -> case decodeMessage MsgPeerStorageRet encoded of
+ Right (MsgPeerStorageRetrievalVal decoded, _) -> decoded @?= msg
+ other -> assertFailure $ "unexpected: " ++ show other
+ ]
+ , testGroup "Unknown types" [
+ testCase "decodeMessage unknown even type" $ do
+ case decodeMessage (MsgUnknown 100) "payload" of
+ Left (DecodeUnknownEvenType 100) -> pure ()
+ other -> assertFailure $ "expected unknown even: " ++ show other
+ , testCase "decodeMessage unknown odd type" $ do
+ case decodeMessage (MsgUnknown 101) "payload" of
+ Left (DecodeUnknownOddType 101) -> pure ()
+ other -> assertFailure $ "expected unknown odd: " ++ show other
]
]
@@ -247,16 +286,18 @@ envelope_tests :: TestTree
envelope_tests = testGroup "Envelope" [
testCase "encode/decode init envelope" $ do
let msg = MsgInitVal (Init "" "" [])
- encoded = encodeEnvelope msg Nothing
- case decodeEnvelope encoded of
- Right (Just decoded) -> decoded @?= msg
- other -> assertFailure $ "unexpected: " ++ show other
+ case encodeEnvelope msg Nothing of
+ Left e -> assertFailure $ "encode failed: " ++ show e
+ Right encoded -> case decodeEnvelope encoded of
+ Right (Just decoded, _) -> decoded @?= msg
+ other -> assertFailure $ "unexpected: " ++ show other
, testCase "encode/decode ping envelope" $ do
let msg = MsgPingVal (Ping 10 "")
- encoded = encodeEnvelope msg Nothing
- case decodeEnvelope encoded of
- Right (Just decoded) -> decoded @?= msg
- other -> assertFailure $ "unexpected: " ++ show other
+ case encodeEnvelope msg Nothing of
+ Left e -> assertFailure $ "encode failed: " ++ show e
+ Right encoded -> case decodeEnvelope encoded of
+ Right (Just decoded, _) -> decoded @?= msg
+ other -> assertFailure $ "unexpected: " ++ show other
, testCase "unknown even type fails" $ do
let bs = encodeU16 100 <> "payload" -- 100 is even, unknown
case decodeEnvelope bs of
@@ -265,8 +306,8 @@ envelope_tests = testGroup "Envelope" [
, testCase "unknown odd type ignored" $ do
let bs = encodeU16 101 <> "payload" -- 101 is odd, unknown
case decodeEnvelope bs of
- Right Nothing -> pure () -- ignored
- other -> assertFailure $ "expected Nothing: " ++ show other
+ Right (Nothing, Nothing) -> pure () -- ignored
+ other -> assertFailure $ "expected (Nothing, Nothing): " ++ show other
, testCase "insufficient bytes for type" $ do
case decodeEnvelope (BS.pack [0x00]) of
Left DecodeInsufficientBytes -> pure ()
@@ -281,6 +322,83 @@ envelope_tests = testGroup "Envelope" [
msgTypeWord MsgPeerStorageRet @?= 9
]
+-- Extension TLV tests ---------------------------------------------------------
+
+extension_tests :: TestTree
+extension_tests = testGroup "Extension TLV" [
+ testCase "encode envelope with extension" $ do
+ let msg = MsgPingVal (Ping 10 "")
+ ext = TlvStream [TlvRecord 100 "extension data"]
+ case encodeEnvelope msg (Just ext) of
+ Left e -> assertFailure $ "encode failed: " ++ show e
+ Right encoded -> do
+ -- Should contain message + extension
+ assertBool "encoded should be longer" (BS.length encoded > 6)
+ , testCase "decode envelope with extension roundtrip" $ do
+ let msg = MsgPingVal (Ping 10 "")
+ ext = TlvStream [TlvRecord 101 "ext"]
+ case encodeEnvelope msg (Just ext) of
+ Left e -> assertFailure $ "encode failed: " ++ show e
+ Right encoded -> case decodeEnvelope encoded of
+ Right (Just decoded, Just decodedExt) -> do
+ decoded @?= msg
+ length (unTlvStream decodedExt) @?= 1
+ other -> assertFailure $ "unexpected: " ++ show other
+ , testCase "decode envelope extension is parsed" $ do
+ -- Manually construct ping + extension TLV
+ let pingPayload = mconcat [encodeU16 10, encodeU16 0] -- numPong=10, len=0
+ extTlv = mconcat [encodeBigSize 200, encodeBigSize 3, "abc"]
+ envelope = encodeU16 18 <> pingPayload <> extTlv -- type 18 = ping
+ case decodeEnvelope envelope of
+ Right (Just (MsgPingVal ping), Just (TlvStream [r])) -> do
+ pingNumPongBytes ping @?= 10
+ tlvType r @?= 200
+ tlvValue r @?= "abc"
+ other -> assertFailure $ "unexpected: " ++ show other
+ , testCase "decode envelope with invalid extension fails" $ do
+ -- Ping + invalid TLV (non-strictly-increasing)
+ let pingPayload = mconcat [encodeU16 10, encodeU16 0]
+ badTlv = mconcat [
+ encodeBigSize 100, encodeBigSize 1, "a"
+ , encodeBigSize 50, encodeBigSize 1, "b" -- 50 < 100, invalid
+ ]
+ envelope = encodeU16 18 <> pingPayload <> badTlv
+ case decodeEnvelope envelope of
+ Left (DecodeInvalidExtension TlvNotStrictlyIncreasing) -> pure ()
+ other -> assertFailure $ "expected invalid extension: " ++ show other
+ ]
+
+-- Bounds checking tests -------------------------------------------------------
+
+bounds_tests :: TestTree
+bounds_tests = testGroup "Bounds checking" [
+ testCase "encode ping with oversized ignored fails" $ do
+ let msg = Ping 10 (BS.replicate 70000 0x00) -- > 65535
+ case encodeMessage (MsgPingVal msg) of
+ Left EncodeLengthOverflow -> pure ()
+ other -> assertFailure $ "expected overflow: " ++ show other
+ , testCase "encode pong with oversized ignored fails" $ do
+ let msg = Pong (BS.replicate 70000 0x00)
+ case encodeMessage (MsgPongVal msg) of
+ Left EncodeLengthOverflow -> pure ()
+ other -> assertFailure $ "expected overflow: " ++ show other
+ , testCase "encode error with oversized data fails" $ do
+ let msg = Error (BS.replicate 32 0x00) (BS.replicate 70000 0x00)
+ case encodeMessage (MsgErrorVal msg) of
+ Left EncodeLengthOverflow -> pure ()
+ other -> assertFailure $ "expected overflow: " ++ show other
+ , testCase "encode init with oversized features fails" $ do
+ let msg = Init "" (BS.replicate 70000 0x00) []
+ case encodeMessage (MsgInitVal msg) of
+ Left EncodeLengthOverflow -> pure ()
+ other -> assertFailure $ "expected overflow: " ++ show other
+ , testCase "encode peer_storage with oversized blob fails" $ do
+ let msg = PeerStorage (BS.replicate 70000 0x00)
+ case encodeMessage (MsgPeerStorageVal msg) of
+ Left EncodeLengthOverflow -> pure ()
+ other -> assertFailure $ "expected overflow: " ++ show other
+ ]
+
-- Property tests --------------------------------------------------------------
property_tests :: TestTree
@@ -296,36 +414,62 @@ property_tests = testGroup "Properties" [
, testProperty "U64 roundtrip" $ \w ->
decodeU64 (encodeU64 w) == Just (w, "")
, testProperty "Ping roundtrip" $ \(NonNegative num) bs ->
- let msg = Ping (fromIntegral (num `mod` 65536 :: Integer))
- (BS.pack bs)
- encoded = encodeMessage (MsgPingVal msg)
- in case decodeMessage MsgPing encoded of
- Right (MsgPingVal decoded) -> decoded == msg
- _ -> False
+ let ignored = BS.pack (take 1000 bs) -- limit size
+ msg = Ping (fromIntegral (num `mod` 65536 :: Integer)) ignored
+ in case encodeMessage (MsgPingVal msg) of
+ Left _ -> False
+ Right encoded -> case decodeMessage MsgPing encoded of
+ Right (MsgPingVal decoded, rest) ->
+ decoded == msg && BS.null rest
+ _ -> False
, testProperty "Pong roundtrip" $ \bs ->
- let msg = Pong (BS.pack bs)
- encoded = encodeMessage (MsgPongVal msg)
- in case decodeMessage MsgPong encoded of
- Right (MsgPongVal decoded) -> decoded == msg
- _ -> False
+ let ignored = BS.pack (take 1000 bs)
+ msg = Pong ignored
+ in case encodeMessage (MsgPongVal msg) of
+ Left _ -> False
+ Right encoded -> case decodeMessage MsgPong encoded of
+ Right (MsgPongVal decoded, rest) ->
+ decoded == msg && BS.null rest
+ _ -> False
, testProperty "PeerStorage roundtrip" $ \bs ->
- let msg = PeerStorage (BS.pack bs)
- encoded = encodeMessage (MsgPeerStorageVal msg)
- in case decodeMessage MsgPeerStorage encoded of
- Right (MsgPeerStorageVal decoded) -> decoded == msg
- _ -> False
+ let blob = BS.pack (take 1000 bs)
+ msg = PeerStorage blob
+ in case encodeMessage (MsgPeerStorageVal msg) of
+ Left _ -> False
+ Right encoded -> case decodeMessage MsgPeerStorage encoded of
+ Right (MsgPeerStorageVal decoded, rest) ->
+ decoded == msg && BS.null rest
+ _ -> False
, testProperty "Error roundtrip" $ \bs ->
let cid = BS.replicate 32 0x00
- msg = Error cid (BS.pack bs)
- encoded = encodeMessage (MsgErrorVal msg)
- in case decodeMessage MsgError encoded of
- Right (MsgErrorVal decoded) -> decoded == msg
- _ -> False
+ dat = BS.pack (take 1000 bs)
+ msg = Error cid dat
+ in case encodeMessage (MsgErrorVal msg) of
+ Left _ -> False
+ Right encoded -> case decodeMessage MsgError encoded of
+ Right (MsgErrorVal decoded, rest) ->
+ decoded == msg && BS.null rest
+ _ -> False
+ , testProperty "Envelope with extension roundtrip" $ \bs ->
+ let msg = MsgPingVal (Ping 42 "")
+ extData = BS.pack (take 100 bs)
+ ext = TlvStream [TlvRecord 101 extData]
+ in case encodeEnvelope msg (Just ext) of
+ Left _ -> False
+ Right encoded -> case decodeEnvelope encoded of
+ Right (Just decoded, Just (TlvStream [r])) ->
+ decoded == msg && tlvType r == 101 && tlvValue r == extData
+ _ -> False
]
-- Helpers ---------------------------------------------------------------------
+-- | Decode hex string. Fails the test on invalid hex.
unhex :: BS.ByteString -> BS.ByteString
unhex bs = case B16.decode bs of
Just r -> r
- Nothing -> error $ "invalid hex: " ++ show bs
+ Nothing -> assertFailure' $ "invalid hex: " ++ show bs
+
+-- | assertFailure that returns any type (for use in pure contexts)
+assertFailure' :: String -> a
+assertFailure' msg = error msg