commit f9a9fa2f403dbe60e9f7002d166bd3e3f906c171
parent cecfa8c7a01795ebd39916a4c0b3e03e33b00c6c
Author: Jared Tobin <jared@jtobin.io>
Date: Sun, 25 Jan 2026 16:20:17 +0400
Refactor: improve type safety and reduce duplication
- Add FeatureLevel type (Required | Optional) to replace Bool in
setFeature, hasFeature, and listFeatures APIs
- Add isFeatureSet helper to abstract common "either bit set" pattern,
used in 3 places in Validate.hs
- Use Map/IntMap for O(log n) feature lookups instead of linear search
- Replace custom countLeadingZeros with Data.Bits.countLeadingZeros
- Add containers dependency (GHC boot library)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat:
6 files changed, 109 insertions(+), 56 deletions(-)
diff --git a/lib/Lightning/Protocol/BOLT9.hs b/lib/Lightning/Protocol/BOLT9.hs
@@ -31,9 +31,9 @@
--
-- >>> import Lightning.Protocol.BOLT9
-- >>> let Just mpp = featureByName "basic_mpp"
--- >>> let fv = setFeature mpp False empty -- optional support
+-- >>> let fv = setFeature mpp Optional empty
-- >>> hasFeature mpp fv
--- Just False
+-- Just Optional
--
-- Validate a feature vector for a specific context:
--
@@ -43,7 +43,7 @@
-- Fix by adding the dependency:
--
-- >>> let Just ps = featureByName "payment_secret"
--- >>> let fv' = setFeature ps False (setFeature mpp False empty)
+-- >>> let fv' = setFeature ps Optional (setFeature mpp Optional empty)
-- >>> validateLocal Init fv'
-- Right ()
--
@@ -70,6 +70,10 @@ module Lightning.Protocol.BOLT9 (
, unBitIndex
, bitIndex
+ -- * Required/optional level
+ -- | Whether a feature is set as required or optional.
+ , FeatureLevel(..)
+
-- * Required/optional bits
-- | Type-safe wrappers ensuring correct parity.
, RequiredBit
@@ -114,6 +118,7 @@ module Lightning.Protocol.BOLT9 (
-- | High-level operations using 'Feature' values.
, setFeature
, hasFeature
+ , isFeatureSet
, listFeatures
-- * Validation
diff --git a/lib/Lightning/Protocol/BOLT9/Codec.hs b/lib/Lightning/Protocol/BOLT9/Codec.hs
@@ -22,6 +22,7 @@ module Lightning.Protocol.BOLT9.Codec (
-- * Feature operations
, setFeature
, hasFeature
+ , isFeatureSet
, listFeatures
) where
@@ -30,6 +31,15 @@ import qualified Data.ByteString as BS
import Data.Word (Word16)
import Lightning.Protocol.BOLT9.Features
import Lightning.Protocol.BOLT9.Types
+ ( FeatureLevel(..)
+ , FeatureVector
+ , bitIndex
+ , clear
+ , fromByteString
+ , member
+ , set
+ , unFeatureVector
+ )
-- Parsing and rendering ------------------------------------------------------
@@ -76,63 +86,78 @@ testBit !idx = member (bitIndex idx)
-- Feature operations ---------------------------------------------------------
--- | Set or clear a feature's bit.
+-- | Set a feature's bit at the given level.
--
--- If the Bool is True, sets the required (even) bit.
--- If the Bool is False, sets the optional (odd) bit.
+-- 'Required' sets the even bit, 'Optional' sets the odd bit.
--
-- >>> import Data.Maybe (fromJust)
-- >>> let mpp = fromJust (featureByName "basic_mpp")
--- >>> setFeature mpp False empty -- set optional bit (17)
+-- >>> setFeature mpp Optional empty -- set optional bit (17)
-- FeatureVector {unFeatureVector = "\STX"}
--- >>> setFeature mpp True empty -- set required bit (16)
+-- >>> setFeature mpp Required empty -- set required bit (16)
-- FeatureVector {unFeatureVector = "\SOH"}
-setFeature :: Feature -> Bool -> FeatureVector -> FeatureVector
-setFeature !f !required = setBit targetBit
+setFeature :: Feature -> FeatureLevel -> FeatureVector -> FeatureVector
+setFeature !f !level = setBit targetBit
where
!baseBit = featureBaseBit f
- !targetBit = if required then baseBit else baseBit + 1
+ !targetBit = case level of
+ Required -> baseBit
+ Optional -> baseBit + 1
{-# INLINE setFeature #-}
-- | Check if a feature is set in the vector.
--
-- Returns:
--
--- * @Just True@ if the required (even) bit is set
--- * @Just False@ if the optional (odd) bit is set (and required is not)
+-- * @Just Required@ if the required (even) bit is set
+-- * @Just Optional@ if the optional (odd) bit is set (and required is not)
-- * @Nothing@ if neither bit is set
--
-- >>> import Data.Maybe (fromJust)
-- >>> let mpp = fromJust (featureByName "basic_mpp")
--- >>> hasFeature mpp (setFeature mpp False empty)
--- Just False
--- >>> hasFeature mpp (setFeature mpp True empty)
--- Just True
+-- >>> hasFeature mpp (setFeature mpp Optional empty)
+-- Just Optional
+-- >>> hasFeature mpp (setFeature mpp Required empty)
+-- Just Required
-- >>> hasFeature mpp empty
-- Nothing
-hasFeature :: Feature -> FeatureVector -> Maybe Bool
+hasFeature :: Feature -> FeatureVector -> Maybe FeatureLevel
hasFeature !f !fv
- | testBit baseBit fv = Just True -- required
- | testBit (baseBit + 1) fv = Just False -- optional
+ | testBit baseBit fv = Just Required
+ | testBit (baseBit + 1) fv = Just Optional
| otherwise = Nothing
where
!baseBit = featureBaseBit f
{-# INLINE hasFeature #-}
+-- | Check if either bit of a feature is set in the vector.
+--
+-- >>> import Data.Maybe (fromJust)
+-- >>> let mpp = fromJust (featureByName "basic_mpp")
+-- >>> isFeatureSet mpp (setFeature mpp Optional empty)
+-- True
+-- >>> isFeatureSet mpp empty
+-- False
+isFeatureSet :: Feature -> FeatureVector -> Bool
+isFeatureSet !f !fv =
+ let !baseBit = featureBaseBit f
+ in testBit baseBit fv || testBit (baseBit + 1) fv
+{-# INLINE isFeatureSet #-}
+
-- | List all known features that are set in the vector.
--
--- Returns pairs of (Feature, Bool) where the Bool indicates if the
--- required (even) bit is set (True) or the optional (odd) bit (False).
+-- Returns pairs of (Feature, FeatureLevel) indicating whether each
+-- feature is set as required or optional.
--
-- >>> import Data.Maybe (fromJust)
-- >>> let mpp = fromJust (featureByName "basic_mpp")
-- >>> let ps = fromJust (featureByName "payment_secret")
--- >>> let fv = setFeature mpp False (setFeature ps True empty)
--- >>> map (\(f, r) -> (featureName f, r)) (listFeatures fv)
--- [("payment_secret",True),("basic_mpp",False)]
-listFeatures :: FeatureVector -> [(Feature, Bool)]
+-- >>> let fv = setFeature mpp Optional (setFeature ps Required empty)
+-- >>> map (\(f, l) -> (featureName f, l)) (listFeatures fv)
+-- [("payment_secret",Required),("basic_mpp",Optional)]
+listFeatures :: FeatureVector -> [(Feature, FeatureLevel)]
listFeatures !fv = foldr check [] knownFeatures
where
check !f !acc = case hasFeature f fv of
- Just isReq -> (f, isReq) : acc
+ Just level -> (f, level) : acc
Nothing -> acc
diff --git a/lib/Lightning/Protocol/BOLT9/Features.hs b/lib/Lightning/Protocol/BOLT9/Features.hs
@@ -24,7 +24,10 @@ module Lightning.Protocol.BOLT9.Features (
) where
import Control.DeepSeq (NFData)
-import Data.List (find)
+import Data.IntMap.Strict (IntMap)
+import qualified Data.IntMap.Strict as IM
+import Data.Map.Strict (Map)
+import qualified Data.Map.Strict as M
import Data.Word (Word16)
import GHC.Generics (Generic)
import Lightning.Protocol.BOLT9.Types (Context(..))
@@ -90,8 +93,9 @@ knownFeatures = [
-- Nothing
featureByBit :: Word16 -> Maybe Feature
featureByBit !bit =
- let baseBit = bit - (bit `mod` 2) -- round down to even
- in find (\f -> featureBaseBit f == baseBit) knownFeatures
+ let !baseBit = fromIntegral bit - (fromIntegral bit `mod` 2)
+ in IM.lookup baseBit featuresByBit
+{-# INLINE featureByBit #-}
-- | Look up a feature by its canonical name.
--
@@ -100,4 +104,17 @@ featureByBit !bit =
-- >>> featureByName "nonexistent"
-- Nothing
featureByName :: String -> Maybe Feature
-featureByName !name = find (\f -> featureName f == name) knownFeatures
+featureByName !name = M.lookup name featuresByName
+{-# INLINE featureByName #-}
+
+-- Lookup tables -------------------------------------------------------------
+
+-- | Features indexed by base bit (even bit number).
+featuresByBit :: IntMap Feature
+featuresByBit = IM.fromList
+ [(fromIntegral (featureBaseBit f), f) | f <- knownFeatures]
+
+-- | Features indexed by canonical name.
+featuresByName :: Map String Feature
+featuresByName = M.fromList
+ [(featureName f, f) | f <- knownFeatures]
diff --git a/lib/Lightning/Protocol/BOLT9/Types.hs b/lib/Lightning/Protocol/BOLT9/Types.hs
@@ -21,6 +21,9 @@ module Lightning.Protocol.BOLT9.Types (
, unBitIndex
, bitIndex
+ -- * Required/optional level
+ , FeatureLevel(..)
+
-- * Required/optional bits
, RequiredBit
, unRequiredBit
@@ -93,6 +96,19 @@ channelParity ChanAnnEven = Just True -- even
channelParity _ = Nothing
{-# INLINE channelParity #-}
+-- FeatureLevel -------------------------------------------------------------
+
+-- | Whether a feature is set as required or optional.
+--
+-- Per BOLT #9, each feature has a pair of bits: the even bit indicates
+-- required (compulsory) support, the odd bit indicates optional support.
+data FeatureLevel
+ = Required -- ^ The feature is required (even bit set)
+ | Optional -- ^ The feature is optional (odd bit set)
+ deriving (Eq, Ord, Show, Generic)
+
+instance NFData FeatureLevel
+
-- BitIndex -----------------------------------------------------------------
-- | A bit index into a feature vector. Bit 0 is the least significant bit.
diff --git a/lib/Lightning/Protocol/BOLT9/Validate.hs b/lib/Lightning/Protocol/BOLT9/Validate.hs
@@ -31,7 +31,7 @@ import qualified Data.ByteString as BS
import qualified Data.Bits as B
import Data.Word (Word16)
import GHC.Generics (Generic)
-import Lightning.Protocol.BOLT9.Codec (testBit)
+import Lightning.Protocol.BOLT9.Codec (isFeatureSet, testBit)
import Lightning.Protocol.BOLT9.Features
import Lightning.Protocol.BOLT9.Types
@@ -100,10 +100,10 @@ contextErrors :: Context -> FeatureVector -> [ValidationError]
contextErrors !ctx !fv = foldr check [] knownFeatures
where
check !f !acc =
- let !baseBit = featureBaseBit f
- !contexts = featureContexts f
- !isSet = testBit baseBit fv || testBit (baseBit + 1) fv
- in if isSet && not (null contexts) && not (contextAllowed ctx contexts)
+ let !contexts = featureContexts f
+ in if isFeatureSet f fv
+ && not (null contexts)
+ && not (contextAllowed ctx contexts)
then ContextNotAllowed (featureName f) ctx : acc
else acc
@@ -118,11 +118,9 @@ dependencyErrors :: FeatureVector -> [ValidationError]
dependencyErrors !fv = foldr check [] knownFeatures
where
check !f !acc =
- let !baseBit = featureBaseBit f
- !isSet = testBit baseBit fv || testBit (baseBit + 1) fv
- in if isSet
- then checkDeps f (featureDependencies f) ++ acc
- else acc
+ if isFeatureSet f fv
+ then checkDeps f (featureDependencies f) ++ acc
+ else acc
checkDeps !f = foldr (checkOneDep f) []
@@ -130,10 +128,9 @@ dependencyErrors !fv = foldr check [] knownFeatures
case featureByName depName of
Nothing -> acc -- unknown dep, skip
Just !dep ->
- let !depBase = featureBaseBit dep
- in if testBit depBase fv || testBit (depBase + 1) fv
- then acc
- else MissingDependency (featureName f) depName : acc
+ if isFeatureSet dep fv
+ then acc
+ else MissingDependency (featureName f) depName : acc
-- | Check for parity errors in C- and C+ contexts.
parityErrors :: Context -> FeatureVector -> [ValidationError]
@@ -207,19 +204,11 @@ findHighestBit !bs = go 0
in if byte == 0
then go (i + 1)
else
- let !bytePos = len - 1 - i
- !highBit = 7 - countLeadingZeros byte
- !bitIdx = fromIntegral bytePos * 8 + fromIntegral highBit
+ let !bytePos = len - 1 - i
+ !highBit = 7 - B.countLeadingZeros byte
+ !bitIdx = fromIntegral bytePos * 8 + fromIntegral highBit
in Just bitIdx
- countLeadingZeros :: B.Bits a => a -> Int
- countLeadingZeros !b = go' 7
- where
- go' !n
- | n < 0 = 8
- | B.testBit b n = 7 - n
- | otherwise = go' (n - 1)
-
-- | Collect all set bits in a feature vector.
--
-- Returns a list of bit indices in ascending order.
diff --git a/ppad-bolt9.cabal b/ppad-bolt9.cabal
@@ -32,6 +32,7 @@ library
build-depends:
base >= 4.9 && < 5
, bytestring >= 0.9 && < 0.13
+ , containers >= 0.6 && < 0.9
, deepseq >= 1.4 && < 1.6
test-suite bolt9-tests