Validate.hs (10181B)
1 {-# OPTIONS_HADDOCK prune #-} 2 {-# LANGUAGE BangPatterns #-} 3 {-# LANGUAGE DeriveGeneric #-} 4 5 -- | 6 -- Module: Lightning.Protocol.BOLT9.Validate 7 -- Copyright: (c) 2025 Jared Tobin 8 -- License: MIT 9 -- Maintainer: Jared Tobin <jared@ppad.tech> 10 -- 11 -- Validation for BOLT #9 feature vectors. 12 13 module Lightning.Protocol.BOLT9.Validate ( 14 -- * Error types 15 ValidationError(..) 16 17 -- * Local validation 18 , validateLocal 19 20 -- * Remote validation 21 , validateRemote 22 23 -- * Validated construction 24 , setFeatureForContext 25 , validateNoBothBits 26 27 -- * Helpers 28 , highestSetBit 29 , setBits 30 ) where 31 32 import Control.DeepSeq (NFData) 33 import Data.ByteString (ByteString) 34 import qualified Data.ByteString as BS 35 import qualified Data.Bits as B 36 import Data.Word (Word16) 37 import GHC.Generics (Generic) 38 import Lightning.Protocol.BOLT9.Codec 39 (isFeatureSet, setFeature, testBit) 40 import Lightning.Protocol.BOLT9.Features 41 import Lightning.Protocol.BOLT9.Types 42 43 -- | Validation errors for feature vectors. 44 data ValidationError 45 = BothBitsSet {-# UNPACK #-} !Word16 !String 46 -- ^ Both optional and required bits are set for a feature. 47 -- Arguments: base bit index, feature name. 48 | MissingDependency !String !String 49 -- ^ A feature's dependency is not set. 50 -- Arguments: feature name, missing dependency name. 51 | ContextNotAllowed !String !Context 52 -- ^ A feature is not allowed in the given context. 53 -- Arguments: feature name, context. 54 | UnknownRequiredBit {-# UNPACK #-} !Word16 55 -- ^ An unknown required (even) bit is set (remote validation only). 56 -- Argument: bit index. 57 | InvalidParity {-# UNPACK #-} !Word16 !Context 58 -- ^ A bit has invalid parity for a channel context. 59 -- Arguments: bit index, context (ChanAnnOdd or ChanAnnEven). 60 deriving (Eq, Show, Generic) 61 62 instance NFData ValidationError 63 64 -- Local validation ----------------------------------------------------------- 65 66 -- | Validate a feature vector for local use (vectors we create/send). 67 -- 68 -- Checks: 69 -- 70 -- * No feature has both optional and required bits set 71 -- * All set features are valid for the given context 72 -- * All dependencies of set features are also set 73 -- * C- context forces odd bits only, C+ forces even bits only 74 -- 75 -- >>> import Data.Maybe (fromJust) 76 -- >>> import Lightning.Protocol.BOLT9.Codec (setFeature) 77 -- >>> let mpp = fromJust (featureByName "basic_mpp") 78 -- >>> let ps = fromJust (featureByName "payment_secret") 79 -- >>> validateLocal Init (setFeature mpp False empty) 80 -- Left [MissingDependency "basic_mpp" "payment_secret"] 81 -- >>> validateLocal Init (setFeature mpp False (setFeature ps False empty)) 82 -- Right () 83 validateLocal :: Context -> FeatureVector -> Either [ValidationError] () 84 validateLocal !ctx !fv = 85 let errs = bothBitsErrors fv 86 ++ contextErrors ctx fv 87 ++ dependencyErrors fv 88 ++ parityErrors ctx fv 89 in if null errs 90 then Right () 91 else Left errs 92 93 -- | Check for features with both bits set. 94 bothBitsErrors :: FeatureVector -> [ValidationError] 95 bothBitsErrors !fv = foldr check [] knownFeatures 96 where 97 check !f !acc = 98 let !baseBit = featureBaseBit f 99 in if testBit baseBit fv && testBit (baseBit + 1) fv 100 then BothBitsSet baseBit (featureName f) : acc 101 else acc 102 103 -- | Check for features not allowed in the given context. 104 contextErrors :: Context -> FeatureVector -> [ValidationError] 105 contextErrors !ctx !fv = foldr check [] knownFeatures 106 where 107 check !f !acc = 108 let !contexts = featureContexts f 109 in if isFeatureSet f fv 110 && not (null contexts) 111 && not (contextAllowed ctx contexts) 112 then ContextNotAllowed (featureName f) ctx : acc 113 else acc 114 115 -- | Check if a context is allowed given a list of allowed contexts. 116 contextAllowed :: Context -> [Context] -> Bool 117 contextAllowed !ctx !allowed = ctx `elem` allowed || channelMatch 118 where 119 channelMatch = isChannelContext ctx && any isChannelContext allowed 120 121 -- | Check for missing dependencies. 122 dependencyErrors :: FeatureVector -> [ValidationError] 123 dependencyErrors !fv = foldr check [] knownFeatures 124 where 125 check !f !acc = 126 if isFeatureSet f fv 127 then checkDeps f (featureDependencies f) ++ acc 128 else acc 129 130 checkDeps !f = foldr (checkOneDep f) [] 131 132 checkOneDep !f !depName !acc = 133 case featureByName depName of 134 Nothing -> acc -- unknown dep, skip 135 Just !dep -> 136 if isFeatureSet dep fv 137 then acc 138 else MissingDependency (featureName f) depName : acc 139 140 -- | Check for parity errors in C- and C+ contexts. 141 parityErrors :: Context -> FeatureVector -> [ValidationError] 142 parityErrors !ctx !fv = case channelParity ctx of 143 Nothing -> [] 144 Just wantEven -> foldr (checkParity wantEven) [] (setBits fv) 145 where 146 checkParity !wantEven !bit !acc = 147 let isEven = bit `mod` 2 == 0 148 in if isEven /= wantEven 149 then InvalidParity bit ctx : acc 150 else acc 151 152 -- Validated construction ------------------------------------------------------- 153 154 -- | Set a feature in a vector, validating that the feature is 155 -- allowed in the given context and has correct parity. 156 -- 157 -- Checks: 158 -- 159 -- * The feature's context list includes the given context 160 -- (or is empty, meaning all contexts are allowed) 161 -- * For 'ChanAnnOdd', only 'Optional' (odd bit) is allowed 162 -- * For 'ChanAnnEven', only 'Required' (even bit) is allowed 163 -- 164 -- >>> import Data.Maybe (fromJust) 165 -- >>> let pm = fromJust (featureByName "option_payment_metadata") 166 -- >>> setFeatureForContext Invoice pm Optional empty 167 -- Right ... 168 -- >>> setFeatureForContext Init pm Optional empty 169 -- Left (ContextNotAllowed "option_payment_metadata" Init) 170 setFeatureForContext 171 :: Context 172 -> Feature 173 -> FeatureLevel 174 -> FeatureVector 175 -> Either ValidationError FeatureVector 176 setFeatureForContext !ctx !f !level !fv 177 | not (null contexts) 178 , not (contextAllowed ctx contexts) 179 = Left (ContextNotAllowed (featureName f) ctx) 180 | otherwise 181 = case channelParity ctx of 182 Just True | level == Optional -> 183 Left (InvalidParity targetBit ctx) 184 Just False | level == Required -> 185 Left (InvalidParity targetBit ctx) 186 _ -> Right (setFeature f level fv) 187 where 188 !contexts = featureContexts f 189 !baseBit = featureBaseBit f 190 !targetBit = case level of 191 Required -> baseBit 192 Optional -> baseBit + 1 193 194 -- | Validate that no feature has both its required and optional 195 -- bits set simultaneously. 196 -- 197 -- Returns the input vector unchanged on success. 198 -- 199 -- >>> validateNoBothBits empty 200 -- Right ... 201 validateNoBothBits 202 :: FeatureVector -> Either ValidationError FeatureVector 203 validateNoBothBits !fv = go knownFeatures 204 where 205 go [] = Right fv 206 go (f:fs) = 207 let !baseBit = featureBaseBit f 208 in if testBit baseBit fv 209 && testBit (baseBit + 1) fv 210 then Left (BothBitsSet baseBit (featureName f)) 211 else go fs 212 213 -- Remote validation ---------------------------------------------------------- 214 215 -- | Validate a feature vector received from a remote peer. 216 -- 217 -- Checks: 218 -- 219 -- * Unknown odd (optional) bits are acceptable (ignored) 220 -- * Unknown even (required) bits are errors 221 -- * If both bits of a pair are set, treat as required (not an error) 222 -- * Context restrictions still apply for known features 223 -- 224 -- >>> import Lightning.Protocol.BOLT9.Codec (setBit) 225 -- >>> validateRemote Init (setBit 999 empty) -- unknown odd bit: ok 226 -- Right () 227 -- >>> validateRemote Init (setBit 998 empty) -- unknown even bit: error 228 -- Left [UnknownRequiredBit 998] 229 validateRemote :: Context -> FeatureVector -> Either [ValidationError] () 230 validateRemote !ctx !fv = 231 let errs = unknownRequiredErrors fv 232 ++ contextErrors ctx fv 233 ++ parityErrors ctx fv 234 in if null errs 235 then Right () 236 else Left errs 237 238 -- | Check for unknown required bits. 239 unknownRequiredErrors :: FeatureVector -> [ValidationError] 240 unknownRequiredErrors !fv = foldr check [] (setBits fv) 241 where 242 check !bit !acc 243 | bit `mod` 2 == 1 = acc -- odd bit, optional, ignore 244 | otherwise = case featureByBit bit of 245 Just _ -> acc -- known feature 246 Nothing -> UnknownRequiredBit bit : acc 247 248 -- Helpers -------------------------------------------------------------------- 249 250 -- | Find the highest set bit in a feature vector. 251 -- 252 -- Returns 'Nothing' if the vector is empty or has no bits set. 253 highestSetBit :: FeatureVector -> Maybe Word16 254 highestSetBit !fv = 255 let !bs = unFeatureVector fv 256 in if BS.null bs 257 then Nothing 258 else findHighestBit bs 259 260 -- | Find the highest set bit in a non-empty ByteString. 261 findHighestBit :: ByteString -> Maybe Word16 262 findHighestBit !bs = go 0 263 where 264 !len = BS.length bs 265 266 go !i 267 | i >= len = Nothing 268 | otherwise = 269 let !byte = BS.index bs i 270 in if byte == 0 271 then go (i + 1) 272 else 273 let !bytePos = len - 1 - i 274 !highBit = 7 - B.countLeadingZeros byte 275 !bitIdx = fromIntegral bytePos * 8 + fromIntegral highBit 276 in Just bitIdx 277 278 -- | Collect all set bits in a feature vector. 279 -- 280 -- Returns a list of bit indices in ascending order. 281 setBits :: FeatureVector -> [Word16] 282 setBits !fv = 283 let !bs = unFeatureVector fv 284 !len = BS.length bs 285 in collectBits bs len 0 [] 286 287 -- | Collect bits from a ByteString into a list. 288 collectBits :: ByteString -> Int -> Int -> [Word16] -> [Word16] 289 collectBits !bs !len !i !acc 290 | i >= len = acc 291 | otherwise = 292 let !byte = BS.index bs (len - 1 - i) 293 !baseIdx = fromIntegral i * 8 294 !acc' = collectByteBits byte baseIdx acc 295 in collectBits bs len (i + 1) acc' 296 297 -- | Collect set bits from a single byte. 298 collectByteBits :: B.Bits a => a -> Word16 -> [Word16] -> [Word16] 299 collectByteBits !byte !baseIdx = go 7 300 where 301 go !bit !acc 302 | bit < 0 = acc 303 | B.testBit byte bit = go (bit - 1) ((baseIdx + fromIntegral bit) : acc) 304 | otherwise = go (bit - 1) acc