bolt1

Base Lightning protocol, per BOLT #1.
git clone git://git.ppad.tech/bolt1.git
Log | Files | Refs | README | LICENSE

commit 98b9bbba12f72665f6bed7575a571fd4a9b37a40
parent b8731a656b99e3bc13cb719ee0f4bf85a94ab075
Author: Jared Tobin <jared@jtobin.io>
Date:   Sun, 25 Jan 2026 11:07:36 +0400

feat: implement BOLT #1 stabilization (IMPL3)

Phase 1 - Fundamental types:
- Add signed integers: encodeS8/S16/S32/S64, decodeS8/S16/S32/S64
- Add truncated unsigned: encodeTu16/Tu32/Tu64, decodeTu16/Tu32/Tu64
- Add minimal signed encoding per Appendix D test vectors
- Add 68 new tests including all Appendix D vectors

Phase 2 - Validation hardening:
- Add EncodeMessageTooLarge error type
- Enforce message size limit (type + payload + extension <= 65535)

Phase 3 - Extension TLV policy:
- Verified unknown even TLVs rejected in extensions
- Unknown odd types properly skipped

Phase 4 - Module split:
- Lightning.Protocol.BOLT1.Prim: primitive encoding/decoding
- Lightning.Protocol.BOLT1.TLV: TLV types and stream handling
- Lightning.Protocol.BOLT1.Message: message types and ADTs
- Lightning.Protocol.BOLT1.Codec: message encoding/decoding
- Lightning.Protocol.BOLT1: re-exports preserving public API

All 145 tests pass.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Diffstat:
Mlib/Lightning/Protocol/BOLT1.hs | 711+++----------------------------------------------------------------------------
Alib/Lightning/Protocol/BOLT1/Codec.hs | 320+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Alib/Lightning/Protocol/BOLT1/Message.hs | 170+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Alib/Lightning/Protocol/BOLT1/Prim.hs | 496+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Alib/Lightning/Protocol/BOLT1/TLV.hs | 209+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Mppad-bolt1.cabal | 4++++
Mtest/Main.hs | 171+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
7 files changed, 1391 insertions(+), 690 deletions(-)

