commit 5fbd48961c36baa346f0f5343ca3cd40ab0510cf
parent cf2177ebbb8c31b41f49fd99be95d597c7ffe08d
Author: Jared Tobin <jared@jtobin.io>
Date: Sun, 25 Jan 2026 16:04:11 +0400
Refactor: use ADTs for channel update flags
Replace raw Word8 message/channel flags with structured ADTs that provide
type-safe access to individual flag bits:
- MessageFlags: encapsulates htlc_maximum_msat presence (bit 0)
- ChannelFlags: encapsulates direction (bit 0) and disabled (bit 1)
Add encode/decode functions for wire format conversion. Update Codec.hs
to use the ADT accessor for htlc_max presence check. Update all tests
to use the new structured types.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat:
4 files changed, 104 insertions(+), 37 deletions(-)
diff --git a/lib/Lightning/Protocol/BOLT7/Codec.hs b/lib/Lightning/Protocol/BOLT7/Codec.hs
@@ -50,7 +50,6 @@ module Lightning.Protocol.BOLT7.Codec (
) where
import Control.DeepSeq (NFData)
-import Data.Bits ((.&.))
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.Word (Word8, Word16, Word32, Word64)
@@ -370,8 +369,8 @@ encodeChannelUpdate msg = mconcat
, getChainHash (chanUpdateChainHash msg)
, getShortChannelId (chanUpdateShortChanId msg)
, Prim.encodeU32 (chanUpdateTimestamp msg)
- , BS.singleton (chanUpdateMsgFlags msg)
- , BS.singleton (chanUpdateChanFlags msg)
+ , BS.singleton (encodeMessageFlags (chanUpdateMsgFlags msg))
+ , BS.singleton (encodeChannelFlags (chanUpdateChanFlags msg))
, Prim.encodeU16 (chanUpdateCltvExpDelta msg)
, Prim.encodeU64 (chanUpdateHtlcMinMsat msg)
, Prim.encodeU32 (chanUpdateFeeBaseMsat msg)
@@ -385,18 +384,20 @@ encodeChannelUpdate msg = mconcat
decodeChannelUpdate :: ByteString
-> Either DecodeError (ChannelUpdate, ByteString)
decodeChannelUpdate bs = do
- (sig, bs1) <- decodeSignature bs
- (chainH, bs2) <- decodeChainHash bs1
- (scid, bs3) <- decodeShortChannelId bs2
- (timestamp, bs4) <- decodeU32 bs3
- (msgFlags, bs5) <- decodeU8 bs4
- (chanFlags, bs6) <- decodeU8 bs5
- (cltvDelta, bs7) <- decodeU16 bs6
- (htlcMin, bs8) <- decodeU64 bs7
- (feeBase, bs9) <- decodeU32 bs8
- (feeProp, bs10) <- decodeU32 bs9
+ (sig, bs1) <- decodeSignature bs
+ (chainH, bs2) <- decodeChainHash bs1
+ (scid, bs3) <- decodeShortChannelId bs2
+ (timestamp, bs4) <- decodeU32 bs3
+ (msgFlagsRaw, bs5) <- decodeU8 bs4
+ (chanFlagsRaw, bs6) <- decodeU8 bs5
+ (cltvDelta, bs7) <- decodeU16 bs6
+ (htlcMin, bs8) <- decodeU64 bs7
+ (feeBase, bs9) <- decodeU32 bs8
+ (feeProp, bs10) <- decodeU32 bs9
+ let msgFlags' = decodeMessageFlags msgFlagsRaw
+ chanFlags' = decodeChannelFlags chanFlagsRaw
-- htlc_maximum_msat is present if message_flags bit 0 is set
- (htlcMax, rest) <- if msgFlags .&. 0x01 /= 0
+ (htlcMax, rest) <- if mfHtlcMaxPresent msgFlags'
then do
(m, r) <- decodeU64 bs10
Right (Just m, r)
@@ -406,8 +407,8 @@ decodeChannelUpdate bs = do
, chanUpdateChainHash = chainH
, chanUpdateShortChanId = scid
, chanUpdateTimestamp = timestamp
- , chanUpdateMsgFlags = msgFlags
- , chanUpdateChanFlags = chanFlags
+ , chanUpdateMsgFlags = msgFlags'
+ , chanUpdateChanFlags = chanFlags'
, chanUpdateCltvExpDelta = cltvDelta
, chanUpdateHtlcMinMsat = htlcMin
, chanUpdateFeeBaseMsat = feeBase
diff --git a/lib/Lightning/Protocol/BOLT7/Messages.hs b/lib/Lightning/Protocol/BOLT7/Messages.hs
@@ -41,7 +41,7 @@ module Lightning.Protocol.BOLT7.Messages (
import Control.DeepSeq (NFData)
import Data.ByteString (ByteString)
-import Data.Word (Word8, Word16, Word32)
+import Data.Word (Word8, Word16, Word32) -- Word8 still used by other messages
import GHC.Generics (Generic)
import Lightning.Protocol.BOLT1 (TlvStream)
import Lightning.Protocol.BOLT7.Types
@@ -126,8 +126,8 @@ data ChannelUpdate = ChannelUpdate
, chanUpdateChainHash :: !ChainHash -- ^ Chain identifier
, chanUpdateShortChanId :: !ShortChannelId -- ^ Short channel ID
, chanUpdateTimestamp :: !Timestamp -- ^ Unix timestamp
- , chanUpdateMsgFlags :: !Word8 -- ^ Message flags
- , chanUpdateChanFlags :: !Word8 -- ^ Channel flags
+ , chanUpdateMsgFlags :: !MessageFlags -- ^ Message flags
+ , chanUpdateChanFlags :: !ChannelFlags -- ^ Channel flags
, chanUpdateCltvExpDelta :: !CltvExpiryDelta -- ^ CLTV expiry delta
, chanUpdateHtlcMinMsat :: !HtlcMinimumMsat -- ^ Minimum HTLC msat
, chanUpdateFeeBaseMsat :: !FeeBaseMsat -- ^ Base fee msat
diff --git a/lib/Lightning/Protocol/BOLT7/Types.hs b/lib/Lightning/Protocol/BOLT7/Types.hs
@@ -64,6 +64,14 @@ module Lightning.Protocol.BOLT7.Types (
, torV3Addr
, getTorV3Addr
+ -- * Channel update flags
+ , MessageFlags(..)
+ , encodeMessageFlags
+ , decodeMessageFlags
+ , ChannelFlags(..)
+ , encodeChannelFlags
+ , decodeChannelFlags
+
-- * Routing parameters
, CltvExpiryDelta
, FeeBaseMsat
@@ -401,6 +409,55 @@ data Address
instance NFData Address
+-- Channel update flags --------------------------------------------------------
+
+-- | Message flags for channel_update.
+--
+-- Bit 0: htlc_maximum_msat field is present.
+data MessageFlags = MessageFlags
+ { mfHtlcMaxPresent :: !Bool -- ^ htlc_maximum_msat is present
+ }
+ deriving (Eq, Show, Generic)
+
+instance NFData MessageFlags
+
+-- | Encode MessageFlags to Word8.
+encodeMessageFlags :: MessageFlags -> Word8
+encodeMessageFlags mf = if mfHtlcMaxPresent mf then 0x01 else 0x00
+{-# INLINE encodeMessageFlags #-}
+
+-- | Decode Word8 to MessageFlags.
+decodeMessageFlags :: Word8 -> MessageFlags
+decodeMessageFlags w = MessageFlags { mfHtlcMaxPresent = w .&. 0x01 /= 0 }
+{-# INLINE decodeMessageFlags #-}
+
+-- | Channel flags for channel_update.
+--
+-- Bit 0: direction (0 = node_id_1 is origin, 1 = node_id_2 is origin).
+-- Bit 1: disabled (1 = channel disabled).
+data ChannelFlags = ChannelFlags
+ { cfDirection :: !Bool -- ^ True = node_id_2 is origin
+ , cfDisabled :: !Bool -- ^ True = channel is disabled
+ }
+ deriving (Eq, Show, Generic)
+
+instance NFData ChannelFlags
+
+-- | Encode ChannelFlags to Word8.
+encodeChannelFlags :: ChannelFlags -> Word8
+encodeChannelFlags cf =
+ (if cfDirection cf then 0x01 else 0x00) .|.
+ (if cfDisabled cf then 0x02 else 0x00)
+{-# INLINE encodeChannelFlags #-}
+
+-- | Decode Word8 to ChannelFlags.
+decodeChannelFlags :: Word8 -> ChannelFlags
+decodeChannelFlags w = ChannelFlags
+ { cfDirection = w .&. 0x01 /= 0
+ , cfDisabled = w .&. 0x02 /= 0
+ }
+{-# INLINE decodeChannelFlags #-}
+
-- Routing parameters ----------------------------------------------------------
-- | CLTV expiry delta.
diff --git a/test/Main.hs b/test/Main.hs
@@ -201,8 +201,9 @@ channel_update_tests = testGroup "ChannelUpdate" [
, chanUpdateChainHash = testChainHash
, chanUpdateShortChanId = testShortChannelId
, chanUpdateTimestamp = 1234567890
- , chanUpdateMsgFlags = 0x00
- , chanUpdateChanFlags = 0x01
+ , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = False }
+ , chanUpdateChanFlags = ChannelFlags
+ { cfDirection = True, cfDisabled = False }
, chanUpdateCltvExpDelta = 144
, chanUpdateHtlcMinMsat = 1000
, chanUpdateFeeBaseMsat = 1000
@@ -219,8 +220,9 @@ channel_update_tests = testGroup "ChannelUpdate" [
, chanUpdateChainHash = testChainHash
, chanUpdateShortChanId = testShortChannelId
, chanUpdateTimestamp = 1234567890
- , chanUpdateMsgFlags = 0x01 -- bit 0 set
- , chanUpdateChanFlags = 0x00
+ , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = True }
+ , chanUpdateChanFlags = ChannelFlags
+ { cfDirection = False, cfDisabled = False }
, chanUpdateCltvExpDelta = 40
, chanUpdateHtlcMinMsat = 1000
, chanUpdateFeeBaseMsat = 500
@@ -419,8 +421,9 @@ hash_tests = testGroup "Hash Functions" [
, chanUpdateChainHash = testChainHash
, chanUpdateShortChanId = testShortChannelId
, chanUpdateTimestamp = 1234567890
- , chanUpdateMsgFlags = 0x00
- , chanUpdateChanFlags = 0x00
+ , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = False }
+ , chanUpdateChanFlags = ChannelFlags
+ { cfDirection = False, cfDisabled = False }
, chanUpdateCltvExpDelta = 144
, chanUpdateHtlcMinMsat = 1000
, chanUpdateFeeBaseMsat = 1000
@@ -438,8 +441,9 @@ hash_tests = testGroup "Hash Functions" [
, chanUpdateChainHash = testChainHash
, chanUpdateShortChanId = testShortChannelId
, chanUpdateTimestamp = 1234567890
- , chanUpdateMsgFlags = 0x00
- , chanUpdateChanFlags = 0x00
+ , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = False }
+ , chanUpdateChanFlags = ChannelFlags
+ { cfDirection = False, cfDisabled = False }
, chanUpdateCltvExpDelta = 144
, chanUpdateHtlcMinMsat = 1000
, chanUpdateFeeBaseMsat = 1000
@@ -456,8 +460,9 @@ hash_tests = testGroup "Hash Functions" [
, chanUpdateChainHash = testChainHash
, chanUpdateShortChanId = testShortChannelId
, chanUpdateTimestamp = 1000000000
- , chanUpdateMsgFlags = 0x00
- , chanUpdateChanFlags = 0x00
+ , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = False }
+ , chanUpdateChanFlags = ChannelFlags
+ { cfDirection = False, cfDisabled = False }
, chanUpdateCltvExpDelta = 144
, chanUpdateHtlcMinMsat = 1000
, chanUpdateFeeBaseMsat = 1000
@@ -469,8 +474,9 @@ hash_tests = testGroup "Hash Functions" [
, chanUpdateChainHash = testChainHash
, chanUpdateShortChanId = testShortChannelId
, chanUpdateTimestamp = 2000000000
- , chanUpdateMsgFlags = 0x00
- , chanUpdateChanFlags = 0x00
+ , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = False }
+ , chanUpdateChanFlags = ChannelFlags
+ { cfDirection = False, cfDisabled = False }
, chanUpdateCltvExpDelta = 144
, chanUpdateHtlcMinMsat = 1000
, chanUpdateFeeBaseMsat = 1000
@@ -524,8 +530,9 @@ validation_tests = testGroup "Validation" [
, chanUpdateChainHash = testChainHash
, chanUpdateShortChanId = testShortChannelId
, chanUpdateTimestamp = 1234567890
- , chanUpdateMsgFlags = 0x01
- , chanUpdateChanFlags = 0x00
+ , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = True }
+ , chanUpdateChanFlags = ChannelFlags
+ { cfDirection = False, cfDisabled = False }
, chanUpdateCltvExpDelta = 144
, chanUpdateHtlcMinMsat = 1000
, chanUpdateFeeBaseMsat = 1000
@@ -539,8 +546,9 @@ validation_tests = testGroup "Validation" [
, chanUpdateChainHash = testChainHash
, chanUpdateShortChanId = testShortChannelId
, chanUpdateTimestamp = 1234567890
- , chanUpdateMsgFlags = 0x01
- , chanUpdateChanFlags = 0x00
+ , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = True }
+ , chanUpdateChanFlags = ChannelFlags
+ { cfDirection = False, cfDisabled = False }
, chanUpdateCltvExpDelta = 144
, chanUpdateHtlcMinMsat = 2000000000 -- > htlcMax
, chanUpdateFeeBaseMsat = 1000
@@ -634,8 +642,9 @@ propChannelUpdateRoundtrip timestamp cltvDelta = property $ do
, chanUpdateChainHash = testChainHash
, chanUpdateShortChanId = testShortChannelId
, chanUpdateTimestamp = timestamp
- , chanUpdateMsgFlags = 0x00
- , chanUpdateChanFlags = 0x00
+ , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = False }
+ , chanUpdateChanFlags = ChannelFlags
+ { cfDirection = False, cfDisabled = False }
, chanUpdateCltvExpDelta = cltvDelta
, chanUpdateHtlcMinMsat = 1000
, chanUpdateFeeBaseMsat = 1000