bolt1

Base Lightning protocol, per BOLT #1.
git clone git://git.ppad.tech/bolt1.git
Log | Files | Refs | README | LICENSE

commit 56d27f6673b4a125218f55ff1ee0d2d74889183b
parent 80d0966d9fc9cf42d98d8f3d1d470defbdad6a01
Author: Jared Tobin <jared@jtobin.io>
Date:   Sun, 25 Jan 2026 10:24:13 +0400

Merge branch 'impl/review-fixes'

Address review findings for BOLT #1 message codec (REVIEW-80d0966):

High severity fixes:
- decodeEnvelope now parses and returns extension TLVs
- Add decodeTlvStreamWith for configurable known-type predicates
- Add decodeTlvStreamRaw for raw TLV parsing without type validation

Medium severity fixes:
- Add bounds checking for u16 length fields (EncodeLengthOverflow error)
- Encoding functions now return Either EncodeError ByteString
- decodeMessage returns DecodeUnknownOddType for unknown odd types

API changes:
- encodeMessage/encodeEnvelope return Either EncodeError ByteString
- decodeMessage returns (Message, ByteString) tuple for remaining bytes
- decodeEnvelope returns (Maybe Message, Maybe TlvStream)
- New exports: decodeTlvStreamWith, decodeTlvStreamRaw, EncodeError

Tests expanded from 55 to 75 cases covering extension handling,
bounds checking, and proper error types.

Diffstat:
Mlib/Lightning/Protocol/BOLT1.hs | 305++++++++++++++++++++++++++++++++++++++++++++++++++++---------------------------
Mtest/Main.hs | 288+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------
2 files changed, 419 insertions(+), 174 deletions(-)