diff --git a/lib/Lightning/Protocol/BOLT1.hs b/lib/Lightning/Protocol/BOLT1.hs @@ -1,9 +1,4 @@ {-# OPTIONS_HADDOCK prune #-} -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} -- | -- Module: Lightning.Protocol.BOLT1 @@ -62,697 +57,33 @@ module Lightning.Protocol.BOLT1 ( , encodeU16 , encodeU32 , encodeU64 + , encodeS8 + , encodeS16 + , encodeS32 + , encodeS64 + , encodeTu16 + , encodeTu32 + , encodeTu64 + , encodeMinSigned , encodeBigSize -- * Primitive decoding , decodeU16 , decodeU32 , decodeU64 + , decodeS8 + , decodeS16 + , decodeS32 + , decodeS64 + , decodeTu16 + , decodeTu32 + , decodeTu64 + , decodeMinSigned , decodeBigSize ) where -import Control.DeepSeq (NFData) -import Control.Monad (when, unless) -import Data.Bits (unsafeShiftL, (.|.)) -import qualified Data.ByteString as BS -import qualified Data.ByteString.Builder as BSB -import qualified Data.ByteString.Lazy as BSL -import Data.Word (Word16, Word32, Word64) -import GHC.Generics (Generic) - --- Primitive encoding ---------------------------------------------------------- - --- | Encode a 16-bit unsigned integer (big-endian). --- --- >>> encodeU16 0x0102 --- "\SOH\STX" -encodeU16 :: Word16 -> BS.ByteString -encodeU16 = BSL.toStrict . BSB.toLazyByteString . BSB.word16BE -{-# INLINE encodeU16 #-} - --- | Encode a 32-bit unsigned integer (big-endian). --- --- >>> encodeU32 0x01020304 --- "\SOH\STX\ETX\EOT" -encodeU32 :: Word32 -> BS.ByteString -encodeU32 = BSL.toStrict . BSB.toLazyByteString . BSB.word32BE -{-# INLINE encodeU32 #-} - --- | Encode a 64-bit unsigned integer (big-endian). --- --- >>> encodeU64 0x0102030405060708 --- "\SOH\STX\ETX\EOT\ENQ\ACK\a\b" -encodeU64 :: Word64 -> BS.ByteString -encodeU64 = BSL.toStrict . BSB.toLazyByteString . BSB.word64BE -{-# INLINE encodeU64 #-} - --- | Encode a BigSize value (variable-length unsigned integer). --- --- >>> encodeBigSize 0 --- "\NUL" --- >>> encodeBigSize 252 --- "\252" --- >>> encodeBigSize 253 --- "\253\NUL\253" --- >>> encodeBigSize 65536 --- "\254\NUL\SOH\NUL\NUL" -encodeBigSize :: Word64 -> BS.ByteString -encodeBigSize !x - | x < 0xfd = BS.singleton (fromIntegral x) - | x < 0x10000 = BS.cons 0xfd (encodeU16 (fromIntegral x)) - | x < 0x100000000 = BS.cons 0xfe (encodeU32 (fromIntegral x)) - | otherwise = BS.cons 0xff (encodeU64 x) -{-# INLINE encodeBigSize #-} - --- | Encode a length as u16, checking bounds. --- --- Returns Nothing if the length exceeds 65535. -encodeLength :: BS.ByteString -> Maybe BS.ByteString -encodeLength !bs - | BS.length bs > 65535 = Nothing - | otherwise = Just (encodeU16 (fromIntegral (BS.length bs))) -{-# INLINE encodeLength #-} - --- Primitive decoding ---------------------------------------------------------- - --- | Decode a 16-bit unsigned integer (big-endian). -decodeU16 :: BS.ByteString -> Maybe (Word16, BS.ByteString) -decodeU16 !bs - | BS.length bs < 2 = Nothing - | otherwise = - let !b0 = fromIntegral (BS.index bs 0) - !b1 = fromIntegral (BS.index bs 1) - !val = (b0 `unsafeShiftL` 8) .|. b1 - in Just (val, BS.drop 2 bs) -{-# INLINE decodeU16 #-} - --- | Decode a 32-bit unsigned integer (big-endian). -decodeU32 :: BS.ByteString -> Maybe (Word32, BS.ByteString) -decodeU32 !bs - | BS.length bs < 4 = Nothing - | otherwise = - let !b0 = fromIntegral (BS.index bs 0) - !b1 = fromIntegral (BS.index bs 1) - !b2 = fromIntegral (BS.index bs 2) - !b3 = fromIntegral (BS.index bs 3) - !val = (b0 `unsafeShiftL` 24) .|. (b1 `unsafeShiftL` 16) - .|. (b2 `unsafeShiftL` 8) .|. b3 - in Just (val, BS.drop 4 bs) -{-# INLINE decodeU32 #-} - --- | Decode a 64-bit unsigned integer (big-endian). -decodeU64 :: BS.ByteString -> Maybe (Word64, BS.ByteString) -decodeU64 !bs - | BS.length bs < 8 = Nothing - | otherwise = - let !b0 = fromIntegral (BS.index bs 0) - !b1 = fromIntegral (BS.index bs 1) - !b2 = fromIntegral (BS.index bs 2) - !b3 = fromIntegral (BS.index bs 3) - !b4 = fromIntegral (BS.index bs 4) - !b5 = fromIntegral (BS.index bs 5) - !b6 = fromIntegral (BS.index bs 6) - !b7 = fromIntegral (BS.index bs 7) - !val = (b0 `unsafeShiftL` 56) .|. (b1 `unsafeShiftL` 48) - .|. (b2 `unsafeShiftL` 40) .|. (b3 `unsafeShiftL` 32) - .|. (b4 `unsafeShiftL` 24) .|. (b5 `unsafeShiftL` 16) - .|. (b6 `unsafeShiftL` 8) .|. b7 - in Just (val, BS.drop 8 bs) -{-# INLINE decodeU64 #-} - --- | Decode a BigSize value with minimality check. -decodeBigSize :: BS.ByteString -> Maybe (Word64, BS.ByteString) -decodeBigSize !bs - | BS.null bs = Nothing - | otherwise = case BS.index bs 0 of - 0xff -> do - (val, rest) <- decodeU64 (BS.drop 1 bs) - -- Must be >= 0x100000000 for minimal encoding - if val >= 0x100000000 - then Just (val, rest) - else Nothing - 0xfe -> do - (val, rest) <- decodeU32 (BS.drop 1 bs) - -- Must be >= 0x10000 for minimal encoding - if val >= 0x10000 - then Just (fromIntegral val, rest) - else Nothing - 0xfd -> do - (val, rest) <- decodeU16 (BS.drop 1 bs) - -- Must be >= 0xfd for minimal encoding - if val >= 0xfd - then Just (fromIntegral val, rest) - else Nothing - b -> Just (fromIntegral b, BS.drop 1 bs) - --- TLV types ------------------------------------------------------------------- - --- | A single TLV record. -data TlvRecord = TlvRecord - { tlvType :: {-# UNPACK #-} !Word64 - , tlvValue :: !BS.ByteString - } deriving stock (Eq, Show, Generic) - -instance NFData TlvRecord - --- | A TLV stream (series of TLV records). -newtype TlvStream = TlvStream { unTlvStream :: [TlvRecord] } - deriving stock (Eq, Show, Generic) - -instance NFData TlvStream - --- | Encode a TLV record. -encodeTlvRecord :: TlvRecord -> BS.ByteString -encodeTlvRecord (TlvRecord typ val) = mconcat - [ encodeBigSize typ - , encodeBigSize (fromIntegral (BS.length val)) - , val - ] - --- | Encode a TLV stream. -encodeTlvStream :: TlvStream -> BS.ByteString -encodeTlvStream (TlvStream recs) = mconcat (map encodeTlvRecord recs) - --- | TLV decoding errors. -data TlvError - = TlvNonMinimalEncoding - | TlvNotStrictlyIncreasing - | TlvLengthExceedsBounds - | TlvUnknownEvenType !Word64 - | TlvInvalidKnownType !Word64 - deriving stock (Eq, Show, Generic) - -instance NFData TlvError - --- | Decode a TLV stream without any known-type validation. --- --- This decoder only enforces structural validity: --- - Types must be strictly increasing --- - Lengths must not exceed bounds --- --- All records are returned regardless of type. Note: this does NOT --- enforce the BOLT #1 unknown-even-type rule. Use 'decodeTlvStreamWith' --- with an appropriate predicate for spec-compliant parsing. -decodeTlvStreamRaw :: BS.ByteString -> Either TlvError TlvStream -decodeTlvStreamRaw = go Nothing [] - where - go :: Maybe Word64 -> [TlvRecord] -> BS.ByteString - -> Either TlvError TlvStream - go !_ !acc !bs - | BS.null bs = Right (TlvStream (reverse acc)) - go !mPrevType !acc !bs = do - (typ, rest1) <- maybe (Left TlvNonMinimalEncoding) Right - (decodeBigSize bs) - -- Strictly increasing check - case mPrevType of - Just prevType -> when (typ <= prevType) $ - Left TlvNotStrictlyIncreasing - Nothing -> pure () - (len, rest2) <- maybe (Left TlvNonMinimalEncoding) Right - (decodeBigSize rest1) - -- Length bounds check - when (fromIntegral len > BS.length rest2) $ - Left TlvLengthExceedsBounds - let !val = BS.take (fromIntegral len) rest2 - !rest3 = BS.drop (fromIntegral len) rest2 - !rec = TlvRecord typ val - go (Just typ) (rec : acc) rest3 - --- | Decode a TLV stream with configurable known-type predicate. --- --- Per BOLT #1: --- - Types must be strictly increasing --- - Unknown even types cause failure --- - Unknown odd types are skipped --- --- The predicate determines which types are "known" for the context. -decodeTlvStreamWith - :: (Word64 -> Bool) -- ^ Predicate: is this type known? - -> BS.ByteString - -> Either TlvError TlvStream -decodeTlvStreamWith isKnown = go Nothing [] - where - go :: Maybe Word64 -> [TlvRecord] -> BS.ByteString - -> Either TlvError TlvStream - go !_ !acc !bs - | BS.null bs = Right (TlvStream (reverse acc)) - go !mPrevType !acc !bs = do - (typ, rest1) <- maybe (Left TlvNonMinimalEncoding) Right - (decodeBigSize bs) - -- Strictly increasing check - case mPrevType of - Just prevType -> when (typ <= prevType) $ - Left TlvNotStrictlyIncreasing - Nothing -> pure () - (len, rest2) <- maybe (Left TlvNonMinimalEncoding) Right - (decodeBigSize rest1) - -- Length bounds check - when (fromIntegral len > BS.length rest2) $ - Left TlvLengthExceedsBounds - let !val = BS.take (fromIntegral len) rest2 - !rest3 = BS.drop (fromIntegral len) rest2 - !rec = TlvRecord typ val - -- Unknown type handling: even = fail, odd = skip - if isKnown typ - then go (Just typ) (rec : acc) rest3 - else if even typ - then Left (TlvUnknownEvenType typ) - else go (Just typ) acc rest3 -- skip unknown odd - --- | Decode a TLV stream with BOLT #1 init_tlvs validation. --- --- This uses the default known types for init messages (1 and 3). --- For other contexts, use 'decodeTlvStreamWith' with an appropriate --- predicate. -decodeTlvStream :: BS.ByteString -> Either TlvError TlvStream -decodeTlvStream = decodeTlvStreamWith isInitTlvType - where - isInitTlvType :: Word64 -> Bool - isInitTlvType 1 = True -- networks - isInitTlvType 3 = True -- remote_addr - isInitTlvType _ = False - --- Init TLV types -------------------------------------------------------------- - --- | TLV records for init message. -data InitTlv - = InitNetworks ![BS.ByteString] -- ^ Type 1: chain hashes (32 bytes each) - | InitRemoteAddr !BS.ByteString -- ^ Type 3: remote address - deriving stock (Eq, Show, Generic) - -instance NFData InitTlv - --- | Parse init TLVs from a TLV stream. -parseInitTlvs :: TlvStream -> Either TlvError [InitTlv] -parseInitTlvs (TlvStream recs) = traverse parseOne recs - where - parseOne (TlvRecord 1 val) - | BS.length val `mod` 32 == 0 = - Right (InitNetworks (chunksOf 32 val)) - | otherwise = Left (TlvInvalidKnownType 1) - parseOne (TlvRecord 3 val) = Right (InitRemoteAddr val) - parseOne (TlvRecord t _) = Left (TlvUnknownEvenType t) - --- | Split bytestring into chunks of given size. -chunksOf :: Int -> BS.ByteString -> [BS.ByteString] -chunksOf !n !bs - | BS.null bs = [] - | otherwise = - let (!chunk, !rest) = BS.splitAt n bs - in chunk : chunksOf n rest - --- | Encode init TLVs to a TLV stream. -encodeInitTlvs :: [InitTlv] -> TlvStream -encodeInitTlvs = TlvStream . map toRecord - where - toRecord (InitNetworks chains) = - TlvRecord 1 (mconcat chains) - toRecord (InitRemoteAddr addr) = - TlvRecord 3 addr - --- Message types --------------------------------------------------------------- - --- | BOLT #1 message type codes. -data MsgType - = MsgInit -- ^ 16 - | MsgError -- ^ 17 - | MsgPing -- ^ 18 - | MsgPong -- ^ 19 - | MsgWarning -- ^ 1 - | MsgPeerStorage -- ^ 7 - | MsgPeerStorageRet -- ^ 9 - | MsgUnknown !Word16 -- ^ Unknown type - deriving stock (Eq, Show, Generic) - -instance NFData MsgType - --- | Get the numeric type code for a message type. -msgTypeWord :: MsgType -> Word16 -msgTypeWord MsgInit = 16 -msgTypeWord MsgError = 17 -msgTypeWord MsgPing = 18 -msgTypeWord MsgPong = 19 -msgTypeWord MsgWarning = 1 -msgTypeWord MsgPeerStorage = 7 -msgTypeWord MsgPeerStorageRet = 9 -msgTypeWord (MsgUnknown w) = w - --- | Parse a message type from a word. -parseMsgType :: Word16 -> MsgType -parseMsgType 16 = MsgInit -parseMsgType 17 = MsgError -parseMsgType 18 = MsgPing -parseMsgType 19 = MsgPong -parseMsgType 1 = MsgWarning -parseMsgType 7 = MsgPeerStorage -parseMsgType 9 = MsgPeerStorageRet -parseMsgType w = MsgUnknown w - --- Message ADTs ---------------------------------------------------------------- - --- | The init message (type 16). -data Init = Init - { initGlobalFeatures :: !BS.ByteString - , initFeatures :: !BS.ByteString - , initTlvs :: ![InitTlv] - } deriving stock (Eq, Show, Generic) - -instance NFData Init - --- | The error message (type 17). -data Error = Error - { errorChannelId :: !BS.ByteString -- ^ 32 bytes - , errorData :: !BS.ByteString - } deriving stock (Eq, Show, Generic) - -instance NFData Error - --- | The warning message (type 1). -data Warning = Warning - { warningChannelId :: !BS.ByteString -- ^ 32 bytes - , warningData :: !BS.ByteString - } deriving stock (Eq, Show, Generic) - -instance NFData Warning - --- | The ping message (type 18). -data Ping = Ping - { pingNumPongBytes :: {-# UNPACK #-} !Word16 - , pingIgnored :: !BS.ByteString - } deriving stock (Eq, Show, Generic) - -instance NFData Ping - --- | The pong message (type 19). -data Pong = Pong - { pongIgnored :: !BS.ByteString - } deriving stock (Eq, Show, Generic) - -instance NFData Pong - --- | The peer_storage message (type 7). -data PeerStorage = PeerStorage - { peerStorageBlob :: !BS.ByteString - } deriving stock (Eq, Show, Generic) - -instance NFData PeerStorage - --- | The peer_storage_retrieval message (type 9). -data PeerStorageRetrieval = PeerStorageRetrieval - { peerStorageRetrievalBlob :: !BS.ByteString - } deriving stock (Eq, Show, Generic) - -instance NFData PeerStorageRetrieval - --- | All BOLT #1 messages. -data Message - = MsgInitVal !Init - | MsgErrorVal !Error - | MsgWarningVal !Warning - | MsgPingVal !Ping - | MsgPongVal !Pong - | MsgPeerStorageVal !PeerStorage - | MsgPeerStorageRetrievalVal !PeerStorageRetrieval - deriving stock (Eq, Show, Generic) - -instance NFData Message - --- Message envelope ------------------------------------------------------------ - --- | A complete message envelope with type, payload, and optional extension. -data Envelope = Envelope - { envType :: !MsgType - , envPayload :: !BS.ByteString - , envExtension :: !(Maybe TlvStream) - } deriving stock (Eq, Show, Generic) - -instance NFData Envelope - --- Message encoding ------------------------------------------------------------ - --- | Encoding errors. -data EncodeError - = EncodeLengthOverflow -- ^ Payload exceeds u16 max (65535 bytes) - deriving stock (Eq, Show, Generic) - -instance NFData EncodeError - --- | Encode an Init message payload. -encodeInit :: Init -> Either EncodeError BS.ByteString -encodeInit (Init gf feat tlvs) = do - gfLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength gf) - featLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength feat) - Right $ mconcat - [ gfLen - , gf - , featLen - , feat - , encodeTlvStream (encodeInitTlvs tlvs) - ] - --- | Encode an Error message payload. -encodeError :: Error -> Either EncodeError BS.ByteString -encodeError (Error cid dat) = do - datLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength dat) - Right $ mconcat [cid, datLen, dat] - --- | Encode a Warning message payload. -encodeWarning :: Warning -> Either EncodeError BS.ByteString -encodeWarning (Warning cid dat) = do - datLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength dat) - Right $ mconcat [cid, datLen, dat] - --- | Encode a Ping message payload. -encodePing :: Ping -> Either EncodeError BS.ByteString -encodePing (Ping numPong ignored) = do - ignoredLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength ignored) - Right $ mconcat [encodeU16 numPong, ignoredLen, ignored] - --- | Encode a Pong message payload. -encodePong :: Pong -> Either EncodeError BS.ByteString -encodePong (Pong ignored) = do - ignoredLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength ignored) - Right $ mconcat [ignoredLen, ignored] - --- | Encode a PeerStorage message payload. -encodePeerStorage :: PeerStorage -> Either EncodeError BS.ByteString -encodePeerStorage (PeerStorage blob) = do - blobLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength blob) - Right $ mconcat [blobLen, blob] - --- | Encode a PeerStorageRetrieval message payload. -encodePeerStorageRetrieval - :: PeerStorageRetrieval -> Either EncodeError BS.ByteString -encodePeerStorageRetrieval (PeerStorageRetrieval blob) = do - blobLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength blob) - Right $ mconcat [blobLen, blob] - --- | Encode a message to its payload bytes. -encodeMessage :: Message -> Either EncodeError BS.ByteString -encodeMessage = \case - MsgInitVal m -> encodeInit m - MsgErrorVal m -> encodeError m - MsgWarningVal m -> encodeWarning m - MsgPingVal m -> encodePing m - MsgPongVal m -> encodePong m - MsgPeerStorageVal m -> encodePeerStorage m - MsgPeerStorageRetrievalVal m -> encodePeerStorageRetrieval m - --- | Get the message type for a message. -messageType :: Message -> MsgType -messageType = \case - MsgInitVal _ -> MsgInit - MsgErrorVal _ -> MsgError - MsgWarningVal _ -> MsgWarning - MsgPingVal _ -> MsgPing - MsgPongVal _ -> MsgPong - MsgPeerStorageVal _ -> MsgPeerStorage - MsgPeerStorageRetrievalVal _ -> MsgPeerStorageRet - --- | Encode a message as a complete envelope (type + payload + extension). -encodeEnvelope :: Message -> Maybe TlvStream -> Either EncodeError BS.ByteString -encodeEnvelope msg mext = do - payload <- encodeMessage msg - Right $ mconcat $ - [ encodeU16 (msgTypeWord (messageType msg)) - , payload - ] ++ maybe [] (\ext -> [encodeTlvStream ext]) mext - --- Message decoding ------------------------------------------------------------ - --- | Decoding errors. -data DecodeError - = DecodeInsufficientBytes - | DecodeInvalidLength - | DecodeUnknownEvenType !Word16 - | DecodeUnknownOddType !Word16 - | DecodeTlvError !TlvError - | DecodeInvalidChannelId - | DecodeInvalidExtension !TlvError - deriving stock (Eq, Show, Generic) - -instance NFData DecodeError - --- | Decode an Init message from payload bytes. --- --- Returns the decoded message and any remaining bytes. -decodeInit :: BS.ByteString -> Either DecodeError (Init, BS.ByteString) -decodeInit !bs = do - (gfLen, rest1) <- maybe (Left DecodeInsufficientBytes) Right - (decodeU16 bs) - unless (BS.length rest1 >= fromIntegral gfLen) $ - Left DecodeInsufficientBytes - let !gf = BS.take (fromIntegral gfLen) rest1 - !rest2 = BS.drop (fromIntegral gfLen) rest1 - (fLen, rest3) <- maybe (Left DecodeInsufficientBytes) Right - (decodeU16 rest2) - unless (BS.length rest3 >= fromIntegral fLen) $ - Left DecodeInsufficientBytes - let !feat = BS.take (fromIntegral fLen) rest3 - !rest4 = BS.drop (fromIntegral fLen) rest3 - -- Parse optional TLV stream (consumes all remaining bytes for init) - tlvStream <- if BS.null rest4 - then Right (TlvStream []) - else either (Left . DecodeTlvError) Right (decodeTlvStream rest4) - initTlvList <- either (Left . DecodeTlvError) Right - (parseInitTlvs tlvStream) - -- Init consumes all bytes (TLVs are part of init, not extensions) - Right (Init gf feat initTlvList, BS.empty) - --- | Decode an Error message from payload bytes. -decodeError :: BS.ByteString -> Either DecodeError (Error, BS.ByteString) -decodeError !bs = do - unless (BS.length bs >= 32) $ Left DecodeInsufficientBytes - let !cid = BS.take 32 bs - !rest1 = BS.drop 32 bs - (dLen, rest2) <- maybe (Left DecodeInsufficientBytes) Right - (decodeU16 rest1) - unless (BS.length rest2 >= fromIntegral dLen) $ - Left DecodeInsufficientBytes - let !dat = BS.take (fromIntegral dLen) rest2 - !rest3 = BS.drop (fromIntegral dLen) rest2 - Right (Error cid dat, rest3) - --- | Decode a Warning message from payload bytes. -decodeWarning :: BS.ByteString -> Either DecodeError (Warning, BS.ByteString) -decodeWarning !bs = do - unless (BS.length bs >= 32) $ Left DecodeInsufficientBytes - let !cid = BS.take 32 bs - !rest1 = BS.drop 32 bs - (dLen, rest2) <- maybe (Left DecodeInsufficientBytes) Right - (decodeU16 rest1) - unless (BS.length rest2 >= fromIntegral dLen) $ - Left DecodeInsufficientBytes - let !dat = BS.take (fromIntegral dLen) rest2 - !rest3 = BS.drop (fromIntegral dLen) rest2 - Right (Warning cid dat, rest3) - --- | Decode a Ping message from payload bytes. -decodePing :: BS.ByteString -> Either DecodeError (Ping, BS.ByteString) -decodePing !bs = do - (numPong, rest1) <- maybe (Left DecodeInsufficientBytes) Right - (decodeU16 bs) - (bLen, rest2) <- maybe (Left DecodeInsufficientBytes) Right - (decodeU16 rest1) - unless (BS.length rest2 >= fromIntegral bLen) $ - Left DecodeInsufficientBytes - let !ignored = BS.take (fromIntegral bLen) rest2 - !rest3 = BS.drop (fromIntegral bLen) rest2 - Right (Ping numPong ignored, rest3) - --- | Decode a Pong message from payload bytes. -decodePong :: BS.ByteString -> Either DecodeError (Pong, BS.ByteString) -decodePong !bs = do - (bLen, rest1) <- maybe (Left DecodeInsufficientBytes) Right - (decodeU16 bs) - unless (BS.length rest1 >= fromIntegral bLen) $ - Left DecodeInsufficientBytes - let !ignored = BS.take (fromIntegral bLen) rest1 - !rest2 = BS.drop (fromIntegral bLen) rest1 - Right (Pong ignored, rest2) - --- | Decode a PeerStorage message from payload bytes. -decodePeerStorage - :: BS.ByteString -> Either DecodeError (PeerStorage, BS.ByteString) -decodePeerStorage !bs = do - (bLen, rest1) <- maybe (Left DecodeInsufficientBytes) Right - (decodeU16 bs) - unless (BS.length rest1 >= fromIntegral bLen) $ - Left DecodeInsufficientBytes - let !blob = BS.take (fromIntegral bLen) rest1 - !rest2 = BS.drop (fromIntegral bLen) rest1 - Right (PeerStorage blob, rest2) - --- | Decode a PeerStorageRetrieval message from payload bytes. -decodePeerStorageRetrieval - :: BS.ByteString - -> Either DecodeError (PeerStorageRetrieval, BS.ByteString) -decodePeerStorageRetrieval !bs = do - (bLen, rest1) <- maybe (Left DecodeInsufficientBytes) Right - (decodeU16 bs) - unless (BS.length rest1 >= fromIntegral bLen) $ - Left DecodeInsufficientBytes - let !blob = BS.take (fromIntegral bLen) rest1 - !rest2 = BS.drop (fromIntegral bLen) rest1 - Right (PeerStorageRetrieval blob, rest2) - --- | Decode a message from its type and payload. --- --- Returns the decoded message and any remaining bytes (for extensions). --- For unknown types, returns an appropriate error. -decodeMessage - :: MsgType -> BS.ByteString -> Either DecodeError (Message, BS.ByteString) -decodeMessage MsgInit bs = do - (m, rest) <- decodeInit bs - Right (MsgInitVal m, rest) -decodeMessage MsgError bs = do - (m, rest) <- decodeError bs - Right (MsgErrorVal m, rest) -decodeMessage MsgWarning bs = do - (m, rest) <- decodeWarning bs - Right (MsgWarningVal m, rest) -decodeMessage MsgPing bs = do - (m, rest) <- decodePing bs - Right (MsgPingVal m, rest) -decodeMessage MsgPong bs = do - (m, rest) <- decodePong bs - Right (MsgPongVal m, rest) -decodeMessage MsgPeerStorage bs = do - (m, rest) <- decodePeerStorage bs - Right (MsgPeerStorageVal m, rest) -decodeMessage MsgPeerStorageRet bs = do - (m, rest) <- decodePeerStorageRetrieval bs - Right (MsgPeerStorageRetrievalVal m, rest) -decodeMessage (MsgUnknown w) _ - | even w = Left (DecodeUnknownEvenType w) - | otherwise = Left (DecodeUnknownOddType w) - --- | Decode a complete envelope (type + payload + optional extension). --- --- Per BOLT #1: --- - Unknown odd message types are ignored (returns Nothing for message) --- - Unknown even message types cause connection close (returns error) --- - Invalid extension TLV causes connection close (returns error) --- --- Returns the decoded message (if known) and any extension TLVs. -decodeEnvelope - :: BS.ByteString - -> Either DecodeError (Maybe Message, Maybe TlvStream) -decodeEnvelope !bs = do - (typeWord, rest1) <- maybe (Left DecodeInsufficientBytes) Right - (decodeU16 bs) - let !msgType = parseMsgType typeWord - case msgType of - MsgUnknown w - | even w -> Left (DecodeUnknownEvenType w) - | otherwise -> Right (Nothing, Nothing) -- Ignore unknown odd types - _ -> do - (msg, rest2) <- decodeMessage msgType rest1 - -- Parse any remaining bytes as extension TLV - -- Per BOLT #1: unknown even types must fail, unknown odd are ignored - ext <- if BS.null rest2 - then Right Nothing - else case decodeTlvStreamWith (const False) rest2 of - Left e -> Left (DecodeInvalidExtension e) - Right s -> Right (Just s) - Right (Just msg, ext) +-- Re-export from sub-modules +import Lightning.Protocol.BOLT1.Prim +import Lightning.Protocol.BOLT1.TLV +import Lightning.Protocol.BOLT1.Message +import Lightning.Protocol.BOLT1.Codec diff --git a/lib/Lightning/Protocol/BOLT1/Codec.hs b/lib/Lightning/Protocol/BOLT1/Codec.hs @@ -0,0 +1,320 @@ +{-# OPTIONS_HADDOCK prune #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE LambdaCase #-} + +-- | +-- Module: Lightning.Protocol.BOLT1.Codec +-- Copyright: (c) 2025 Jared Tobin +-- License: MIT +-- Maintainer: Jared Tobin <jared@ppad.tech> +-- +-- Message encoding and decoding for BOLT #1. + +module Lightning.Protocol.BOLT1.Codec ( + -- * Encoding errors + EncodeError(..) + + -- * Message encoding + , encodeInit + , encodeError + , encodeWarning + , encodePing + , encodePong + , encodePeerStorage + , encodePeerStorageRetrieval + , encodeMessage + , encodeEnvelope + + -- * Decoding errors + , DecodeError(..) + + -- * Message decoding + , decodeInit + , decodeError + , decodeWarning + , decodePing + , decodePong + , decodePeerStorage + , decodePeerStorageRetrieval + , decodeMessage + , decodeEnvelope + ) where + +import Control.DeepSeq (NFData) +import Control.Monad (when, unless) +import qualified Data.ByteString as BS +import Data.Word (Word16) +import GHC.Generics (Generic) +import Lightning.Protocol.BOLT1.Prim +import Lightning.Protocol.BOLT1.TLV +import Lightning.Protocol.BOLT1.Message + +-- Encoding errors ------------------------------------------------------------- + +-- | Encoding errors. +data EncodeError + = EncodeLengthOverflow -- ^ Field length exceeds u16 max (65535 bytes) + | EncodeMessageTooLarge -- ^ Total message size exceeds 65535 bytes + deriving stock (Eq, Show, Generic) + +instance NFData EncodeError + +-- Message encoding ------------------------------------------------------------ + +-- | Encode an Init message payload. +encodeInit :: Init -> Either EncodeError BS.ByteString +encodeInit (Init gf feat tlvs) = do + gfLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength gf) + featLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength feat) + Right $ mconcat + [ gfLen + , gf + , featLen + , feat + , encodeTlvStream (encodeInitTlvs tlvs) + ] + +-- | Encode an Error message payload. +encodeError :: Error -> Either EncodeError BS.ByteString +encodeError (Error cid dat) = do + datLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength dat) + Right $ mconcat [cid, datLen, dat] + +-- | Encode a Warning message payload. +encodeWarning :: Warning -> Either EncodeError BS.ByteString +encodeWarning (Warning cid dat) = do + datLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength dat) + Right $ mconcat [cid, datLen, dat] + +-- | Encode a Ping message payload. +encodePing :: Ping -> Either EncodeError BS.ByteString +encodePing (Ping numPong ignored) = do + ignoredLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength ignored) + Right $ mconcat [encodeU16 numPong, ignoredLen, ignored] + +-- | Encode a Pong message payload. +encodePong :: Pong -> Either EncodeError BS.ByteString +encodePong (Pong ignored) = do + ignoredLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength ignored) + Right $ mconcat [ignoredLen, ignored] + +-- | Encode a PeerStorage message payload. +encodePeerStorage :: PeerStorage -> Either EncodeError BS.ByteString +encodePeerStorage (PeerStorage blob) = do + blobLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength blob) + Right $ mconcat [blobLen, blob] + +-- | Encode a PeerStorageRetrieval message payload. +encodePeerStorageRetrieval + :: PeerStorageRetrieval -> Either EncodeError BS.ByteString +encodePeerStorageRetrieval (PeerStorageRetrieval blob) = do + blobLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength blob) + Right $ mconcat [blobLen, blob] + +-- | Encode a message to its payload bytes. +encodeMessage :: Message -> Either EncodeError BS.ByteString +encodeMessage = \case + MsgInitVal m -> encodeInit m + MsgErrorVal m -> encodeError m + MsgWarningVal m -> encodeWarning m + MsgPingVal m -> encodePing m + MsgPongVal m -> encodePong m + MsgPeerStorageVal m -> encodePeerStorage m + MsgPeerStorageRetrievalVal m -> encodePeerStorageRetrieval m + +-- | Encode a message as a complete envelope (type + payload + extension). +-- +-- Per BOLT #1, the total message size must not exceed 65535 bytes. +encodeEnvelope :: Message -> Maybe TlvStream -> Either EncodeError BS.ByteString +encodeEnvelope msg mext = do + payload <- encodeMessage msg + let !typeBytes = encodeU16 (msgTypeWord (messageType msg)) + !extBytes = maybe BS.empty encodeTlvStream mext + !result = mconcat [typeBytes, payload, extBytes] + -- Per BOLT #1: message size must fit in 2 bytes (max 65535) + when (BS.length result > 65535) $ + Left EncodeMessageTooLarge + Right result + +-- Decoding errors ------------------------------------------------------------- + +-- | Decoding errors. +data DecodeError + = DecodeInsufficientBytes + | DecodeInvalidLength + | DecodeUnknownEvenType !Word16 + | DecodeUnknownOddType !Word16 + | DecodeTlvError !TlvError + | DecodeInvalidChannelId + | DecodeInvalidExtension !TlvError + deriving stock (Eq, Show, Generic) + +instance NFData DecodeError + +-- Message decoding ------------------------------------------------------------ + +-- | Decode an Init message from payload bytes. +-- +-- Returns the decoded message and any remaining bytes. +decodeInit :: BS.ByteString -> Either DecodeError (Init, BS.ByteString) +decodeInit !bs = do + (gfLen, rest1) <- maybe (Left DecodeInsufficientBytes) Right + (decodeU16 bs) + unless (BS.length rest1 >= fromIntegral gfLen) $ + Left DecodeInsufficientBytes + let !gf = BS.take (fromIntegral gfLen) rest1 + !rest2 = BS.drop (fromIntegral gfLen) rest1 + (fLen, rest3) <- maybe (Left DecodeInsufficientBytes) Right + (decodeU16 rest2) + unless (BS.length rest3 >= fromIntegral fLen) $ + Left DecodeInsufficientBytes + let !feat = BS.take (fromIntegral fLen) rest3 + !rest4 = BS.drop (fromIntegral fLen) rest3 + -- Parse optional TLV stream (consumes all remaining bytes for init) + tlvStream <- if BS.null rest4 + then Right (TlvStream []) + else either (Left . DecodeTlvError) Right (decodeTlvStream rest4) + initTlvList <- either (Left . DecodeTlvError) Right + (parseInitTlvs tlvStream) + -- Init consumes all bytes (TLVs are part of init, not extensions) + Right (Init gf feat initTlvList, BS.empty) + +-- | Decode an Error message from payload bytes. +decodeError :: BS.ByteString -> Either DecodeError (Error, BS.ByteString) +decodeError !bs = do + unless (BS.length bs >= 32) $ Left DecodeInsufficientBytes + let !cid = BS.take 32 bs + !rest1 = BS.drop 32 bs + (dLen, rest2) <- maybe (Left DecodeInsufficientBytes) Right + (decodeU16 rest1) + unless (BS.length rest2 >= fromIntegral dLen) $ + Left DecodeInsufficientBytes + let !dat = BS.take (fromIntegral dLen) rest2 + !rest3 = BS.drop (fromIntegral dLen) rest2 + Right (Error cid dat, rest3) + +-- | Decode a Warning message from payload bytes. +decodeWarning :: BS.ByteString -> Either DecodeError (Warning, BS.ByteString) +decodeWarning !bs = do + unless (BS.length bs >= 32) $ Left DecodeInsufficientBytes + let !cid = BS.take 32 bs + !rest1 = BS.drop 32 bs + (dLen, rest2) <- maybe (Left DecodeInsufficientBytes) Right + (decodeU16 rest1) + unless (BS.length rest2 >= fromIntegral dLen) $ + Left DecodeInsufficientBytes + let !dat = BS.take (fromIntegral dLen) rest2 + !rest3 = BS.drop (fromIntegral dLen) rest2 + Right (Warning cid dat, rest3) + +-- | Decode a Ping message from payload bytes. +decodePing :: BS.ByteString -> Either DecodeError (Ping, BS.ByteString) +decodePing !bs = do + (numPong, rest1) <- maybe (Left DecodeInsufficientBytes) Right + (decodeU16 bs) + (bLen, rest2) <- maybe (Left DecodeInsufficientBytes) Right + (decodeU16 rest1) + unless (BS.length rest2 >= fromIntegral bLen) $ + Left DecodeInsufficientBytes + let !ignored = BS.take (fromIntegral bLen) rest2 + !rest3 = BS.drop (fromIntegral bLen) rest2 + Right (Ping numPong ignored, rest3) + +-- | Decode a Pong message from payload bytes. +decodePong :: BS.ByteString -> Either DecodeError (Pong, BS.ByteString) +decodePong !bs = do + (bLen, rest1) <- maybe (Left DecodeInsufficientBytes) Right + (decodeU16 bs) + unless (BS.length rest1 >= fromIntegral bLen) $ + Left DecodeInsufficientBytes + let !ignored = BS.take (fromIntegral bLen) rest1 + !rest2 = BS.drop (fromIntegral bLen) rest1 + Right (Pong ignored, rest2) + +-- | Decode a PeerStorage message from payload bytes. +decodePeerStorage + :: BS.ByteString -> Either DecodeError (PeerStorage, BS.ByteString) +decodePeerStorage !bs = do + (bLen, rest1) <- maybe (Left DecodeInsufficientBytes) Right + (decodeU16 bs) + unless (BS.length rest1 >= fromIntegral bLen) $ + Left DecodeInsufficientBytes + let !blob = BS.take (fromIntegral bLen) rest1 + !rest2 = BS.drop (fromIntegral bLen) rest1 + Right (PeerStorage blob, rest2) + +-- | Decode a PeerStorageRetrieval message from payload bytes. +decodePeerStorageRetrieval + :: BS.ByteString + -> Either DecodeError (PeerStorageRetrieval, BS.ByteString) +decodePeerStorageRetrieval !bs = do + (bLen, rest1) <- maybe (Left DecodeInsufficientBytes) Right + (decodeU16 bs) + unless (BS.length rest1 >= fromIntegral bLen) $ + Left DecodeInsufficientBytes + let !blob = BS.take (fromIntegral bLen) rest1 + !rest2 = BS.drop (fromIntegral bLen) rest1 + Right (PeerStorageRetrieval blob, rest2) + +-- | Decode a message from its type and payload. +-- +-- Returns the decoded message and any remaining bytes (for extensions). +-- For unknown types, returns an appropriate error. +decodeMessage + :: MsgType -> BS.ByteString -> Either DecodeError (Message, BS.ByteString) +decodeMessage MsgInit bs = do + (m, rest) <- decodeInit bs + Right (MsgInitVal m, rest) +decodeMessage MsgError bs = do + (m, rest) <- decodeError bs + Right (MsgErrorVal m, rest) +decodeMessage MsgWarning bs = do + (m, rest) <- decodeWarning bs + Right (MsgWarningVal m, rest) +decodeMessage MsgPing bs = do + (m, rest) <- decodePing bs + Right (MsgPingVal m, rest) +decodeMessage MsgPong bs = do + (m, rest) <- decodePong bs + Right (MsgPongVal m, rest) +decodeMessage MsgPeerStorage bs = do + (m, rest) <- decodePeerStorage bs + Right (MsgPeerStorageVal m, rest) +decodeMessage MsgPeerStorageRet bs = do + (m, rest) <- decodePeerStorageRetrieval bs + Right (MsgPeerStorageRetrievalVal m, rest) +decodeMessage (MsgUnknown w) _ + | even w = Left (DecodeUnknownEvenType w) + | otherwise = Left (DecodeUnknownOddType w) + +-- | Decode a complete envelope (type + payload + optional extension). +-- +-- Per BOLT #1: +-- - Unknown odd message types are ignored (returns Nothing for message) +-- - Unknown even message types cause connection close (returns error) +-- - Invalid extension TLV causes connection close (returns error) +-- +-- Returns the decoded message (if known) and any extension TLVs. +decodeEnvelope + :: BS.ByteString + -> Either DecodeError (Maybe Message, Maybe TlvStream) +decodeEnvelope !bs = do + (typeWord, rest1) <- maybe (Left DecodeInsufficientBytes) Right + (decodeU16 bs) + let !msgType = parseMsgType typeWord + case msgType of + MsgUnknown w + | even w -> Left (DecodeUnknownEvenType w) + | otherwise -> Right (Nothing, Nothing) -- Ignore unknown odd types + _ -> do + (msg, rest2) <- decodeMessage msgType rest1 + -- Parse any remaining bytes as extension TLV + -- Per BOLT #1: unknown even types must fail, unknown odd are ignored + ext <- if BS.null rest2 + then Right Nothing + else case decodeTlvStreamWith (const False) rest2 of + Left e -> Left (DecodeInvalidExtension e) + Right s -> Right (Just s) + Right (Just msg, ext) diff --git a/lib/Lightning/Protocol/BOLT1/Message.hs b/lib/Lightning/Protocol/BOLT1/Message.hs @@ -0,0 +1,170 @@ +{-# OPTIONS_HADDOCK prune #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingStrategies #-} + +-- | +-- Module: Lightning.Protocol.BOLT1.Message +-- Copyright: (c) 2025 Jared Tobin +-- License: MIT +-- Maintainer: Jared Tobin <jared@ppad.tech> +-- +-- Message types for BOLT #1. + +module Lightning.Protocol.BOLT1.Message ( + -- * Message types + MsgType(..) + , msgTypeWord + , parseMsgType + + -- * Setup messages + , Init(..) + , Error(..) + , Warning(..) + + -- * Control messages + , Ping(..) + , Pong(..) + + -- * Peer storage messages + , PeerStorage(..) + , PeerStorageRetrieval(..) + + -- * Message envelope + , Message(..) + , messageType + , Envelope(..) + ) where + +import Control.DeepSeq (NFData) +import qualified Data.ByteString as BS +import Data.Word (Word16) +import GHC.Generics (Generic) +import Lightning.Protocol.BOLT1.TLV + +-- Message types --------------------------------------------------------------- + +-- | BOLT #1 message type codes. +data MsgType + = MsgInit -- ^ 16 + | MsgError -- ^ 17 + | MsgPing -- ^ 18 + | MsgPong -- ^ 19 + | MsgWarning -- ^ 1 + | MsgPeerStorage -- ^ 7 + | MsgPeerStorageRet -- ^ 9 + | MsgUnknown !Word16 -- ^ Unknown type + deriving stock (Eq, Show, Generic) + +instance NFData MsgType + +-- | Get the numeric type code for a message type. +msgTypeWord :: MsgType -> Word16 +msgTypeWord MsgInit = 16 +msgTypeWord MsgError = 17 +msgTypeWord MsgPing = 18 +msgTypeWord MsgPong = 19 +msgTypeWord MsgWarning = 1 +msgTypeWord MsgPeerStorage = 7 +msgTypeWord MsgPeerStorageRet = 9 +msgTypeWord (MsgUnknown w) = w + +-- | Parse a message type from a word. +parseMsgType :: Word16 -> MsgType +parseMsgType 16 = MsgInit +parseMsgType 17 = MsgError +parseMsgType 18 = MsgPing +parseMsgType 19 = MsgPong +parseMsgType 1 = MsgWarning +parseMsgType 7 = MsgPeerStorage +parseMsgType 9 = MsgPeerStorageRet +parseMsgType w = MsgUnknown w + +-- Message ADTs ---------------------------------------------------------------- + +-- | The init message (type 16). +data Init = Init + { initGlobalFeatures :: !BS.ByteString + , initFeatures :: !BS.ByteString + , initTlvs :: ![InitTlv] + } deriving stock (Eq, Show, Generic) + +instance NFData Init + +-- | The error message (type 17). +data Error = Error + { errorChannelId :: !BS.ByteString -- ^ 32 bytes + , errorData :: !BS.ByteString + } deriving stock (Eq, Show, Generic) + +instance NFData Error + +-- | The warning message (type 1). +data Warning = Warning + { warningChannelId :: !BS.ByteString -- ^ 32 bytes + , warningData :: !BS.ByteString + } deriving stock (Eq, Show, Generic) + +instance NFData Warning + +-- | The ping message (type 18). +data Ping = Ping + { pingNumPongBytes :: {-# UNPACK #-} !Word16 + , pingIgnored :: !BS.ByteString + } deriving stock (Eq, Show, Generic) + +instance NFData Ping + +-- | The pong message (type 19). +data Pong = Pong + { pongIgnored :: !BS.ByteString + } deriving stock (Eq, Show, Generic) + +instance NFData Pong + +-- | The peer_storage message (type 7). +data PeerStorage = PeerStorage + { peerStorageBlob :: !BS.ByteString + } deriving stock (Eq, Show, Generic) + +instance NFData PeerStorage + +-- | The peer_storage_retrieval message (type 9). +data PeerStorageRetrieval = PeerStorageRetrieval + { peerStorageRetrievalBlob :: !BS.ByteString + } deriving stock (Eq, Show, Generic) + +instance NFData PeerStorageRetrieval + +-- | All BOLT #1 messages. +data Message + = MsgInitVal !Init + | MsgErrorVal !Error + | MsgWarningVal !Warning + | MsgPingVal !Ping + | MsgPongVal !Pong + | MsgPeerStorageVal !PeerStorage + | MsgPeerStorageRetrievalVal !PeerStorageRetrieval + deriving stock (Eq, Show, Generic) + +instance NFData Message + +-- | Get the message type for a message. +messageType :: Message -> MsgType +messageType (MsgInitVal _) = MsgInit +messageType (MsgErrorVal _) = MsgError +messageType (MsgWarningVal _) = MsgWarning +messageType (MsgPingVal _) = MsgPing +messageType (MsgPongVal _) = MsgPong +messageType (MsgPeerStorageVal _) = MsgPeerStorage +messageType (MsgPeerStorageRetrievalVal _) = MsgPeerStorageRet + +-- Message envelope ------------------------------------------------------------ + +-- | A complete message envelope with type, payload, and optional extension. +data Envelope = Envelope + { envType :: !MsgType + , envPayload :: !BS.ByteString + , envExtension :: !(Maybe TlvStream) + } deriving stock (Eq, Show, Generic) + +instance NFData Envelope diff --git a/lib/Lightning/Protocol/BOLT1/Prim.hs b/lib/Lightning/Protocol/BOLT1/Prim.hs @@ -0,0 +1,496 @@ +{-# OPTIONS_HADDOCK prune #-} +{-# LANGUAGE BangPatterns #-} + +-- | +-- Module: Lightning.Protocol.BOLT1.Prim +-- Copyright: (c) 2025 Jared Tobin +-- License: MIT +-- Maintainer: Jared Tobin <jared@ppad.tech> +-- +-- Primitive type encoding and decoding for BOLT #1. + +module Lightning.Protocol.BOLT1.Prim ( + -- * Unsigned integer encoding + encodeU16 + , encodeU32 + , encodeU64 + + -- * Signed integer encoding + , encodeS8 + , encodeS16 + , encodeS32 + , encodeS64 + + -- * Truncated unsigned integer encoding + , encodeTu16 + , encodeTu32 + , encodeTu64 + + -- * Minimal signed integer encoding + , encodeMinSigned + + -- * BigSize encoding + , encodeBigSize + + -- * Unsigned integer decoding + , decodeU16 + , decodeU32 + , decodeU64 + + -- * Signed integer decoding + , decodeS8 + , decodeS16 + , decodeS32 + , decodeS64 + + -- * Truncated unsigned integer decoding + , decodeTu16 + , decodeTu32 + , decodeTu64 + + -- * Minimal signed integer decoding + , decodeMinSigned + + -- * BigSize decoding + , decodeBigSize + + -- * Internal helpers + , encodeLength + ) where + +import Data.Bits (unsafeShiftL, unsafeShiftR, (.|.)) +import qualified Data.ByteString as BS +import qualified Data.ByteString.Builder as BSB +import qualified Data.ByteString.Lazy as BSL +import Data.Int (Int8, Int16, Int32, Int64) +import Data.Word (Word16, Word32, Word64) + +-- Unsigned integer encoding --------------------------------------------------- + +-- | Encode a 16-bit unsigned integer (big-endian). +-- +-- >>> encodeU16 0x0102 +-- "\SOH\STX" +encodeU16 :: Word16 -> BS.ByteString +encodeU16 = BSL.toStrict . BSB.toLazyByteString . BSB.word16BE +{-# INLINE encodeU16 #-} + +-- | Encode a 32-bit unsigned integer (big-endian). +-- +-- >>> encodeU32 0x01020304 +-- "\SOH\STX\ETX\EOT" +encodeU32 :: Word32 -> BS.ByteString +encodeU32 = BSL.toStrict . BSB.toLazyByteString . BSB.word32BE +{-# INLINE encodeU32 #-} + +-- | Encode a 64-bit unsigned integer (big-endian). +-- +-- >>> encodeU64 0x0102030405060708 +-- "\SOH\STX\ETX\EOT\ENQ\ACK\a\b" +encodeU64 :: Word64 -> BS.ByteString +encodeU64 = BSL.toStrict . BSB.toLazyByteString . BSB.word64BE +{-# INLINE encodeU64 #-} + +-- Signed integer encoding ----------------------------------------------------- + +-- | Encode an 8-bit signed integer. +-- +-- >>> encodeS8 42 +-- "*" +-- >>> encodeS8 (-42) +-- "\214" +encodeS8 :: Int8 -> BS.ByteString +encodeS8 = BS.singleton . fromIntegral +{-# INLINE encodeS8 #-} + +-- | Encode a 16-bit signed integer (big-endian two's complement). +-- +-- >>> encodeS16 0x0102 +-- "\SOH\STX" +-- >>> encodeS16 (-1) +-- "\255\255" +encodeS16 :: Int16 -> BS.ByteString +encodeS16 = BSL.toStrict . BSB.toLazyByteString . BSB.int16BE +{-# INLINE encodeS16 #-} + +-- | Encode a 32-bit signed integer (big-endian two's complement). +-- +-- >>> encodeS32 0x01020304 +-- "\SOH\STX\ETX\EOT" +-- >>> encodeS32 (-1) +-- "\255\255\255\255" +encodeS32 :: Int32 -> BS.ByteString +encodeS32 = BSL.toStrict . BSB.toLazyByteString . BSB.int32BE +{-# INLINE encodeS32 #-} + +-- | Encode a 64-bit signed integer (big-endian two's complement). +-- +-- >>> encodeS64 0x0102030405060708 +-- "\SOH\STX\ETX\EOT\ENQ\ACK\a\b" +-- >>> encodeS64 (-1) +-- "\255\255\255\255\255\255\255\255" +encodeS64 :: Int64 -> BS.ByteString +encodeS64 = BSL.toStrict . BSB.toLazyByteString . BSB.int64BE +{-# INLINE encodeS64 #-} + +-- Truncated unsigned integer encoding ----------------------------------------- + +-- | Encode a truncated 16-bit unsigned integer (0-2 bytes). +-- +-- Leading zeros are omitted per BOLT #1. Zero encodes to empty. +-- +-- >>> encodeTu16 0 +-- "" +-- >>> encodeTu16 1 +-- "\SOH" +-- >>> encodeTu16 256 +-- "\SOH\NUL" +encodeTu16 :: Word16 -> BS.ByteString +encodeTu16 0 = BS.empty +encodeTu16 !x + | x < 0x100 = BS.singleton (fromIntegral x) + | otherwise = encodeU16 x +{-# INLINE encodeTu16 #-} + +-- | Encode a truncated 32-bit unsigned integer (0-4 bytes). +-- +-- Leading zeros are omitted per BOLT #1. Zero encodes to empty. +-- +-- >>> encodeTu32 0 +-- "" +-- >>> encodeTu32 1 +-- "\SOH" +-- >>> encodeTu32 0x010000 +-- "\SOH\NUL\NUL" +encodeTu32 :: Word32 -> BS.ByteString +encodeTu32 0 = BS.empty +encodeTu32 !x + | x < 0x100 = BS.singleton (fromIntegral x) + | x < 0x10000 = encodeU16 (fromIntegral x) + | x < 0x1000000 = BS.pack [ fromIntegral (x `unsafeShiftR` 16) + , fromIntegral (x `unsafeShiftR` 8) + , fromIntegral x + ] + | otherwise = encodeU32 x +{-# INLINE encodeTu32 #-} + +-- | Encode a truncated 64-bit unsigned integer (0-8 bytes). +-- +-- Leading zeros are omitted per BOLT #1. Zero encodes to empty. +-- +-- >>> encodeTu64 0 +-- "" +-- >>> encodeTu64 1 +-- "\SOH" +-- >>> encodeTu64 0x0100000000 +-- "\SOH\NUL\NUL\NUL\NUL" +encodeTu64 :: Word64 -> BS.ByteString +encodeTu64 0 = BS.empty +encodeTu64 !x + | x < 0x100 = BS.singleton (fromIntegral x) + | x < 0x10000 = encodeU16 (fromIntegral x) + | x < 0x1000000 = BS.pack [ fromIntegral (x `unsafeShiftR` 16) + , fromIntegral (x `unsafeShiftR` 8) + , fromIntegral x + ] + | x < 0x100000000 = encodeU32 (fromIntegral x) + | x < 0x10000000000 = BS.pack [ fromIntegral (x `unsafeShiftR` 32) + , fromIntegral (x `unsafeShiftR` 24) + , fromIntegral (x `unsafeShiftR` 16) + , fromIntegral (x `unsafeShiftR` 8) + , fromIntegral x + ] + | x < 0x1000000000000 = BS.pack [ fromIntegral (x `unsafeShiftR` 40) + , fromIntegral (x `unsafeShiftR` 32) + , fromIntegral (x `unsafeShiftR` 24) + , fromIntegral (x `unsafeShiftR` 16) + , fromIntegral (x `unsafeShiftR` 8) + , fromIntegral x + ] + | x < 0x100000000000000 = BS.pack [ fromIntegral (x `unsafeShiftR` 48) + , fromIntegral (x `unsafeShiftR` 40) + , fromIntegral (x `unsafeShiftR` 32) + , fromIntegral (x `unsafeShiftR` 24) + , fromIntegral (x `unsafeShiftR` 16) + , fromIntegral (x `unsafeShiftR` 8) + , fromIntegral x + ] + | otherwise = encodeU64 x +{-# INLINE encodeTu64 #-} + +-- Minimal signed integer encoding --------------------------------------------- + +-- | Encode a signed 64-bit integer using minimal bytes. +-- +-- Uses the smallest number of bytes that can represent the value +-- in two's complement. Per BOLT #1 Appendix D test vectors. +-- +-- >>> encodeMinSigned 0 +-- "\NUL" +-- >>> encodeMinSigned 127 +-- "\DEL" +-- >>> encodeMinSigned 128 +-- "\NUL\128" +-- >>> encodeMinSigned (-1) +-- "\255" +-- >>> encodeMinSigned (-128) +-- "\128" +-- >>> encodeMinSigned (-129) +-- "\255\DEL" +encodeMinSigned :: Int64 -> BS.ByteString +encodeMinSigned !x + | x >= -128 && x <= 127 = + -- Fits in 1 byte + BS.singleton (fromIntegral x) + | x >= -32768 && x <= 32767 = + -- Fits in 2 bytes + encodeS16 (fromIntegral x) + | x >= -2147483648 && x <= 2147483647 = + -- Fits in 4 bytes + encodeS32 (fromIntegral x) + | otherwise = + -- Need 8 bytes + encodeS64 x +{-# INLINE encodeMinSigned #-} + +-- BigSize encoding ------------------------------------------------------------ + +-- | Encode a BigSize value (variable-length unsigned integer). +-- +-- >>> encodeBigSize 0 +-- "\NUL" +-- >>> encodeBigSize 252 +-- "\252" +-- >>> encodeBigSize 253 +-- "\253\NUL\253" +-- >>> encodeBigSize 65536 +-- "\254\NUL\SOH\NUL\NUL" +encodeBigSize :: Word64 -> BS.ByteString +encodeBigSize !x + | x < 0xfd = BS.singleton (fromIntegral x) + | x < 0x10000 = BS.cons 0xfd (encodeU16 (fromIntegral x)) + | x < 0x100000000 = BS.cons 0xfe (encodeU32 (fromIntegral x)) + | otherwise = BS.cons 0xff (encodeU64 x) +{-# INLINE encodeBigSize #-} + +-- Length encoding ------------------------------------------------------------- + +-- | Encode a length as u16, checking bounds. +-- +-- Returns Nothing if the length exceeds 65535. +encodeLength :: BS.ByteString -> Maybe BS.ByteString +encodeLength !bs + | BS.length bs > 65535 = Nothing + | otherwise = Just (encodeU16 (fromIntegral (BS.length bs))) +{-# INLINE encodeLength #-} + +-- Unsigned integer decoding --------------------------------------------------- + +-- | Decode a 16-bit unsigned integer (big-endian). +decodeU16 :: BS.ByteString -> Maybe (Word16, BS.ByteString) +decodeU16 !bs + | BS.length bs < 2 = Nothing + | otherwise = + let !b0 = fromIntegral (BS.index bs 0) + !b1 = fromIntegral (BS.index bs 1) + !val = (b0 `unsafeShiftL` 8) .|. b1 + in Just (val, BS.drop 2 bs) +{-# INLINE decodeU16 #-} + +-- | Decode a 32-bit unsigned integer (big-endian). +decodeU32 :: BS.ByteString -> Maybe (Word32, BS.ByteString) +decodeU32 !bs + | BS.length bs < 4 = Nothing + | otherwise = + let !b0 = fromIntegral (BS.index bs 0) + !b1 = fromIntegral (BS.index bs 1) + !b2 = fromIntegral (BS.index bs 2) + !b3 = fromIntegral (BS.index bs 3) + !val = (b0 `unsafeShiftL` 24) .|. (b1 `unsafeShiftL` 16) + .|. (b2 `unsafeShiftL` 8) .|. b3 + in Just (val, BS.drop 4 bs) +{-# INLINE decodeU32 #-} + +-- | Decode a 64-bit unsigned integer (big-endian). +decodeU64 :: BS.ByteString -> Maybe (Word64, BS.ByteString) +decodeU64 !bs + | BS.length bs < 8 = Nothing + | otherwise = + let !b0 = fromIntegral (BS.index bs 0) + !b1 = fromIntegral (BS.index bs 1) + !b2 = fromIntegral (BS.index bs 2) + !b3 = fromIntegral (BS.index bs 3) + !b4 = fromIntegral (BS.index bs 4) + !b5 = fromIntegral (BS.index bs 5) + !b6 = fromIntegral (BS.index bs 6) + !b7 = fromIntegral (BS.index bs 7) + !val = (b0 `unsafeShiftL` 56) .|. (b1 `unsafeShiftL` 48) + .|. (b2 `unsafeShiftL` 40) .|. (b3 `unsafeShiftL` 32) + .|. (b4 `unsafeShiftL` 24) .|. (b5 `unsafeShiftL` 16) + .|. (b6 `unsafeShiftL` 8) .|. b7 + in Just (val, BS.drop 8 bs) +{-# INLINE decodeU64 #-} + +-- Signed integer decoding ----------------------------------------------------- + +-- | Decode an 8-bit signed integer. +decodeS8 :: BS.ByteString -> Maybe (Int8, BS.ByteString) +decodeS8 !bs + | BS.null bs = Nothing + | otherwise = Just (fromIntegral (BS.index bs 0), BS.drop 1 bs) +{-# INLINE decodeS8 #-} + +-- | Decode a 16-bit signed integer (big-endian two's complement). +decodeS16 :: BS.ByteString -> Maybe (Int16, BS.ByteString) +decodeS16 !bs = do + (w, rest) <- decodeU16 bs + Just (fromIntegral w, rest) +{-# INLINE decodeS16 #-} + +-- | Decode a 32-bit signed integer (big-endian two's complement). +decodeS32 :: BS.ByteString -> Maybe (Int32, BS.ByteString) +decodeS32 !bs = do + (w, rest) <- decodeU32 bs + Just (fromIntegral w, rest) +{-# INLINE decodeS32 #-} + +-- | Decode a 64-bit signed integer (big-endian two's complement). +decodeS64 :: BS.ByteString -> Maybe (Int64, BS.ByteString) +decodeS64 !bs = do + (w, rest) <- decodeU64 bs + Just (fromIntegral w, rest) +{-# INLINE decodeS64 #-} + +-- Truncated unsigned integer decoding ----------------------------------------- + +-- | Decode a truncated 16-bit unsigned integer (0-2 bytes). +-- +-- Returns Nothing if the encoding is non-minimal (has leading zeros). +decodeTu16 :: Int -> BS.ByteString -> Maybe (Word16, BS.ByteString) +decodeTu16 !len !bs + | len < 0 || len > 2 = Nothing + | BS.length bs < len = Nothing + | len == 0 = Just (0, bs) + | otherwise = + let !bytes = BS.take len bs + !rest = BS.drop len bs + in if BS.index bytes 0 == 0 + then Nothing -- non-minimal: leading zero + else Just (decodeBeWord16 bytes, rest) + where + decodeBeWord16 :: BS.ByteString -> Word16 + decodeBeWord16 b = case BS.length b of + 1 -> fromIntegral (BS.index b 0) + 2 -> (fromIntegral (BS.index b 0) `unsafeShiftL` 8) + .|. fromIntegral (BS.index b 1) + _ -> 0 +{-# INLINE decodeTu16 #-} + +-- | Decode a truncated 32-bit unsigned integer (0-4 bytes). +-- +-- Returns Nothing if the encoding is non-minimal (has leading zeros). +decodeTu32 :: Int -> BS.ByteString -> Maybe (Word32, BS.ByteString) +decodeTu32 !len !bs + | len < 0 || len > 4 = Nothing + | BS.length bs < len = Nothing + | len == 0 = Just (0, bs) + | otherwise = + let !bytes = BS.take len bs + !rest = BS.drop len bs + in if BS.index bytes 0 == 0 + then Nothing -- non-minimal: leading zero + else Just (decodeBeWord32 len bytes, rest) + where + decodeBeWord32 :: Int -> BS.ByteString -> Word32 + decodeBeWord32 n b = go 0 0 + where + go !acc !i + | i >= n = acc + | otherwise = go ((acc `unsafeShiftL` 8) + .|. fromIntegral (BS.index b i)) (i + 1) +{-# INLINE decodeTu32 #-} + +-- | Decode a truncated 64-bit unsigned integer (0-8 bytes). +-- +-- Returns Nothing if the encoding is non-minimal (has leading zeros). +decodeTu64 :: Int -> BS.ByteString -> Maybe (Word64, BS.ByteString) +decodeTu64 !len !bs + | len < 0 || len > 8 = Nothing + | BS.length bs < len = Nothing + | len == 0 = Just (0, bs) + | otherwise = + let !bytes = BS.take len bs + !rest = BS.drop len bs + in if BS.index bytes 0 == 0 + then Nothing -- non-minimal: leading zero + else Just (decodeBeWord64 len bytes, rest) + where + decodeBeWord64 :: Int -> BS.ByteString -> Word64 + decodeBeWord64 n b = go 0 0 + where + go !acc !i + | i >= n = acc + | otherwise = go ((acc `unsafeShiftL` 8) + .|. fromIntegral (BS.index b i)) (i + 1) +{-# INLINE decodeTu64 #-} + +-- Minimal signed integer decoding --------------------------------------------- + +-- | Decode a minimal signed integer (1, 2, 4, or 8 bytes). +-- +-- Validates that the encoding is minimal: the value could not be +-- represented in fewer bytes. Per BOLT #1 Appendix D test vectors. +decodeMinSigned :: Int -> BS.ByteString -> Maybe (Int64, BS.ByteString) +decodeMinSigned !len !bs + | BS.length bs < len = Nothing + | otherwise = case len of + 1 -> do + (v, rest) <- decodeS8 bs + Just (fromIntegral v, rest) + 2 -> do + (v, rest) <- decodeS16 bs + -- Must not fit in 1 byte + if v >= -128 && v <= 127 + then Nothing + else Just (fromIntegral v, rest) + 4 -> do + (v, rest) <- decodeS32 bs + -- Must not fit in 2 bytes + if v >= -32768 && v <= 32767 + then Nothing + else Just (fromIntegral v, rest) + 8 -> do + (v, rest) <- decodeS64 bs + -- Must not fit in 4 bytes + if v >= -2147483648 && v <= 2147483647 + then Nothing + else Just (v, rest) + _ -> Nothing +{-# INLINE decodeMinSigned #-} + +-- BigSize decoding ------------------------------------------------------------ + +-- | Decode a BigSize value with minimality check. +decodeBigSize :: BS.ByteString -> Maybe (Word64, BS.ByteString) +decodeBigSize !bs + | BS.null bs = Nothing + | otherwise = case BS.index bs 0 of + 0xff -> do + (val, rest) <- decodeU64 (BS.drop 1 bs) + -- Must be >= 0x100000000 for minimal encoding + if val >= 0x100000000 + then Just (val, rest) + else Nothing + 0xfe -> do + (val, rest) <- decodeU32 (BS.drop 1 bs) + -- Must be >= 0x10000 for minimal encoding + if val >= 0x10000 + then Just (fromIntegral val, rest) + else Nothing + 0xfd -> do + (val, rest) <- decodeU16 (BS.drop 1 bs) + -- Must be >= 0xfd for minimal encoding + if val >= 0xfd + then Just (fromIntegral val, rest) + else Nothing + b -> Just (fromIntegral b, BS.drop 1 bs) diff --git a/lib/Lightning/Protocol/BOLT1/TLV.hs b/lib/Lightning/Protocol/BOLT1/TLV.hs @@ -0,0 +1,209 @@ +{-# OPTIONS_HADDOCK prune #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingStrategies #-} + +-- | +-- Module: Lightning.Protocol.BOLT1.TLV +-- Copyright: (c) 2025 Jared Tobin +-- License: MIT +-- Maintainer: Jared Tobin <jared@ppad.tech> +-- +-- TLV (Type-Length-Value) format for BOLT #1. + +module Lightning.Protocol.BOLT1.TLV ( + -- * TLV types + TlvRecord(..) + , TlvStream(..) + , TlvError(..) + + -- * TLV encoding + , encodeTlvRecord + , encodeTlvStream + + -- * TLV decoding + , decodeTlvStream + , decodeTlvStreamWith + , decodeTlvStreamRaw + + -- * Init TLV types + , InitTlv(..) + , parseInitTlvs + , encodeInitTlvs + ) where + +import Control.DeepSeq (NFData) +import Control.Monad (when) +import qualified Data.ByteString as BS +import Data.Word (Word64) +import GHC.Generics (Generic) +import Lightning.Protocol.BOLT1.Prim + +-- TLV types ------------------------------------------------------------------- + +-- | A single TLV record. +data TlvRecord = TlvRecord + { tlvType :: {-# UNPACK #-} !Word64 + , tlvValue :: !BS.ByteString + } deriving stock (Eq, Show, Generic) + +instance NFData TlvRecord + +-- | A TLV stream (series of TLV records). +newtype TlvStream = TlvStream { unTlvStream :: [TlvRecord] } + deriving stock (Eq, Show, Generic) + +instance NFData TlvStream + +-- | TLV decoding errors. +data TlvError + = TlvNonMinimalEncoding + | TlvNotStrictlyIncreasing + | TlvLengthExceedsBounds + | TlvUnknownEvenType !Word64 + | TlvInvalidKnownType !Word64 + deriving stock (Eq, Show, Generic) + +instance NFData TlvError + +-- TLV encoding ---------------------------------------------------------------- + +-- | Encode a TLV record. +encodeTlvRecord :: TlvRecord -> BS.ByteString +encodeTlvRecord (TlvRecord typ val) = mconcat + [ encodeBigSize typ + , encodeBigSize (fromIntegral (BS.length val)) + , val + ] + +-- | Encode a TLV stream. +encodeTlvStream :: TlvStream -> BS.ByteString +encodeTlvStream (TlvStream recs) = mconcat (map encodeTlvRecord recs) + +-- TLV decoding ---------------------------------------------------------------- + +-- | Decode a TLV stream without any known-type validation. +-- +-- This decoder only enforces structural validity: +-- - Types must be strictly increasing +-- - Lengths must not exceed bounds +-- +-- All records are returned regardless of type. Note: this does NOT +-- enforce the BOLT #1 unknown-even-type rule. Use 'decodeTlvStreamWith' +-- with an appropriate predicate for spec-compliant parsing. +decodeTlvStreamRaw :: BS.ByteString -> Either TlvError TlvStream +decodeTlvStreamRaw = go Nothing [] + where + go :: Maybe Word64 -> [TlvRecord] -> BS.ByteString + -> Either TlvError TlvStream + go !_ !acc !bs + | BS.null bs = Right (TlvStream (reverse acc)) + go !mPrevType !acc !bs = do + (typ, rest1) <- maybe (Left TlvNonMinimalEncoding) Right + (decodeBigSize bs) + -- Strictly increasing check + case mPrevType of + Just prevType -> when (typ <= prevType) $ + Left TlvNotStrictlyIncreasing + Nothing -> pure () + (len, rest2) <- maybe (Left TlvNonMinimalEncoding) Right + (decodeBigSize rest1) + -- Length bounds check + when (fromIntegral len > BS.length rest2) $ + Left TlvLengthExceedsBounds + let !val = BS.take (fromIntegral len) rest2 + !rest3 = BS.drop (fromIntegral len) rest2 + !rec = TlvRecord typ val + go (Just typ) (rec : acc) rest3 + +-- | Decode a TLV stream with configurable known-type predicate. +-- +-- Per BOLT #1: +-- - Types must be strictly increasing +-- - Unknown even types cause failure +-- - Unknown odd types are skipped +-- +-- The predicate determines which types are "known" for the context. +decodeTlvStreamWith + :: (Word64 -> Bool) -- ^ Predicate: is this type known? + -> BS.ByteString + -> Either TlvError TlvStream +decodeTlvStreamWith isKnown = go Nothing [] + where + go :: Maybe Word64 -> [TlvRecord] -> BS.ByteString + -> Either TlvError TlvStream + go !_ !acc !bs + | BS.null bs = Right (TlvStream (reverse acc)) + go !mPrevType !acc !bs = do + (typ, rest1) <- maybe (Left TlvNonMinimalEncoding) Right + (decodeBigSize bs) + -- Strictly increasing check + case mPrevType of + Just prevType -> when (typ <= prevType) $ + Left TlvNotStrictlyIncreasing + Nothing -> pure () + (len, rest2) <- maybe (Left TlvNonMinimalEncoding) Right + (decodeBigSize rest1) + -- Length bounds check + when (fromIntegral len > BS.length rest2) $ + Left TlvLengthExceedsBounds + let !val = BS.take (fromIntegral len) rest2 + !rest3 = BS.drop (fromIntegral len) rest2 + !rec = TlvRecord typ val + -- Unknown type handling: even = fail, odd = skip + if isKnown typ + then go (Just typ) (rec : acc) rest3 + else if even typ + then Left (TlvUnknownEvenType typ) + else go (Just typ) acc rest3 -- skip unknown odd + +-- | Decode a TLV stream with BOLT #1 init_tlvs validation. +-- +-- This uses the default known types for init messages (1 and 3). +-- For other contexts, use 'decodeTlvStreamWith' with an appropriate +-- predicate. +decodeTlvStream :: BS.ByteString -> Either TlvError TlvStream +decodeTlvStream = decodeTlvStreamWith isInitTlvType + where + isInitTlvType :: Word64 -> Bool + isInitTlvType 1 = True -- networks + isInitTlvType 3 = True -- remote_addr + isInitTlvType _ = False + +-- Init TLV types -------------------------------------------------------------- + +-- | TLV records for init message. +data InitTlv + = InitNetworks ![BS.ByteString] -- ^ Type 1: chain hashes (32 bytes each) + | InitRemoteAddr !BS.ByteString -- ^ Type 3: remote address + deriving stock (Eq, Show, Generic) + +instance NFData InitTlv + +-- | Parse init TLVs from a TLV stream. +parseInitTlvs :: TlvStream -> Either TlvError [InitTlv] +parseInitTlvs (TlvStream recs) = traverse parseOne recs + where + parseOne (TlvRecord 1 val) + | BS.length val `mod` 32 == 0 = + Right (InitNetworks (chunksOf 32 val)) + | otherwise = Left (TlvInvalidKnownType 1) + parseOne (TlvRecord 3 val) = Right (InitRemoteAddr val) + parseOne (TlvRecord t _) = Left (TlvUnknownEvenType t) + +-- | Split bytestring into chunks of given size. +chunksOf :: Int -> BS.ByteString -> [BS.ByteString] +chunksOf !n !bs + | BS.null bs = [] + | otherwise = + let (!chunk, !rest) = BS.splitAt n bs + in chunk : chunksOf n rest + +-- | Encode init TLVs to a TLV stream. +encodeInitTlvs :: [InitTlv] -> TlvStream +encodeInitTlvs = TlvStream . map toRecord + where + toRecord (InitNetworks chains) = + TlvRecord 1 (mconcat chains) + toRecord (InitRemoteAddr addr) = + TlvRecord 3 addr diff --git a/ppad-bolt1.cabal b/ppad-bolt1.cabal @@ -25,6 +25,10 @@ library -Wall exposed-modules: Lightning.Protocol.BOLT1 + Lightning.Protocol.BOLT1.Codec + Lightning.Protocol.BOLT1.Message + Lightning.Protocol.BOLT1.Prim + Lightning.Protocol.BOLT1.TLV build-depends: base >= 4.9 && < 5 , bytestring >= 0.9 && < 0.13 diff --git a/test/Main.hs b/test/Main.hs @@ -13,6 +13,9 @@ main :: IO () main = defaultMain $ testGroup "ppad-bolt1" [ bigsize_tests , primitive_tests + , signed_tests + , truncated_tests + , minsigned_tests , tlv_tests , message_tests , envelope_tests @@ -92,6 +95,166 @@ primitive_tests = testGroup "Primitives" [ decodeU64 (BS.pack [0x01, 0x02, 0x03, 0x04]) @?= Nothing ] +-- Signed integer tests --------------------------------------------------------- + +signed_tests :: TestTree +signed_tests = testGroup "Signed integers" [ + testCase "encodeS8 42" $ + encodeS8 42 @?= BS.pack [0x2a] + , testCase "encodeS8 -42" $ + encodeS8 (-42) @?= BS.pack [0xd6] + , testCase "encodeS8 127" $ + encodeS8 127 @?= BS.pack [0x7f] + , testCase "encodeS8 -128" $ + encodeS8 (-128) @?= BS.pack [0x80] + , testCase "decodeS8 42" $ + decodeS8 (BS.pack [0x2a]) @?= Just (42, "") + , testCase "decodeS8 -42" $ + decodeS8 (BS.pack [0xd6]) @?= Just (-42, "") + , testCase "encodeS16 -1" $ + encodeS16 (-1) @?= BS.pack [0xff, 0xff] + , testCase "encodeS16 32767" $ + encodeS16 32767 @?= BS.pack [0x7f, 0xff] + , testCase "encodeS16 -32768" $ + encodeS16 (-32768) @?= BS.pack [0x80, 0x00] + , testCase "decodeS16 -1" $ + decodeS16 (BS.pack [0xff, 0xff]) @?= Just (-1, "") + , testCase "encodeS32 -1" $ + encodeS32 (-1) @?= BS.pack [0xff, 0xff, 0xff, 0xff] + , testCase "encodeS32 2147483647" $ + encodeS32 2147483647 @?= BS.pack [0x7f, 0xff, 0xff, 0xff] + , testCase "encodeS32 -2147483648" $ + encodeS32 (-2147483648) @?= BS.pack [0x80, 0x00, 0x00, 0x00] + , testCase "decodeS32 -1" $ + decodeS32 (BS.pack [0xff, 0xff, 0xff, 0xff]) @?= Just (-1, "") + , testCase "encodeS64 -1" $ + encodeS64 (-1) @?= + BS.pack [0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff] + , testCase "decodeS64 -1" $ + decodeS64 (BS.pack [0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff]) @?= + Just (-1, "") + ] + +-- Truncated unsigned integer tests --------------------------------------------- + +truncated_tests :: TestTree +truncated_tests = testGroup "Truncated unsigned integers" [ + testCase "encodeTu16 0" $ + encodeTu16 0 @?= "" + , testCase "encodeTu16 1" $ + encodeTu16 1 @?= BS.pack [0x01] + , testCase "encodeTu16 255" $ + encodeTu16 255 @?= BS.pack [0xff] + , testCase "encodeTu16 256" $ + encodeTu16 256 @?= BS.pack [0x01, 0x00] + , testCase "encodeTu16 65535" $ + encodeTu16 65535 @?= BS.pack [0xff, 0xff] + , testCase "decodeTu16 0 bytes" $ + decodeTu16 0 "" @?= Just (0, "") + , testCase "decodeTu16 1 byte" $ + decodeTu16 1 (BS.pack [0x01]) @?= Just (1, "") + , testCase "decodeTu16 2 bytes" $ + decodeTu16 2 (BS.pack [0x01, 0x00]) @?= Just (256, "") + , testCase "decodeTu16 non-minimal fails" $ + decodeTu16 2 (BS.pack [0x00, 0x01]) @?= Nothing + , testCase "encodeTu32 0" $ + encodeTu32 0 @?= "" + , testCase "encodeTu32 1" $ + encodeTu32 1 @?= BS.pack [0x01] + , testCase "encodeTu32 0x010000" $ + encodeTu32 0x010000 @?= BS.pack [0x01, 0x00, 0x00] + , testCase "encodeTu32 0x01000000" $ + encodeTu32 0x01000000 @?= BS.pack [0x01, 0x00, 0x00, 0x00] + , testCase "decodeTu32 0 bytes" $ + decodeTu32 0 "" @?= Just (0, "") + , testCase "decodeTu32 3 bytes" $ + decodeTu32 3 (BS.pack [0x01, 0x00, 0x00]) @?= Just (0x010000, "") + , testCase "decodeTu32 non-minimal fails" $ + decodeTu32 3 (BS.pack [0x00, 0x01, 0x00]) @?= Nothing + , testCase "encodeTu64 0" $ + encodeTu64 0 @?= "" + , testCase "encodeTu64 0x0100000000" $ + encodeTu64 0x0100000000 @?= BS.pack [0x01, 0x00, 0x00, 0x00, 0x00] + , testCase "decodeTu64 5 bytes" $ + decodeTu64 5 (BS.pack [0x01, 0x00, 0x00, 0x00, 0x00]) @?= + Just (0x0100000000, "") + , testCase "decodeTu64 non-minimal fails" $ + decodeTu64 5 (BS.pack [0x00, 0x01, 0x00, 0x00, 0x00]) @?= Nothing + ] + +-- Minimal signed integer tests (Appendix D) ------------------------------------ + +minsigned_tests :: TestTree +minsigned_tests = testGroup "Minimal signed (Appendix D)" [ + -- Test vectors from BOLT #1 Appendix D + testCase "encode 0" $ + encodeMinSigned 0 @?= unhex "00" + , testCase "encode 42" $ + encodeMinSigned 42 @?= unhex "2a" + , testCase "encode -42" $ + encodeMinSigned (-42) @?= unhex "d6" + , testCase "encode 127" $ + encodeMinSigned 127 @?= unhex "7f" + , testCase "encode -128" $ + encodeMinSigned (-128) @?= unhex "80" + , testCase "encode 128" $ + encodeMinSigned 128 @?= unhex "0080" + , testCase "encode -129" $ + encodeMinSigned (-129) @?= unhex "ff7f" + , testCase "encode 15000" $ + encodeMinSigned 15000 @?= unhex "3a98" + , testCase "encode -15000" $ + encodeMinSigned (-15000) @?= unhex "c568" + , testCase "encode 32767" $ + encodeMinSigned 32767 @?= unhex "7fff" + , testCase "encode -32768" $ + encodeMinSigned (-32768) @?= unhex "8000" + , testCase "encode 32768" $ + encodeMinSigned 32768 @?= unhex "00008000" + , testCase "encode -32769" $ + encodeMinSigned (-32769) @?= unhex "ffff7fff" + , testCase "encode 21000000" $ + encodeMinSigned 21000000 @?= unhex "01406f40" + , testCase "encode -21000000" $ + encodeMinSigned (-21000000) @?= unhex "febf90c0" + , testCase "encode 2147483647" $ + encodeMinSigned 2147483647 @?= unhex "7fffffff" + , testCase "encode -2147483648" $ + encodeMinSigned (-2147483648) @?= unhex "80000000" + , testCase "encode 2147483648" $ + encodeMinSigned 2147483648 @?= unhex "0000000080000000" + , testCase "encode -2147483649" $ + encodeMinSigned (-2147483649) @?= unhex "ffffffff7fffffff" + , testCase "encode 500000000000" $ + encodeMinSigned 500000000000 @?= unhex "000000746a528800" + , testCase "encode -500000000000" $ + encodeMinSigned (-500000000000) @?= unhex "ffffff8b95ad7800" + , testCase "encode max int64" $ + encodeMinSigned 9223372036854775807 @?= unhex "7fffffffffffffff" + , testCase "encode min int64" $ + encodeMinSigned (-9223372036854775808) @?= unhex "8000000000000000" + -- Decode tests + , testCase "decode 1-byte 42" $ + decodeMinSigned 1 (unhex "2a") @?= Just (42, "") + , testCase "decode 1-byte -42" $ + decodeMinSigned 1 (unhex "d6") @?= Just (-42, "") + , testCase "decode 2-byte 128" $ + decodeMinSigned 2 (unhex "0080") @?= Just (128, "") + , testCase "decode 2-byte -129" $ + decodeMinSigned 2 (unhex "ff7f") @?= Just (-129, "") + , testCase "decode 4-byte 32768" $ + decodeMinSigned 4 (unhex "00008000") @?= Just (32768, "") + , testCase "decode 8-byte 2147483648" $ + decodeMinSigned 8 (unhex "0000000080000000") @?= Just (2147483648, "") + -- Minimality rejection + , testCase "decode 2-byte for 1-byte value fails" $ + decodeMinSigned 2 (unhex "0042") @?= Nothing -- 42 fits in 1 byte + , testCase "decode 4-byte for 2-byte value fails" $ + decodeMinSigned 4 (unhex "00000080") @?= Nothing -- 128 fits in 2 bytes + , testCase "decode 8-byte for 4-byte value fails" $ + decodeMinSigned 8 (unhex "0000000000008000") @?= Nothing -- 32768 fits in 4 + ] + -- TLV tests ------------------------------------------------------------------- tlv_tests :: TestTree @@ -406,6 +569,14 @@ bounds_tests = testGroup "Bounds checking" [ case encodeMessage (MsgPeerStorageVal msg) of Left EncodeLengthOverflow -> pure () other -> assertFailure $ "expected overflow: " ++ show other + , testCase "encode envelope exceeding 65535 bytes fails" $ do + -- Create a message that fits in encodeMessage but combined with + -- extension exceeds 65535 bytes total + let msg = MsgPongVal (Pong (BS.replicate 60000 0x00)) + ext = TlvStream [TlvRecord 101 (BS.replicate 10000 0x00)] + case encodeEnvelope msg (Just ext) of + Left EncodeMessageTooLarge -> pure () + other -> assertFailure $ "expected message too large: " ++ show other ] -- Property tests --------------------------------------------------------------