TLV.hs (7956B)
1 {-# OPTIONS_HADDOCK prune #-} 2 {-# LANGUAGE BangPatterns #-} 3 {-# LANGUAGE DeriveGeneric #-} 4 {-# LANGUAGE DerivingStrategies #-} 5 6 -- | 7 -- Module: Lightning.Protocol.BOLT1.TLV 8 -- Copyright: (c) 2025 Jared Tobin 9 -- License: MIT 10 -- Maintainer: Jared Tobin <jared@ppad.tech> 11 -- 12 -- TLV (Type-Length-Value) format for BOLT #1. 13 14 module Lightning.Protocol.BOLT1.TLV ( 15 -- * TLV types 16 TlvRecord(..) 17 , TlvStream 18 , unTlvStream 19 , tlvStream 20 , unsafeTlvStream 21 , TlvError(..) 22 23 -- * TLV encoding 24 , encodeTlvRecord 25 , encodeTlvStream 26 27 -- * TLV decoding 28 , decodeTlvStream 29 , decodeTlvStreamWith 30 , decodeTlvStreamRaw 31 32 -- * Init TLV types 33 , InitTlv(..) 34 , parseInitTlvs 35 , encodeInitTlvs 36 37 -- * Re-exports 38 , ChainHash 39 , chainHash 40 , unChainHash 41 ) where 42 43 import Control.DeepSeq (NFData) 44 import Control.Monad (when) 45 import qualified Data.ByteString as BS 46 import Data.Word (Word64) 47 import GHC.Generics (Generic) 48 import Lightning.Protocol.BOLT1.Prim 49 50 -- TLV types ------------------------------------------------------------------- 51 52 -- | A single TLV record. 53 data TlvRecord = TlvRecord 54 { tlvType :: {-# UNPACK #-} !Word64 55 , tlvValue :: !BS.ByteString 56 } deriving stock (Eq, Show, Generic) 57 58 instance NFData TlvRecord 59 60 -- | A TLV stream (series of TLV records). 61 newtype TlvStream = TlvStream { unTlvStream :: [TlvRecord] } 62 deriving stock (Eq, Show, Generic) 63 64 instance NFData TlvStream 65 66 -- | Smart constructor for 'TlvStream' that validates records are 67 -- strictly increasing by type. 68 -- 69 -- Returns 'Nothing' if types are not strictly increasing. 70 tlvStream :: [TlvRecord] -> Maybe TlvStream 71 tlvStream recs 72 | isStrictlyIncreasing (map tlvType recs) = Just (TlvStream recs) 73 | otherwise = Nothing 74 where 75 isStrictlyIncreasing :: [Word64] -> Bool 76 isStrictlyIncreasing [] = True 77 isStrictlyIncreasing [_] = True 78 isStrictlyIncreasing (x:y:rest) = x < y && isStrictlyIncreasing (y:rest) 79 80 -- | Unsafe constructor for 'TlvStream' that skips validation. 81 -- 82 -- Use only when ordering is already guaranteed (e.g., in decode functions). 83 unsafeTlvStream :: [TlvRecord] -> TlvStream 84 unsafeTlvStream = TlvStream 85 86 -- | TLV decoding errors. 87 data TlvError 88 = TlvNonMinimalEncoding 89 | TlvNotStrictlyIncreasing 90 | TlvLengthExceedsBounds 91 | TlvUnknownEvenType !Word64 92 | TlvInvalidKnownType !Word64 93 deriving stock (Eq, Show, Generic) 94 95 instance NFData TlvError 96 97 -- TLV encoding ---------------------------------------------------------------- 98 99 -- | Encode a TLV record. 100 encodeTlvRecord :: TlvRecord -> BS.ByteString 101 encodeTlvRecord (TlvRecord typ val) = mconcat 102 [ encodeBigSize typ 103 , encodeBigSize (fromIntegral (BS.length val)) 104 , val 105 ] 106 107 -- | Encode a TLV stream. 108 encodeTlvStream :: TlvStream -> BS.ByteString 109 encodeTlvStream (TlvStream recs) = mconcat (map encodeTlvRecord recs) 110 111 -- TLV decoding ---------------------------------------------------------------- 112 113 -- | Decode a TLV stream without any known-type validation. 114 -- 115 -- This decoder only enforces structural validity: 116 -- - Types must be strictly increasing 117 -- - Lengths must not exceed bounds 118 -- 119 -- All records are returned regardless of type. Note: this does NOT 120 -- enforce the BOLT #1 unknown-even-type rule. Use 'decodeTlvStreamWith' 121 -- with an appropriate predicate for spec-compliant parsing. 122 decodeTlvStreamRaw :: BS.ByteString -> Either TlvError TlvStream 123 decodeTlvStreamRaw = go Nothing [] 124 where 125 go :: Maybe Word64 -> [TlvRecord] -> BS.ByteString 126 -> Either TlvError TlvStream 127 go !_ !acc !bs 128 | BS.null bs = Right (unsafeTlvStream (reverse acc)) 129 go !mPrevType !acc !bs = do 130 (typ, rest1) <- maybe (Left TlvNonMinimalEncoding) Right 131 (decodeBigSize bs) 132 -- Strictly increasing check 133 case mPrevType of 134 Just prevType -> when (typ <= prevType) $ 135 Left TlvNotStrictlyIncreasing 136 Nothing -> pure () 137 (len, rest2) <- maybe (Left TlvNonMinimalEncoding) Right 138 (decodeBigSize rest1) 139 -- Length bounds check 140 when (fromIntegral len > BS.length rest2) $ 141 Left TlvLengthExceedsBounds 142 let !val = BS.take (fromIntegral len) rest2 143 !rest3 = BS.drop (fromIntegral len) rest2 144 !rec = TlvRecord typ val 145 go (Just typ) (rec : acc) rest3 146 147 -- | Decode a TLV stream with configurable known-type predicate. 148 -- 149 -- Per BOLT #1: 150 -- - Types must be strictly increasing 151 -- - Unknown even types cause failure 152 -- - Unknown odd types are skipped 153 -- 154 -- The predicate determines which types are "known" for the context. 155 decodeTlvStreamWith 156 :: (Word64 -> Bool) -- ^ Predicate: is this type known? 157 -> BS.ByteString 158 -> Either TlvError TlvStream 159 decodeTlvStreamWith isKnown = go Nothing [] 160 where 161 go :: Maybe Word64 -> [TlvRecord] -> BS.ByteString 162 -> Either TlvError TlvStream 163 go !_ !acc !bs 164 | BS.null bs = Right (unsafeTlvStream (reverse acc)) 165 go !mPrevType !acc !bs = do 166 (typ, rest1) <- maybe (Left TlvNonMinimalEncoding) Right 167 (decodeBigSize bs) 168 -- Strictly increasing check 169 case mPrevType of 170 Just prevType -> when (typ <= prevType) $ 171 Left TlvNotStrictlyIncreasing 172 Nothing -> pure () 173 (len, rest2) <- maybe (Left TlvNonMinimalEncoding) Right 174 (decodeBigSize rest1) 175 -- Length bounds check 176 when (fromIntegral len > BS.length rest2) $ 177 Left TlvLengthExceedsBounds 178 let !val = BS.take (fromIntegral len) rest2 179 !rest3 = BS.drop (fromIntegral len) rest2 180 !rec = TlvRecord typ val 181 -- Unknown type handling: even = fail, odd = skip 182 if isKnown typ 183 then go (Just typ) (rec : acc) rest3 184 else if even typ 185 then Left (TlvUnknownEvenType typ) 186 else go (Just typ) acc rest3 -- skip unknown odd 187 188 -- | Decode a TLV stream with BOLT #1 init_tlvs validation. 189 -- 190 -- This uses the default known types for init messages (1 and 3). 191 -- For other contexts, use 'decodeTlvStreamWith' with an appropriate 192 -- predicate. 193 decodeTlvStream :: BS.ByteString -> Either TlvError TlvStream 194 decodeTlvStream = decodeTlvStreamWith isInitTlvType 195 where 196 isInitTlvType :: Word64 -> Bool 197 isInitTlvType 1 = True -- networks 198 isInitTlvType 3 = True -- remote_addr 199 isInitTlvType _ = False 200 201 -- Init TLV types -------------------------------------------------------------- 202 203 -- | TLV records for init message. 204 data InitTlv 205 = InitNetworks ![ChainHash] -- ^ Type 1: chain hashes (32 bytes each) 206 | InitRemoteAddr !BS.ByteString -- ^ Type 3: remote address 207 deriving stock (Eq, Show, Generic) 208 209 instance NFData InitTlv 210 211 -- | Parse init TLVs from a TLV stream. 212 parseInitTlvs :: TlvStream -> Either TlvError [InitTlv] 213 parseInitTlvs (TlvStream recs) = traverse parseOne recs 214 where 215 parseOne (TlvRecord 1 val) 216 | BS.length val `mod` 32 == 0 = 217 Right (InitNetworks (map mkChainHash (chunksOf 32 val))) 218 | otherwise = Left (TlvInvalidKnownType 1) 219 parseOne (TlvRecord 3 val) = Right (InitRemoteAddr val) 220 parseOne (TlvRecord t _) = Left (TlvUnknownEvenType t) 221 222 -- Each chunk is exactly 32 bytes from chunksOf, so chainHash always 223 -- succeeds. We use a partial pattern match as the Nothing case is 224 -- unreachable given our chunksOf guarantee. 225 mkChainHash bs = case chainHash bs of 226 Just ch -> ch 227 Nothing -> error "parseInitTlvs: impossible - chunk is not 32 bytes" 228 229 -- | Split bytestring into chunks of given size. 230 chunksOf :: Int -> BS.ByteString -> [BS.ByteString] 231 chunksOf !n !bs 232 | BS.null bs = [] 233 | otherwise = 234 let (!chunk, !rest) = BS.splitAt n bs 235 in chunk : chunksOf n rest 236 237 -- | Encode init TLVs to a TLV stream. 238 encodeInitTlvs :: [InitTlv] -> TlvStream 239 encodeInitTlvs = unsafeTlvStream . map toRecord 240 where 241 toRecord (InitNetworks chains) = 242 TlvRecord 1 (mconcat (map unChainHash chains)) 243 toRecord (InitRemoteAddr addr) = 244 TlvRecord 3 addr