diff --git a/lib/Lightning/Protocol/BOLT1.hs b/lib/Lightning/Protocol/BOLT1.hs @@ -39,6 +39,8 @@ module Lightning.Protocol.BOLT1 ( , TlvError(..) , encodeTlvStream , decodeTlvStream + , decodeTlvStreamWith + , decodeTlvStreamRaw -- ** Init TLVs , InitTlv(..) @@ -47,6 +49,7 @@ module Lightning.Protocol.BOLT1 ( , Envelope(..) -- * Encoding + , EncodeError(..) , encodeMessage , encodeEnvelope @@ -121,6 +124,15 @@ encodeBigSize !x | otherwise = BS.cons 0xff (encodeU64 x) {-# INLINE encodeBigSize #-} +-- | Encode a length as u16, checking bounds. +-- +-- Returns Nothing if the length exceeds 65535. +encodeLength :: BS.ByteString -> Maybe BS.ByteString +encodeLength !bs + | BS.length bs > 65535 = Nothing + | otherwise = Just (encodeU16 (fromIntegral (BS.length bs))) +{-# INLINE encodeLength #-} + -- Primitive decoding ---------------------------------------------------------- -- | Decode a 16-bit unsigned integer (big-endian). @@ -232,13 +244,52 @@ data TlvError instance NFData TlvError --- | Decode a TLV stream with BOLT #1 validation. +-- | Decode a TLV stream without any known-type validation. +-- +-- This decoder only enforces structural validity: +-- - Types must be strictly increasing +-- - Lengths must not exceed bounds +-- +-- All records are returned regardless of type. Use this for parsing +-- extension TLVs or when you want to handle type validation separately. +decodeTlvStreamRaw :: BS.ByteString -> Either TlvError TlvStream +decodeTlvStreamRaw = go Nothing [] + where + go :: Maybe Word64 -> [TlvRecord] -> BS.ByteString + -> Either TlvError TlvStream + go !_ !acc !bs + | BS.null bs = Right (TlvStream (reverse acc)) + go !mPrevType !acc !bs = do + (typ, rest1) <- maybe (Left TlvNonMinimalEncoding) Right + (decodeBigSize bs) + -- Strictly increasing check + case mPrevType of + Just prevType -> when (typ <= prevType) $ + Left TlvNotStrictlyIncreasing + Nothing -> pure () + (len, rest2) <- maybe (Left TlvNonMinimalEncoding) Right + (decodeBigSize rest1) + -- Length bounds check + when (fromIntegral len > BS.length rest2) $ + Left TlvLengthExceedsBounds + let !val = BS.take (fromIntegral len) rest2 + !rest3 = BS.drop (fromIntegral len) rest2 + !rec = TlvRecord typ val + go (Just typ) (rec : acc) rest3 + +-- | Decode a TLV stream with configurable known-type predicate. -- +-- Per BOLT #1: -- - Types must be strictly increasing -- - Unknown even types cause failure -- - Unknown odd types are skipped -decodeTlvStream :: BS.ByteString -> Either TlvError TlvStream -decodeTlvStream = go Nothing [] +-- +-- The predicate determines which types are "known" for the context. +decodeTlvStreamWith + :: (Word64 -> Bool) -- ^ Predicate: is this type known? + -> BS.ByteString + -> Either TlvError TlvStream +decodeTlvStreamWith isKnown = go Nothing [] where go :: Maybe Word64 -> [TlvRecord] -> BS.ByteString -> Either TlvError TlvStream @@ -261,18 +312,24 @@ decodeTlvStream = go Nothing [] !rest3 = BS.drop (fromIntegral len) rest2 !rec = TlvRecord typ val -- Unknown type handling: even = fail, odd = skip - if isKnownTlvType typ + if isKnown typ then go (Just typ) (rec : acc) rest3 else if even typ then Left (TlvUnknownEvenType typ) else go (Just typ) acc rest3 -- skip unknown odd --- | Check if a TLV type is known (for init_tlvs). --- Types 1 (networks) and 3 (remote_addr) are known. -isKnownTlvType :: Word64 -> Bool -isKnownTlvType 1 = True -- networks -isKnownTlvType 3 = True -- remote_addr -isKnownTlvType _ = False +-- | Decode a TLV stream with BOLT #1 init_tlvs validation. +-- +-- This uses the default known types for init messages (1 and 3). +-- For other contexts, use 'decodeTlvStreamWith' with an appropriate +-- predicate. +decodeTlvStream :: BS.ByteString -> Either TlvError TlvStream +decodeTlvStream = decodeTlvStreamWith isInitTlvType + where + isInitTlvType :: Word64 -> Bool + isInitTlvType 1 = True -- networks + isInitTlvType 3 = True -- remote_addr + isInitTlvType _ = False -- Init TLV types -------------------------------------------------------------- @@ -432,89 +489,93 @@ instance NFData Envelope -- Message encoding ------------------------------------------------------------ +-- | Encoding errors. +data EncodeError + = EncodeLengthOverflow -- ^ Payload exceeds u16 max (65535 bytes) + deriving stock (Eq, Show, Generic) + +instance NFData EncodeError + -- | Encode an Init message payload. -encodeInit :: Init -> BS.ByteString -encodeInit (Init gf feat tlvs) = mconcat - [ encodeU16 (fromIntegral (BS.length gf)) - , gf - , encodeU16 (fromIntegral (BS.length feat)) - , feat - , encodeTlvStream (encodeInitTlvs tlvs) - ] +encodeInit :: Init -> Either EncodeError BS.ByteString +encodeInit (Init gf feat tlvs) = do + gfLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength gf) + featLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength feat) + Right $ mconcat + [ gfLen + , gf + , featLen + , feat + , encodeTlvStream (encodeInitTlvs tlvs) + ] -- | Encode an Error message payload. -encodeError :: Error -> BS.ByteString -encodeError (Error cid dat) = mconcat - [ cid -- 32 bytes - , encodeU16 (fromIntegral (BS.length dat)) - , dat - ] +encodeError :: Error -> Either EncodeError BS.ByteString +encodeError (Error cid dat) = do + datLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength dat) + Right $ mconcat [cid, datLen, dat] -- | Encode a Warning message payload. -encodeWarning :: Warning -> BS.ByteString -encodeWarning (Warning cid dat) = mconcat - [ cid -- 32 bytes - , encodeU16 (fromIntegral (BS.length dat)) - , dat - ] +encodeWarning :: Warning -> Either EncodeError BS.ByteString +encodeWarning (Warning cid dat) = do + datLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength dat) + Right $ mconcat [cid, datLen, dat] -- | Encode a Ping message payload. -encodePing :: Ping -> BS.ByteString -encodePing (Ping numPong ignored) = mconcat - [ encodeU16 numPong - , encodeU16 (fromIntegral (BS.length ignored)) - , ignored - ] +encodePing :: Ping -> Either EncodeError BS.ByteString +encodePing (Ping numPong ignored) = do + ignoredLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength ignored) + Right $ mconcat [encodeU16 numPong, ignoredLen, ignored] -- | Encode a Pong message payload. -encodePong :: Pong -> BS.ByteString -encodePong (Pong ignored) = mconcat - [ encodeU16 (fromIntegral (BS.length ignored)) - , ignored - ] +encodePong :: Pong -> Either EncodeError BS.ByteString +encodePong (Pong ignored) = do + ignoredLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength ignored) + Right $ mconcat [ignoredLen, ignored] -- | Encode a PeerStorage message payload. -encodePeerStorage :: PeerStorage -> BS.ByteString -encodePeerStorage (PeerStorage blob) = mconcat - [ encodeU16 (fromIntegral (BS.length blob)) - , blob - ] +encodePeerStorage :: PeerStorage -> Either EncodeError BS.ByteString +encodePeerStorage (PeerStorage blob) = do + blobLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength blob) + Right $ mconcat [blobLen, blob] -- | Encode a PeerStorageRetrieval message payload. -encodePeerStorageRetrieval :: PeerStorageRetrieval -> BS.ByteString -encodePeerStorageRetrieval (PeerStorageRetrieval blob) = mconcat - [ encodeU16 (fromIntegral (BS.length blob)) - , blob - ] +encodePeerStorageRetrieval + :: PeerStorageRetrieval -> Either EncodeError BS.ByteString +encodePeerStorageRetrieval (PeerStorageRetrieval blob) = do + blobLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength blob) + Right $ mconcat [blobLen, blob] -- | Encode a message to its payload bytes. -encodeMessage :: Message -> BS.ByteString +encodeMessage :: Message -> Either EncodeError BS.ByteString encodeMessage = \case - MsgInitVal m -> encodeInit m - MsgErrorVal m -> encodeError m - MsgWarningVal m -> encodeWarning m - MsgPingVal m -> encodePing m - MsgPongVal m -> encodePong m - MsgPeerStorageVal m -> encodePeerStorage m + MsgInitVal m -> encodeInit m + MsgErrorVal m -> encodeError m + MsgWarningVal m -> encodeWarning m + MsgPingVal m -> encodePing m + MsgPongVal m -> encodePong m + MsgPeerStorageVal m -> encodePeerStorage m MsgPeerStorageRetrievalVal m -> encodePeerStorageRetrieval m -- | Get the message type for a message. messageType :: Message -> MsgType messageType = \case - MsgInitVal _ -> MsgInit - MsgErrorVal _ -> MsgError - MsgWarningVal _ -> MsgWarning - MsgPingVal _ -> MsgPing - MsgPongVal _ -> MsgPong - MsgPeerStorageVal _ -> MsgPeerStorage + MsgInitVal _ -> MsgInit + MsgErrorVal _ -> MsgError + MsgWarningVal _ -> MsgWarning + MsgPingVal _ -> MsgPing + MsgPongVal _ -> MsgPong + MsgPeerStorageVal _ -> MsgPeerStorage MsgPeerStorageRetrievalVal _ -> MsgPeerStorageRet --- | Encode a message as a complete envelope (type + payload). -encodeEnvelope :: Message -> Maybe TlvStream -> BS.ByteString -encodeEnvelope msg mext = mconcat $ - [ encodeU16 (msgTypeWord (messageType msg)) - , encodeMessage msg - ] ++ maybe [] (\ext -> [encodeTlvStream ext]) mext +-- | Encode a message as a complete envelope (type + payload + extension). +encodeEnvelope :: Message -> Maybe TlvStream -> Either EncodeError BS.ByteString +encodeEnvelope msg mext = do + payload <- encodeMessage msg + Right $ mconcat $ + [ encodeU16 (msgTypeWord (messageType msg)) + , payload + ] ++ maybe [] (\ext -> [encodeTlvStream ext]) mext -- Message decoding ------------------------------------------------------------ @@ -523,14 +584,18 @@ data DecodeError = DecodeInsufficientBytes | DecodeInvalidLength | DecodeUnknownEvenType !Word16 + | DecodeUnknownOddType !Word16 | DecodeTlvError !TlvError | DecodeInvalidChannelId + | DecodeInvalidExtension !TlvError deriving stock (Eq, Show, Generic) instance NFData DecodeError -- | Decode an Init message from payload bytes. -decodeInit :: BS.ByteString -> Either DecodeError Init +-- +-- Returns the decoded message and any remaining bytes. +decodeInit :: BS.ByteString -> Either DecodeError (Init, BS.ByteString) decodeInit !bs = do (gfLen, rest1) <- maybe (Left DecodeInsufficientBytes) Right (decodeU16 bs) @@ -544,16 +609,17 @@ decodeInit !bs = do Left DecodeInsufficientBytes let !feat = BS.take (fromIntegral fLen) rest3 !rest4 = BS.drop (fromIntegral fLen) rest3 - -- Parse optional TLV stream + -- Parse optional TLV stream (consumes all remaining bytes for init) tlvStream <- if BS.null rest4 then Right (TlvStream []) else either (Left . DecodeTlvError) Right (decodeTlvStream rest4) initTlvList <- either (Left . DecodeTlvError) Right (parseInitTlvs tlvStream) - Right (Init gf feat initTlvList) + -- Init consumes all bytes (TLVs are part of init, not extensions) + Right (Init gf feat initTlvList, BS.empty) -- | Decode an Error message from payload bytes. -decodeError :: BS.ByteString -> Either DecodeError Error +decodeError :: BS.ByteString -> Either DecodeError (Error, BS.ByteString) decodeError !bs = do unless (BS.length bs >= 32) $ Left DecodeInsufficientBytes let !cid = BS.take 32 bs @@ -563,10 +629,11 @@ decodeError !bs = do unless (BS.length rest2 >= fromIntegral dLen) $ Left DecodeInsufficientBytes let !dat = BS.take (fromIntegral dLen) rest2 - Right (Error cid dat) + !rest3 = BS.drop (fromIntegral dLen) rest2 + Right (Error cid dat, rest3) -- | Decode a Warning message from payload bytes. -decodeWarning :: BS.ByteString -> Either DecodeError Warning +decodeWarning :: BS.ByteString -> Either DecodeError (Warning, BS.ByteString) decodeWarning !bs = do unless (BS.length bs >= 32) $ Left DecodeInsufficientBytes let !cid = BS.take 32 bs @@ -576,10 +643,11 @@ decodeWarning !bs = do unless (BS.length rest2 >= fromIntegral dLen) $ Left DecodeInsufficientBytes let !dat = BS.take (fromIntegral dLen) rest2 - Right (Warning cid dat) + !rest3 = BS.drop (fromIntegral dLen) rest2 + Right (Warning cid dat, rest3) -- | Decode a Ping message from payload bytes. -decodePing :: BS.ByteString -> Either DecodeError Ping +decodePing :: BS.ByteString -> Either DecodeError (Ping, BS.ByteString) decodePing !bs = do (numPong, rest1) <- maybe (Left DecodeInsufficientBytes) Right (decodeU16 bs) @@ -588,60 +656,87 @@ decodePing !bs = do unless (BS.length rest2 >= fromIntegral bLen) $ Left DecodeInsufficientBytes let !ignored = BS.take (fromIntegral bLen) rest2 - Right (Ping numPong ignored) + !rest3 = BS.drop (fromIntegral bLen) rest2 + Right (Ping numPong ignored, rest3) -- | Decode a Pong message from payload bytes. -decodePong :: BS.ByteString -> Either DecodeError Pong +decodePong :: BS.ByteString -> Either DecodeError (Pong, BS.ByteString) decodePong !bs = do (bLen, rest1) <- maybe (Left DecodeInsufficientBytes) Right (decodeU16 bs) unless (BS.length rest1 >= fromIntegral bLen) $ Left DecodeInsufficientBytes let !ignored = BS.take (fromIntegral bLen) rest1 - Right (Pong ignored) + !rest2 = BS.drop (fromIntegral bLen) rest1 + Right (Pong ignored, rest2) -- | Decode a PeerStorage message from payload bytes. -decodePeerStorage :: BS.ByteString -> Either DecodeError PeerStorage +decodePeerStorage + :: BS.ByteString -> Either DecodeError (PeerStorage, BS.ByteString) decodePeerStorage !bs = do (bLen, rest1) <- maybe (Left DecodeInsufficientBytes) Right (decodeU16 bs) unless (BS.length rest1 >= fromIntegral bLen) $ Left DecodeInsufficientBytes let !blob = BS.take (fromIntegral bLen) rest1 - Right (PeerStorage blob) + !rest2 = BS.drop (fromIntegral bLen) rest1 + Right (PeerStorage blob, rest2) -- | Decode a PeerStorageRetrieval message from payload bytes. -decodePeerStorageRetrieval :: BS.ByteString - -> Either DecodeError PeerStorageRetrieval +decodePeerStorageRetrieval + :: BS.ByteString + -> Either DecodeError (PeerStorageRetrieval, BS.ByteString) decodePeerStorageRetrieval !bs = do (bLen, rest1) <- maybe (Left DecodeInsufficientBytes) Right (decodeU16 bs) unless (BS.length rest1 >= fromIntegral bLen) $ Left DecodeInsufficientBytes let !blob = BS.take (fromIntegral bLen) rest1 - Right (PeerStorageRetrieval blob) + !rest2 = BS.drop (fromIntegral bLen) rest1 + Right (PeerStorageRetrieval blob, rest2) -- | Decode a message from its type and payload. -decodeMessage :: MsgType -> BS.ByteString -> Either DecodeError Message -decodeMessage MsgInit bs = MsgInitVal <$> decodeInit bs -decodeMessage MsgError bs = MsgErrorVal <$> decodeError bs -decodeMessage MsgWarning bs = MsgWarningVal <$> decodeWarning bs -decodeMessage MsgPing bs = MsgPingVal <$> decodePing bs -decodeMessage MsgPong bs = MsgPongVal <$> decodePong bs -decodeMessage MsgPeerStorage bs = MsgPeerStorageVal <$> decodePeerStorage bs -decodeMessage MsgPeerStorageRet bs = - MsgPeerStorageRetrievalVal <$> decodePeerStorageRetrieval bs +-- +-- Returns the decoded message and any remaining bytes (for extensions). +-- For unknown types, returns an appropriate error. +decodeMessage + :: MsgType -> BS.ByteString -> Either DecodeError (Message, BS.ByteString) +decodeMessage MsgInit bs = do + (m, rest) <- decodeInit bs + Right (MsgInitVal m, rest) +decodeMessage MsgError bs = do + (m, rest) <- decodeError bs + Right (MsgErrorVal m, rest) +decodeMessage MsgWarning bs = do + (m, rest) <- decodeWarning bs + Right (MsgWarningVal m, rest) +decodeMessage MsgPing bs = do + (m, rest) <- decodePing bs + Right (MsgPingVal m, rest) +decodeMessage MsgPong bs = do + (m, rest) <- decodePong bs + Right (MsgPongVal m, rest) +decodeMessage MsgPeerStorage bs = do + (m, rest) <- decodePeerStorage bs + Right (MsgPeerStorageVal m, rest) +decodeMessage MsgPeerStorageRet bs = do + (m, rest) <- decodePeerStorageRetrieval bs + Right (MsgPeerStorageRetrievalVal m, rest) decodeMessage (MsgUnknown w) _ | even w = Left (DecodeUnknownEvenType w) - | otherwise = Left DecodeInsufficientBytes + | otherwise = Left (DecodeUnknownOddType w) -- | Decode a complete envelope (type + payload + optional extension). -- -- Per BOLT #1: --- - Unknown odd message types are ignored (returns Nothing) +-- - Unknown odd message types are ignored (returns Nothing for message) -- - Unknown even message types cause connection close (returns error) --- - Invalid extension TLV causes connection close -decodeEnvelope :: BS.ByteString -> Either DecodeError (Maybe Message) +-- - Invalid extension TLV causes connection close (returns error) +-- +-- Returns the decoded message (if known) and any extension TLVs. +decodeEnvelope + :: BS.ByteString + -> Either DecodeError (Maybe Message, Maybe TlvStream) decodeEnvelope !bs = do (typeWord, rest1) <- maybe (Left DecodeInsufficientBytes) Right (decodeU16 bs) @@ -649,7 +744,13 @@ decodeEnvelope !bs = do case msgType of MsgUnknown w | even w -> Left (DecodeUnknownEvenType w) - | otherwise -> Right Nothing -- Ignore unknown odd types + | otherwise -> Right (Nothing, Nothing) -- Ignore unknown odd types _ -> do - msg <- decodeMessage msgType rest1 - Right (Just msg) + (msg, rest2) <- decodeMessage msgType rest1 + -- Parse any remaining bytes as extension TLV + ext <- if BS.null rest2 + then Right Nothing + else case decodeTlvStreamRaw rest2 of + Left e -> Left (DecodeInvalidExtension e) + Right s -> Right (Just s) + Right (Just msg, ext) diff --git a/test/Main.hs b/test/Main.hs @@ -16,6 +16,8 @@ main = defaultMain $ testGroup "ppad-bolt1" [ , tlv_tests , message_tests , envelope_tests + , extension_tests + , bounds_tests , property_tests ] @@ -152,6 +154,23 @@ tlv_tests = testGroup "TLV" [ Left TlvLengthExceedsBounds -> pure () other -> assertFailure $ "expected TlvLengthExceedsBounds: " ++ show other + , testCase "decodeTlvStreamWith custom predicate" $ do + -- Use a predicate that only knows type 5 + let isKnown t = t == 5 + bs = mconcat [ + encodeBigSize 5, encodeBigSize 2, "hi" + ] + case decodeTlvStreamWith isKnown bs of + Right (TlvStream [r]) -> tlvType r @?= 5 + other -> assertFailure $ "unexpected: " ++ show other + , testCase "decodeTlvStreamRaw returns all records" $ do + let bs = mconcat [ + encodeBigSize 2, encodeBigSize 1, "a" -- even type + , encodeBigSize 5, encodeBigSize 1, "b" -- odd type + ] + case decodeTlvStreamRaw bs of + Right (TlvStream recs) -> length recs @?= 2 + Left e -> assertFailure $ "unexpected error: " ++ show e ] -- Message encode/decode tests ------------------------------------------------- @@ -161,32 +180,36 @@ message_tests = testGroup "Messages" [ testGroup "Init" [ testCase "encode/decode minimal init" $ do let msg = Init "" "" [] - encoded = encodeMessage (MsgInitVal msg) - case decodeMessage MsgInit encoded of - Right (MsgInitVal decoded) -> decoded @?= msg - other -> assertFailure $ "unexpected: " ++ show other + case encodeMessage (MsgInitVal msg) of + Left e -> assertFailure $ "encode failed: " ++ show e + Right encoded -> case decodeMessage MsgInit encoded of + Right (MsgInitVal decoded, _) -> decoded @?= msg + other -> assertFailure $ "unexpected: " ++ show other , testCase "encode/decode init with features" $ do let msg = Init (BS.pack [0x01]) (BS.pack [0x02, 0x0a]) [] - encoded = encodeMessage (MsgInitVal msg) - case decodeMessage MsgInit encoded of - Right (MsgInitVal decoded) -> decoded @?= msg - other -> assertFailure $ "unexpected: " ++ show other + case encodeMessage (MsgInitVal msg) of + Left e -> assertFailure $ "encode failed: " ++ show e + Right encoded -> case decodeMessage MsgInit encoded of + Right (MsgInitVal decoded, _) -> decoded @?= msg + other -> assertFailure $ "unexpected: " ++ show other , testCase "encode/decode init with networks TLV" $ do let chainHash = BS.replicate 32 0xab msg = Init "" "" [InitNetworks [chainHash]] - encoded = encodeMessage (MsgInitVal msg) - case decodeMessage MsgInit encoded of - Right (MsgInitVal decoded) -> decoded @?= msg - other -> assertFailure $ "unexpected: " ++ show other + case encodeMessage (MsgInitVal msg) of + Left e -> assertFailure $ "encode failed: " ++ show e + Right encoded -> case decodeMessage MsgInit encoded of + Right (MsgInitVal decoded, _) -> decoded @?= msg + other -> assertFailure $ "unexpected: " ++ show other ] , testGroup "Error" [ testCase "encode/decode error" $ do let cid = BS.replicate 32 0xff msg = Error cid "something went wrong" - encoded = encodeMessage (MsgErrorVal msg) - case decodeMessage MsgError encoded of - Right (MsgErrorVal decoded) -> decoded @?= msg - other -> assertFailure $ "unexpected: " ++ show other + case encodeMessage (MsgErrorVal msg) of + Left e -> assertFailure $ "encode failed: " ++ show e + Right encoded -> case decodeMessage MsgError encoded of + Right (MsgErrorVal decoded, _) -> decoded @?= msg + other -> assertFailure $ "unexpected: " ++ show other , testCase "error insufficient channel_id" $ do case decodeMessage MsgError (BS.replicate 31 0x00) of Left DecodeInsufficientBytes -> pure () @@ -196,48 +219,64 @@ message_tests = testGroup "Messages" [ testCase "encode/decode warning" $ do let cid = BS.replicate 32 0x00 msg = Warning cid "be careful" - encoded = encodeMessage (MsgWarningVal msg) - case decodeMessage MsgWarning encoded of - Right (MsgWarningVal decoded) -> decoded @?= msg - other -> assertFailure $ "unexpected: " ++ show other + case encodeMessage (MsgWarningVal msg) of + Left e -> assertFailure $ "encode failed: " ++ show e + Right encoded -> case decodeMessage MsgWarning encoded of + Right (MsgWarningVal decoded, _) -> decoded @?= msg + other -> assertFailure $ "unexpected: " ++ show other ] , testGroup "Ping" [ testCase "encode/decode ping" $ do let msg = Ping 100 (BS.replicate 10 0x00) - encoded = encodeMessage (MsgPingVal msg) - case decodeMessage MsgPing encoded of - Right (MsgPingVal decoded) -> decoded @?= msg - other -> assertFailure $ "unexpected: " ++ show other + case encodeMessage (MsgPingVal msg) of + Left e -> assertFailure $ "encode failed: " ++ show e + Right encoded -> case decodeMessage MsgPing encoded of + Right (MsgPingVal decoded, _) -> decoded @?= msg + other -> assertFailure $ "unexpected: " ++ show other , testCase "ping with zero ignored" $ do let msg = Ping 50 "" - encoded = encodeMessage (MsgPingVal msg) - case decodeMessage MsgPing encoded of - Right (MsgPingVal decoded) -> decoded @?= msg - other -> assertFailure $ "unexpected: " ++ show other + case encodeMessage (MsgPingVal msg) of + Left e -> assertFailure $ "encode failed: " ++ show e + Right encoded -> case decodeMessage MsgPing encoded of + Right (MsgPingVal decoded, _) -> decoded @?= msg + other -> assertFailure $ "unexpected: " ++ show other ] , testGroup "Pong" [ testCase "encode/decode pong" $ do let msg = Pong (BS.replicate 100 0x00) - encoded = encodeMessage (MsgPongVal msg) - case decodeMessage MsgPong encoded of - Right (MsgPongVal decoded) -> decoded @?= msg - other -> assertFailure $ "unexpected: " ++ show other + case encodeMessage (MsgPongVal msg) of + Left e -> assertFailure $ "encode failed: " ++ show e + Right encoded -> case decodeMessage MsgPong encoded of + Right (MsgPongVal decoded, _) -> decoded @?= msg + other -> assertFailure $ "unexpected: " ++ show other ] , testGroup "PeerStorage" [ testCase "encode/decode peer_storage" $ do let msg = PeerStorage "encrypted blob data" - encoded = encodeMessage (MsgPeerStorageVal msg) - case decodeMessage MsgPeerStorage encoded of - Right (MsgPeerStorageVal decoded) -> decoded @?= msg - other -> assertFailure $ "unexpected: " ++ show other + case encodeMessage (MsgPeerStorageVal msg) of + Left e -> assertFailure $ "encode failed: " ++ show e + Right encoded -> case decodeMessage MsgPeerStorage encoded of + Right (MsgPeerStorageVal decoded, _) -> decoded @?= msg + other -> assertFailure $ "unexpected: " ++ show other ] , testGroup "PeerStorageRetrieval" [ testCase "encode/decode peer_storage_retrieval" $ do let msg = PeerStorageRetrieval "retrieved blob" - encoded = encodeMessage (MsgPeerStorageRetrievalVal msg) - case decodeMessage MsgPeerStorageRet encoded of - Right (MsgPeerStorageRetrievalVal decoded) -> decoded @?= msg - other -> assertFailure $ "unexpected: " ++ show other + case encodeMessage (MsgPeerStorageRetrievalVal msg) of + Left e -> assertFailure $ "encode failed: " ++ show e + Right encoded -> case decodeMessage MsgPeerStorageRet encoded of + Right (MsgPeerStorageRetrievalVal decoded, _) -> decoded @?= msg + other -> assertFailure $ "unexpected: " ++ show other + ] + , testGroup "Unknown types" [ + testCase "decodeMessage unknown even type" $ do + case decodeMessage (MsgUnknown 100) "payload" of + Left (DecodeUnknownEvenType 100) -> pure () + other -> assertFailure $ "expected unknown even: " ++ show other + , testCase "decodeMessage unknown odd type" $ do + case decodeMessage (MsgUnknown 101) "payload" of + Left (DecodeUnknownOddType 101) -> pure () + other -> assertFailure $ "expected unknown odd: " ++ show other ] ] @@ -247,16 +286,18 @@ envelope_tests :: TestTree envelope_tests = testGroup "Envelope" [ testCase "encode/decode init envelope" $ do let msg = MsgInitVal (Init "" "" []) - encoded = encodeEnvelope msg Nothing - case decodeEnvelope encoded of - Right (Just decoded) -> decoded @?= msg - other -> assertFailure $ "unexpected: " ++ show other + case encodeEnvelope msg Nothing of + Left e -> assertFailure $ "encode failed: " ++ show e + Right encoded -> case decodeEnvelope encoded of + Right (Just decoded, _) -> decoded @?= msg + other -> assertFailure $ "unexpected: " ++ show other , testCase "encode/decode ping envelope" $ do let msg = MsgPingVal (Ping 10 "") - encoded = encodeEnvelope msg Nothing - case decodeEnvelope encoded of - Right (Just decoded) -> decoded @?= msg - other -> assertFailure $ "unexpected: " ++ show other + case encodeEnvelope msg Nothing of + Left e -> assertFailure $ "encode failed: " ++ show e + Right encoded -> case decodeEnvelope encoded of + Right (Just decoded, _) -> decoded @?= msg + other -> assertFailure $ "unexpected: " ++ show other , testCase "unknown even type fails" $ do let bs = encodeU16 100 <> "payload" -- 100 is even, unknown case decodeEnvelope bs of @@ -265,8 +306,8 @@ envelope_tests = testGroup "Envelope" [ , testCase "unknown odd type ignored" $ do let bs = encodeU16 101 <> "payload" -- 101 is odd, unknown case decodeEnvelope bs of - Right Nothing -> pure () -- ignored - other -> assertFailure $ "expected Nothing: " ++ show other + Right (Nothing, Nothing) -> pure () -- ignored + other -> assertFailure $ "expected (Nothing, Nothing): " ++ show other , testCase "insufficient bytes for type" $ do case decodeEnvelope (BS.pack [0x00]) of Left DecodeInsufficientBytes -> pure () @@ -281,6 +322,83 @@ envelope_tests = testGroup "Envelope" [ msgTypeWord MsgPeerStorageRet @?= 9 ] +-- Extension TLV tests --------------------------------------------------------- + +extension_tests :: TestTree +extension_tests = testGroup "Extension TLV" [ + testCase "encode envelope with extension" $ do + let msg = MsgPingVal (Ping 10 "") + ext = TlvStream [TlvRecord 100 "extension data"] + case encodeEnvelope msg (Just ext) of + Left e -> assertFailure $ "encode failed: " ++ show e + Right encoded -> do + -- Should contain message + extension + assertBool "encoded should be longer" (BS.length encoded > 6) + , testCase "decode envelope with extension roundtrip" $ do + let msg = MsgPingVal (Ping 10 "") + ext = TlvStream [TlvRecord 101 "ext"] + case encodeEnvelope msg (Just ext) of + Left e -> assertFailure $ "encode failed: " ++ show e + Right encoded -> case decodeEnvelope encoded of + Right (Just decoded, Just decodedExt) -> do + decoded @?= msg + length (unTlvStream decodedExt) @?= 1 + other -> assertFailure $ "unexpected: " ++ show other + , testCase "decode envelope extension is parsed" $ do + -- Manually construct ping + extension TLV + let pingPayload = mconcat [encodeU16 10, encodeU16 0] -- numPong=10, len=0 + extTlv = mconcat [encodeBigSize 200, encodeBigSize 3, "abc"] + envelope = encodeU16 18 <> pingPayload <> extTlv -- type 18 = ping + case decodeEnvelope envelope of + Right (Just (MsgPingVal ping), Just (TlvStream [r])) -> do + pingNumPongBytes ping @?= 10 + tlvType r @?= 200 + tlvValue r @?= "abc" + other -> assertFailure $ "unexpected: " ++ show other + , testCase "decode envelope with invalid extension fails" $ do + -- Ping + invalid TLV (non-strictly-increasing) + let pingPayload = mconcat [encodeU16 10, encodeU16 0] + badTlv = mconcat [ + encodeBigSize 100, encodeBigSize 1, "a" + , encodeBigSize 50, encodeBigSize 1, "b" -- 50 < 100, invalid + ] + envelope = encodeU16 18 <> pingPayload <> badTlv + case decodeEnvelope envelope of + Left (DecodeInvalidExtension TlvNotStrictlyIncreasing) -> pure () + other -> assertFailure $ "expected invalid extension: " ++ show other + ] + +-- Bounds checking tests ------------------------------------------------------- + +bounds_tests :: TestTree +bounds_tests = testGroup "Bounds checking" [ + testCase "encode ping with oversized ignored fails" $ do + let msg = Ping 10 (BS.replicate 70000 0x00) -- > 65535 + case encodeMessage (MsgPingVal msg) of + Left EncodeLengthOverflow -> pure () + other -> assertFailure $ "expected overflow: " ++ show other + , testCase "encode pong with oversized ignored fails" $ do + let msg = Pong (BS.replicate 70000 0x00) + case encodeMessage (MsgPongVal msg) of + Left EncodeLengthOverflow -> pure () + other -> assertFailure $ "expected overflow: " ++ show other + , testCase "encode error with oversized data fails" $ do + let msg = Error (BS.replicate 32 0x00) (BS.replicate 70000 0x00) + case encodeMessage (MsgErrorVal msg) of + Left EncodeLengthOverflow -> pure () + other -> assertFailure $ "expected overflow: " ++ show other + , testCase "encode init with oversized features fails" $ do + let msg = Init "" (BS.replicate 70000 0x00) [] + case encodeMessage (MsgInitVal msg) of + Left EncodeLengthOverflow -> pure () + other -> assertFailure $ "expected overflow: " ++ show other + , testCase "encode peer_storage with oversized blob fails" $ do + let msg = PeerStorage (BS.replicate 70000 0x00) + case encodeMessage (MsgPeerStorageVal msg) of + Left EncodeLengthOverflow -> pure () + other -> assertFailure $ "expected overflow: " ++ show other + ] + -- Property tests -------------------------------------------------------------- property_tests :: TestTree @@ -296,36 +414,62 @@ property_tests = testGroup "Properties" [ , testProperty "U64 roundtrip" $ \w -> decodeU64 (encodeU64 w) == Just (w, "") , testProperty "Ping roundtrip" $ \(NonNegative num) bs -> - let msg = Ping (fromIntegral (num `mod` 65536 :: Integer)) - (BS.pack bs) - encoded = encodeMessage (MsgPingVal msg) - in case decodeMessage MsgPing encoded of - Right (MsgPingVal decoded) -> decoded == msg - _ -> False + let ignored = BS.pack (take 1000 bs) -- limit size + msg = Ping (fromIntegral (num `mod` 65536 :: Integer)) ignored + in case encodeMessage (MsgPingVal msg) of + Left _ -> False + Right encoded -> case decodeMessage MsgPing encoded of + Right (MsgPingVal decoded, rest) -> + decoded == msg && BS.null rest + _ -> False , testProperty "Pong roundtrip" $ \bs -> - let msg = Pong (BS.pack bs) - encoded = encodeMessage (MsgPongVal msg) - in case decodeMessage MsgPong encoded of - Right (MsgPongVal decoded) -> decoded == msg - _ -> False + let ignored = BS.pack (take 1000 bs) + msg = Pong ignored + in case encodeMessage (MsgPongVal msg) of + Left _ -> False + Right encoded -> case decodeMessage MsgPong encoded of + Right (MsgPongVal decoded, rest) -> + decoded == msg && BS.null rest + _ -> False , testProperty "PeerStorage roundtrip" $ \bs -> - let msg = PeerStorage (BS.pack bs) - encoded = encodeMessage (MsgPeerStorageVal msg) - in case decodeMessage MsgPeerStorage encoded of - Right (MsgPeerStorageVal decoded) -> decoded == msg - _ -> False + let blob = BS.pack (take 1000 bs) + msg = PeerStorage blob + in case encodeMessage (MsgPeerStorageVal msg) of + Left _ -> False + Right encoded -> case decodeMessage MsgPeerStorage encoded of + Right (MsgPeerStorageVal decoded, rest) -> + decoded == msg && BS.null rest + _ -> False , testProperty "Error roundtrip" $ \bs -> let cid = BS.replicate 32 0x00 - msg = Error cid (BS.pack bs) - encoded = encodeMessage (MsgErrorVal msg) - in case decodeMessage MsgError encoded of - Right (MsgErrorVal decoded) -> decoded == msg - _ -> False + dat = BS.pack (take 1000 bs) + msg = Error cid dat + in case encodeMessage (MsgErrorVal msg) of + Left _ -> False + Right encoded -> case decodeMessage MsgError encoded of + Right (MsgErrorVal decoded, rest) -> + decoded == msg && BS.null rest + _ -> False + , testProperty "Envelope with extension roundtrip" $ \bs -> + let msg = MsgPingVal (Ping 42 "") + extData = BS.pack (take 100 bs) + ext = TlvStream [TlvRecord 101 extData] + in case encodeEnvelope msg (Just ext) of + Left _ -> False + Right encoded -> case decodeEnvelope encoded of + Right (Just decoded, Just (TlvStream [r])) -> + decoded == msg && tlvType r == 101 && tlvValue r == extData + _ -> False ] -- Helpers --------------------------------------------------------------------- +-- | Decode hex string. Fails the test on invalid hex. unhex :: BS.ByteString -> BS.ByteString unhex bs = case B16.decode bs of Just r -> r - Nothing -> error $ "invalid hex: " ++ show bs + Nothing -> assertFailure' $ "invalid hex: " ++ show bs + +-- | assertFailure that returns any type (for use in pure contexts) +assertFailure' :: String -> a +assertFailure' msg = error msg