bolt1

Base Lightning protocol, per BOLT #1 (docs.ppad.tech/bolt1).
git clone git://git.ppad.tech/bolt1.git
Log | Files | Refs | README | LICENSE

TLV.hs (7956B)


      1 {-# OPTIONS_HADDOCK prune #-}
      2 {-# LANGUAGE BangPatterns #-}
      3 {-# LANGUAGE DeriveGeneric #-}
      4 {-# LANGUAGE DerivingStrategies #-}
      5 
      6 -- |
      7 -- Module: Lightning.Protocol.BOLT1.TLV
      8 -- Copyright: (c) 2025 Jared Tobin
      9 -- License: MIT
     10 -- Maintainer: Jared Tobin <jared@ppad.tech>
     11 --
     12 -- TLV (Type-Length-Value) format for BOLT #1.
     13 
     14 module Lightning.Protocol.BOLT1.TLV (
     15   -- * TLV types
     16     TlvRecord(..)
     17   , TlvStream
     18   , unTlvStream
     19   , tlvStream
     20   , unsafeTlvStream
     21   , TlvError(..)
     22 
     23   -- * TLV encoding
     24   , encodeTlvRecord
     25   , encodeTlvStream
     26 
     27   -- * TLV decoding
     28   , decodeTlvStream
     29   , decodeTlvStreamWith
     30   , decodeTlvStreamRaw
     31 
     32   -- * Init TLV types
     33   , InitTlv(..)
     34   , parseInitTlvs
     35   , encodeInitTlvs
     36 
     37   -- * Re-exports
     38   , ChainHash
     39   , chainHash
     40   , unChainHash
     41   ) where
     42 
     43 import Control.DeepSeq (NFData)
     44 import Control.Monad (when)
     45 import qualified Data.ByteString as BS
     46 import Data.Word (Word64)
     47 import GHC.Generics (Generic)
     48 import Lightning.Protocol.BOLT1.Prim
     49 
     50 -- TLV types -------------------------------------------------------------------
     51 
     52 -- | A single TLV record.
     53 data TlvRecord = TlvRecord
     54   { tlvType   :: {-# UNPACK #-} !Word64
     55   , tlvValue  :: !BS.ByteString
     56   } deriving stock (Eq, Show, Generic)
     57 
     58 instance NFData TlvRecord
     59 
     60 -- | A TLV stream (series of TLV records).
     61 newtype TlvStream = TlvStream { unTlvStream :: [TlvRecord] }
     62   deriving stock (Eq, Show, Generic)
     63 
     64 instance NFData TlvStream
     65 
     66 -- | Smart constructor for 'TlvStream' that validates records are
     67 -- strictly increasing by type.
     68 --
     69 -- Returns 'Nothing' if types are not strictly increasing.
     70 tlvStream :: [TlvRecord] -> Maybe TlvStream
     71 tlvStream recs
     72   | isStrictlyIncreasing (map tlvType recs) = Just (TlvStream recs)
     73   | otherwise = Nothing
     74   where
     75     isStrictlyIncreasing :: [Word64] -> Bool
     76     isStrictlyIncreasing [] = True
     77     isStrictlyIncreasing [_] = True
     78     isStrictlyIncreasing (x:y:rest) = x < y && isStrictlyIncreasing (y:rest)
     79 
     80 -- | Unsafe constructor for 'TlvStream' that skips validation.
     81 --
     82 -- Use only when ordering is already guaranteed (e.g., in decode functions).
     83 unsafeTlvStream :: [TlvRecord] -> TlvStream
     84 unsafeTlvStream = TlvStream
     85 
     86 -- | TLV decoding errors.
     87 data TlvError
     88   = TlvNonMinimalEncoding
     89   | TlvNotStrictlyIncreasing
     90   | TlvLengthExceedsBounds
     91   | TlvUnknownEvenType !Word64
     92   | TlvInvalidKnownType !Word64
     93   deriving stock (Eq, Show, Generic)
     94 
     95 instance NFData TlvError
     96 
     97 -- TLV encoding ----------------------------------------------------------------
     98 
     99 -- | Encode a TLV record.
    100 encodeTlvRecord :: TlvRecord -> BS.ByteString
    101 encodeTlvRecord (TlvRecord typ val) = mconcat
    102   [ encodeBigSize typ
    103   , encodeBigSize (fromIntegral (BS.length val))
    104   , val
    105   ]
    106 
    107 -- | Encode a TLV stream.
    108 encodeTlvStream :: TlvStream -> BS.ByteString
    109 encodeTlvStream (TlvStream recs) = mconcat (map encodeTlvRecord recs)
    110 
    111 -- TLV decoding ----------------------------------------------------------------
    112 
    113 -- | Decode a TLV stream without any known-type validation.
    114 --
    115 -- This decoder only enforces structural validity:
    116 -- - Types must be strictly increasing
    117 -- - Lengths must not exceed bounds
    118 --
    119 -- All records are returned regardless of type. Note: this does NOT
    120 -- enforce the BOLT #1 unknown-even-type rule. Use 'decodeTlvStreamWith'
    121 -- with an appropriate predicate for spec-compliant parsing.
    122 decodeTlvStreamRaw :: BS.ByteString -> Either TlvError TlvStream
    123 decodeTlvStreamRaw = go Nothing []
    124   where
    125     go :: Maybe Word64 -> [TlvRecord] -> BS.ByteString
    126        -> Either TlvError TlvStream
    127     go !_ !acc !bs
    128       | BS.null bs = Right (unsafeTlvStream (reverse acc))
    129     go !mPrevType !acc !bs = do
    130       (typ, rest1) <- maybe (Left TlvNonMinimalEncoding) Right
    131                         (decodeBigSize bs)
    132       -- Strictly increasing check
    133       case mPrevType of
    134         Just prevType -> when (typ <= prevType) $
    135           Left TlvNotStrictlyIncreasing
    136         Nothing -> pure ()
    137       (len, rest2) <- maybe (Left TlvNonMinimalEncoding) Right
    138                         (decodeBigSize rest1)
    139       -- Length bounds check
    140       when (fromIntegral len > BS.length rest2) $
    141         Left TlvLengthExceedsBounds
    142       let !val = BS.take (fromIntegral len) rest2
    143           !rest3 = BS.drop (fromIntegral len) rest2
    144           !rec = TlvRecord typ val
    145       go (Just typ) (rec : acc) rest3
    146 
    147 -- | Decode a TLV stream with configurable known-type predicate.
    148 --
    149 -- Per BOLT #1:
    150 -- - Types must be strictly increasing
    151 -- - Unknown even types cause failure
    152 -- - Unknown odd types are skipped
    153 --
    154 -- The predicate determines which types are "known" for the context.
    155 decodeTlvStreamWith
    156   :: (Word64 -> Bool)  -- ^ Predicate: is this type known?
    157   -> BS.ByteString
    158   -> Either TlvError TlvStream
    159 decodeTlvStreamWith isKnown = go Nothing []
    160   where
    161     go :: Maybe Word64 -> [TlvRecord] -> BS.ByteString
    162        -> Either TlvError TlvStream
    163     go !_ !acc !bs
    164       | BS.null bs = Right (unsafeTlvStream (reverse acc))
    165     go !mPrevType !acc !bs = do
    166       (typ, rest1) <- maybe (Left TlvNonMinimalEncoding) Right
    167                         (decodeBigSize bs)
    168       -- Strictly increasing check
    169       case mPrevType of
    170         Just prevType -> when (typ <= prevType) $
    171           Left TlvNotStrictlyIncreasing
    172         Nothing -> pure ()
    173       (len, rest2) <- maybe (Left TlvNonMinimalEncoding) Right
    174                         (decodeBigSize rest1)
    175       -- Length bounds check
    176       when (fromIntegral len > BS.length rest2) $
    177         Left TlvLengthExceedsBounds
    178       let !val = BS.take (fromIntegral len) rest2
    179           !rest3 = BS.drop (fromIntegral len) rest2
    180           !rec = TlvRecord typ val
    181       -- Unknown type handling: even = fail, odd = skip
    182       if isKnown typ
    183         then go (Just typ) (rec : acc) rest3
    184         else if even typ
    185           then Left (TlvUnknownEvenType typ)
    186           else go (Just typ) acc rest3  -- skip unknown odd
    187 
    188 -- | Decode a TLV stream with BOLT #1 init_tlvs validation.
    189 --
    190 -- This uses the default known types for init messages (1 and 3).
    191 -- For other contexts, use 'decodeTlvStreamWith' with an appropriate
    192 -- predicate.
    193 decodeTlvStream :: BS.ByteString -> Either TlvError TlvStream
    194 decodeTlvStream = decodeTlvStreamWith isInitTlvType
    195   where
    196     isInitTlvType :: Word64 -> Bool
    197     isInitTlvType 1 = True  -- networks
    198     isInitTlvType 3 = True  -- remote_addr
    199     isInitTlvType _ = False
    200 
    201 -- Init TLV types --------------------------------------------------------------
    202 
    203 -- | TLV records for init message.
    204 data InitTlv
    205   = InitNetworks ![ChainHash]      -- ^ Type 1: chain hashes (32 bytes each)
    206   | InitRemoteAddr !BS.ByteString  -- ^ Type 3: remote address
    207   deriving stock (Eq, Show, Generic)
    208 
    209 instance NFData InitTlv
    210 
    211 -- | Parse init TLVs from a TLV stream.
    212 parseInitTlvs :: TlvStream -> Either TlvError [InitTlv]
    213 parseInitTlvs (TlvStream recs) = traverse parseOne recs
    214   where
    215     parseOne (TlvRecord 1 val)
    216       | BS.length val `mod` 32 == 0 =
    217           Right (InitNetworks (map mkChainHash (chunksOf 32 val)))
    218       | otherwise = Left (TlvInvalidKnownType 1)
    219     parseOne (TlvRecord 3 val) = Right (InitRemoteAddr val)
    220     parseOne (TlvRecord t _) = Left (TlvUnknownEvenType t)
    221 
    222     -- Each chunk is exactly 32 bytes from chunksOf, so chainHash always
    223     -- succeeds. We use a partial pattern match as the Nothing case is
    224     -- unreachable given our chunksOf guarantee.
    225     mkChainHash bs = case chainHash bs of
    226       Just ch -> ch
    227       Nothing -> error "parseInitTlvs: impossible - chunk is not 32 bytes"
    228 
    229 -- | Split bytestring into chunks of given size.
    230 chunksOf :: Int -> BS.ByteString -> [BS.ByteString]
    231 chunksOf !n !bs
    232   | BS.null bs = []
    233   | otherwise =
    234       let (!chunk, !rest) = BS.splitAt n bs
    235       in  chunk : chunksOf n rest
    236 
    237 -- | Encode init TLVs to a TLV stream.
    238 encodeInitTlvs :: [InitTlv] -> TlvStream
    239 encodeInitTlvs = unsafeTlvStream . map toRecord
    240   where
    241     toRecord (InitNetworks chains) =
    242       TlvRecord 1 (mconcat (map unChainHash chains))
    243     toRecord (InitRemoteAddr addr) =
    244       TlvRecord 3 addr