bolt9

Lightning feature flags, per BOLT #9 (docs.ppad.tech/bolt9).
git clone git://git.ppad.tech/bolt9.git
Log | Files | Refs | README | LICENSE

Validate.hs (10181B)


      1 {-# OPTIONS_HADDOCK prune #-}
      2 {-# LANGUAGE BangPatterns #-}
      3 {-# LANGUAGE DeriveGeneric #-}
      4 
      5 -- |
      6 -- Module: Lightning.Protocol.BOLT9.Validate
      7 -- Copyright: (c) 2025 Jared Tobin
      8 -- License: MIT
      9 -- Maintainer: Jared Tobin <jared@ppad.tech>
     10 --
     11 -- Validation for BOLT #9 feature vectors.
     12 
     13 module Lightning.Protocol.BOLT9.Validate (
     14     -- * Error types
     15     ValidationError(..)
     16 
     17     -- * Local validation
     18   , validateLocal
     19 
     20     -- * Remote validation
     21   , validateRemote
     22 
     23     -- * Validated construction
     24   , setFeatureForContext
     25   , validateNoBothBits
     26 
     27     -- * Helpers
     28   , highestSetBit
     29   , setBits
     30   ) where
     31 
     32 import Control.DeepSeq (NFData)
     33 import Data.ByteString (ByteString)
     34 import qualified Data.ByteString as BS
     35 import qualified Data.Bits as B
     36 import Data.Word (Word16)
     37 import GHC.Generics (Generic)
     38 import Lightning.Protocol.BOLT9.Codec
     39   (isFeatureSet, setFeature, testBit)
     40 import Lightning.Protocol.BOLT9.Features
     41 import Lightning.Protocol.BOLT9.Types
     42 
     43 -- | Validation errors for feature vectors.
     44 data ValidationError
     45   = BothBitsSet {-# UNPACK #-} !Word16 !String
     46     -- ^ Both optional and required bits are set for a feature.
     47     --   Arguments: base bit index, feature name.
     48   | MissingDependency !String !String
     49     -- ^ A feature's dependency is not set.
     50     --   Arguments: feature name, missing dependency name.
     51   | ContextNotAllowed !String !Context
     52     -- ^ A feature is not allowed in the given context.
     53     --   Arguments: feature name, context.
     54   | UnknownRequiredBit {-# UNPACK #-} !Word16
     55     -- ^ An unknown required (even) bit is set (remote validation only).
     56     --   Argument: bit index.
     57   | InvalidParity {-# UNPACK #-} !Word16 !Context
     58     -- ^ A bit has invalid parity for a channel context.
     59     --   Arguments: bit index, context (ChanAnnOdd or ChanAnnEven).
     60   deriving (Eq, Show, Generic)
     61 
     62 instance NFData ValidationError
     63 
     64 -- Local validation -----------------------------------------------------------
     65 
     66 -- | Validate a feature vector for local use (vectors we create/send).
     67 --
     68 --   Checks:
     69 --
     70 --   * No feature has both optional and required bits set
     71 --   * All set features are valid for the given context
     72 --   * All dependencies of set features are also set
     73 --   * C- context forces odd bits only, C+ forces even bits only
     74 --
     75 --   >>> import Data.Maybe (fromJust)
     76 --   >>> import Lightning.Protocol.BOLT9.Codec (setFeature)
     77 --   >>> let mpp = fromJust (featureByName "basic_mpp")
     78 --   >>> let ps = fromJust (featureByName "payment_secret")
     79 --   >>> validateLocal Init (setFeature mpp False empty)
     80 --   Left [MissingDependency "basic_mpp" "payment_secret"]
     81 --   >>> validateLocal Init (setFeature mpp False (setFeature ps False empty))
     82 --   Right ()
     83 validateLocal :: Context -> FeatureVector -> Either [ValidationError] ()
     84 validateLocal !ctx !fv =
     85   let errs = bothBitsErrors fv
     86           ++ contextErrors ctx fv
     87           ++ dependencyErrors fv
     88           ++ parityErrors ctx fv
     89   in  if null errs
     90       then Right ()
     91       else Left errs
     92 
     93 -- | Check for features with both bits set.
     94 bothBitsErrors :: FeatureVector -> [ValidationError]
     95 bothBitsErrors !fv = foldr check [] knownFeatures
     96   where
     97     check !f !acc =
     98       let !baseBit = featureBaseBit f
     99       in  if testBit baseBit fv && testBit (baseBit + 1) fv
    100           then BothBitsSet baseBit (featureName f) : acc
    101           else acc
    102 
    103 -- | Check for features not allowed in the given context.
    104 contextErrors :: Context -> FeatureVector -> [ValidationError]
    105 contextErrors !ctx !fv = foldr check [] knownFeatures
    106   where
    107     check !f !acc =
    108       let !contexts = featureContexts f
    109       in  if   isFeatureSet f fv
    110             && not (null contexts)
    111             && not (contextAllowed ctx contexts)
    112           then ContextNotAllowed (featureName f) ctx : acc
    113           else acc
    114 
    115 -- | Check if a context is allowed given a list of allowed contexts.
    116 contextAllowed :: Context -> [Context] -> Bool
    117 contextAllowed !ctx !allowed = ctx `elem` allowed || channelMatch
    118   where
    119     channelMatch = isChannelContext ctx && any isChannelContext allowed
    120 
    121 -- | Check for missing dependencies.
    122 dependencyErrors :: FeatureVector -> [ValidationError]
    123 dependencyErrors !fv = foldr check [] knownFeatures
    124   where
    125     check !f !acc =
    126       if   isFeatureSet f fv
    127       then checkDeps f (featureDependencies f) ++ acc
    128       else acc
    129 
    130     checkDeps !f = foldr (checkOneDep f) []
    131 
    132     checkOneDep !f !depName !acc =
    133       case featureByName depName of
    134         Nothing   -> acc  -- unknown dep, skip
    135         Just !dep ->
    136           if   isFeatureSet dep fv
    137           then acc
    138           else MissingDependency (featureName f) depName : acc
    139 
    140 -- | Check for parity errors in C- and C+ contexts.
    141 parityErrors :: Context -> FeatureVector -> [ValidationError]
    142 parityErrors !ctx !fv = case channelParity ctx of
    143   Nothing       -> []
    144   Just wantEven -> foldr (checkParity wantEven) [] (setBits fv)
    145   where
    146     checkParity !wantEven !bit !acc =
    147       let isEven = bit `mod` 2 == 0
    148       in  if isEven /= wantEven
    149           then InvalidParity bit ctx : acc
    150           else acc
    151 
    152 -- Validated construction -------------------------------------------------------
    153 
    154 -- | Set a feature in a vector, validating that the feature is
    155 --   allowed in the given context and has correct parity.
    156 --
    157 --   Checks:
    158 --
    159 --   * The feature's context list includes the given context
    160 --     (or is empty, meaning all contexts are allowed)
    161 --   * For 'ChanAnnOdd', only 'Optional' (odd bit) is allowed
    162 --   * For 'ChanAnnEven', only 'Required' (even bit) is allowed
    163 --
    164 --   >>> import Data.Maybe (fromJust)
    165 --   >>> let pm = fromJust (featureByName "option_payment_metadata")
    166 --   >>> setFeatureForContext Invoice pm Optional empty
    167 --   Right ...
    168 --   >>> setFeatureForContext Init pm Optional empty
    169 --   Left (ContextNotAllowed "option_payment_metadata" Init)
    170 setFeatureForContext
    171   :: Context
    172   -> Feature
    173   -> FeatureLevel
    174   -> FeatureVector
    175   -> Either ValidationError FeatureVector
    176 setFeatureForContext !ctx !f !level !fv
    177   | not (null contexts)
    178   , not (contextAllowed ctx contexts)
    179   = Left (ContextNotAllowed (featureName f) ctx)
    180   | otherwise
    181   = case channelParity ctx of
    182       Just True | level == Optional ->
    183         Left (InvalidParity targetBit ctx)
    184       Just False | level == Required ->
    185         Left (InvalidParity targetBit ctx)
    186       _ -> Right (setFeature f level fv)
    187   where
    188     !contexts  = featureContexts f
    189     !baseBit   = featureBaseBit f
    190     !targetBit = case level of
    191       Required -> baseBit
    192       Optional -> baseBit + 1
    193 
    194 -- | Validate that no feature has both its required and optional
    195 --   bits set simultaneously.
    196 --
    197 --   Returns the input vector unchanged on success.
    198 --
    199 --   >>> validateNoBothBits empty
    200 --   Right ...
    201 validateNoBothBits
    202   :: FeatureVector -> Either ValidationError FeatureVector
    203 validateNoBothBits !fv = go knownFeatures
    204   where
    205     go [] = Right fv
    206     go (f:fs) =
    207       let !baseBit = featureBaseBit f
    208       in  if   testBit baseBit fv
    209             && testBit (baseBit + 1) fv
    210           then Left (BothBitsSet baseBit (featureName f))
    211           else go fs
    212 
    213 -- Remote validation ----------------------------------------------------------
    214 
    215 -- | Validate a feature vector received from a remote peer.
    216 --
    217 --   Checks:
    218 --
    219 --   * Unknown odd (optional) bits are acceptable (ignored)
    220 --   * Unknown even (required) bits are errors
    221 --   * If both bits of a pair are set, treat as required (not an error)
    222 --   * Context restrictions still apply for known features
    223 --
    224 --   >>> import Lightning.Protocol.BOLT9.Codec (setBit)
    225 --   >>> validateRemote Init (setBit 999 empty)  -- unknown odd bit: ok
    226 --   Right ()
    227 --   >>> validateRemote Init (setBit 998 empty)  -- unknown even bit: error
    228 --   Left [UnknownRequiredBit 998]
    229 validateRemote :: Context -> FeatureVector -> Either [ValidationError] ()
    230 validateRemote !ctx !fv =
    231   let errs = unknownRequiredErrors fv
    232           ++ contextErrors ctx fv
    233           ++ parityErrors ctx fv
    234   in  if null errs
    235       then Right ()
    236       else Left errs
    237 
    238 -- | Check for unknown required bits.
    239 unknownRequiredErrors :: FeatureVector -> [ValidationError]
    240 unknownRequiredErrors !fv = foldr check [] (setBits fv)
    241   where
    242     check !bit !acc
    243       | bit `mod` 2 == 1 = acc  -- odd bit, optional, ignore
    244       | otherwise = case featureByBit bit of
    245           Just _  -> acc  -- known feature
    246           Nothing -> UnknownRequiredBit bit : acc
    247 
    248 -- Helpers --------------------------------------------------------------------
    249 
    250 -- | Find the highest set bit in a feature vector.
    251 --
    252 --   Returns 'Nothing' if the vector is empty or has no bits set.
    253 highestSetBit :: FeatureVector -> Maybe Word16
    254 highestSetBit !fv =
    255   let !bs = unFeatureVector fv
    256   in  if BS.null bs
    257       then Nothing
    258       else findHighestBit bs
    259 
    260 -- | Find the highest set bit in a non-empty ByteString.
    261 findHighestBit :: ByteString -> Maybe Word16
    262 findHighestBit !bs = go 0
    263   where
    264     !len = BS.length bs
    265 
    266     go !i
    267       | i >= len  = Nothing
    268       | otherwise =
    269           let !byte = BS.index bs i
    270           in  if byte == 0
    271               then go (i + 1)
    272               else
    273                 let !bytePos = len - 1 - i
    274                     !highBit = 7 - B.countLeadingZeros byte
    275                     !bitIdx  = fromIntegral bytePos * 8 + fromIntegral highBit
    276                 in  Just bitIdx
    277 
    278 -- | Collect all set bits in a feature vector.
    279 --
    280 --   Returns a list of bit indices in ascending order.
    281 setBits :: FeatureVector -> [Word16]
    282 setBits !fv =
    283   let !bs  = unFeatureVector fv
    284       !len = BS.length bs
    285   in  collectBits bs len 0 []
    286 
    287 -- | Collect bits from a ByteString into a list.
    288 collectBits :: ByteString -> Int -> Int -> [Word16] -> [Word16]
    289 collectBits !bs !len !i !acc
    290   | i >= len  = acc
    291   | otherwise =
    292       let !byte    = BS.index bs (len - 1 - i)
    293           !baseIdx = fromIntegral i * 8
    294           !acc'    = collectByteBits byte baseIdx acc
    295       in  collectBits bs len (i + 1) acc'
    296 
    297 -- | Collect set bits from a single byte.
    298 collectByteBits :: B.Bits a => a -> Word16 -> [Word16] -> [Word16]
    299 collectByteBits !byte !baseIdx = go 7
    300   where
    301     go !bit !acc
    302       | bit < 0        = acc
    303       | B.testBit byte bit = go (bit - 1) ((baseIdx + fromIntegral bit) : acc)
    304       | otherwise          = go (bit - 1) acc