bolt7

Routing gossip protocol, per BOLT #7 (docs.ppad.tech/bolt7).
git clone git://git.ppad.tech/bolt7.git
Log | Files | Refs | README | LICENSE

commit 7ce07077840bf3dc8f86fa4d0d00e23d7813d905
parent b7b8b6f8d6b3ef9b2d3740af8768169a4bfca599
Author: Jared Tobin <jared@jtobin.io>
Date:   Mon, 20 Apr 2026 15:16:00 +0800

lib: type safety improvements

Replace boolean flags with descriptive sum types:
- Direction (NodeOne | NodeTwo) for channel_flags bit 0
- ChannelStatus (Enabled | Disabled) for channel_flags bit 1
- ChannelFlags now uses cfDirection :: Direction, cfStatus :: ChannelStatus

Eliminate MessageFlags / htlcMaxMsat inconsistency:
- Remove MessageFlags type entirely
- Remove chanUpdateMsgFlags field from ChannelUpdate
- Derive message_flags byte from Maybe HtlcMaximumMsat during encoding
- Impossible to construct inconsistent state

Add Hostname newtype for DNS addresses:
- Validates length (1-255 bytes) in smart constructor
- AddrDNS now uses Hostname instead of raw ByteString
- Decoder rejects invalid hostnames

Add BlockHeight / BlockCount newtypes:
- BlockHeight for absolute block heights (first_blocknum)
- BlockCount for relative durations (number_of_blocks)
- Prevents mixing in QueryChannelRange / ReplyChannelRange

Update flake.lock for latest ppad-bolt1.

Diffstat:
Mbench/Main.hs | 11+++++------
Mbench/Weight.hs | 11+++++------
Mlib/Lightning/Protocol/BOLT7/Codec.hs | 129++++++++++++++++++++++++++++++++++++++++++++++++-------------------------------
Mlib/Lightning/Protocol/BOLT7/Messages.hs | 58++++++++++++++++++++++++++++++++++++----------------------
Mlib/Lightning/Protocol/BOLT7/Types.hs | 122+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------
Mlib/Lightning/Protocol/BOLT7/Validate.hs | 22++++++++++++++--------
Mtest/Main.hs | 50++++++++++++++++++++++++--------------------------
7 files changed, 255 insertions(+), 148 deletions(-)

diff --git a/bench/Main.hs b/bench/Main.hs @@ -173,8 +173,7 @@ testChannelUpdate = ChannelUpdate , chanUpdateChainHash = testChainHash , chanUpdateShortChanId = testShortChannelId , chanUpdateTimestamp = 1234567890 - , chanUpdateMsgFlags = MessageFlags True - , chanUpdateChanFlags = ChannelFlags False False + , chanUpdateChanFlags = ChannelFlags NodeOne Enabled , chanUpdateCltvExpDelta = CltvExpiryDelta 144 , chanUpdateHtlcMinMsat = HtlcMinimumMsat 1000 , chanUpdateFeeBaseMsat = FeeBaseMsat 1000 @@ -239,8 +238,8 @@ encodedReplyShortChannelIdsEnd = testQueryChannelRange :: QueryChannelRange testQueryChannelRange = QueryChannelRange { queryRangeChainHash = testChainHash - , queryRangeFirstBlock = 700000 - , queryRangeNumBlocks = 10000 + , queryRangeFirstBlock = BlockHeight 700000 + , queryRangeNumBlocks = BlockCount 10000 , queryRangeTlvs = emptyTlvs } {-# NOINLINE testQueryChannelRange #-} @@ -254,8 +253,8 @@ encodedQueryChannelRange = encodeQueryChannelRange testQueryChannelRange testReplyChannelRange :: ReplyChannelRange testReplyChannelRange = ReplyChannelRange { replyRangeChainHash = testChainHash - , replyRangeFirstBlock = 700000 - , replyRangeNumBlocks = 10000 + , replyRangeFirstBlock = BlockHeight 700000 + , replyRangeNumBlocks = BlockCount 10000 , replyRangeSyncComplete = 1 , replyRangeData = encodeShortChannelIdList [testShortChannelId] , replyRangeTlvs = emptyTlvs diff --git a/bench/Weight.hs b/bench/Weight.hs @@ -165,8 +165,7 @@ mkChannelUpdate !sig !ch !scid = ChannelUpdate , chanUpdateChainHash = ch , chanUpdateShortChanId = scid , chanUpdateTimestamp = 1234567890 - , chanUpdateMsgFlags = MessageFlags True - , chanUpdateChanFlags = ChannelFlags False False + , chanUpdateChanFlags = ChannelFlags NodeOne Enabled , chanUpdateCltvExpDelta = CltvExpiryDelta 144 , chanUpdateHtlcMinMsat = HtlcMinimumMsat 1000 , chanUpdateFeeBaseMsat = FeeBaseMsat 1000 @@ -196,8 +195,8 @@ mkGossipTimestampFilter !ch = GossipTimestampFilter mkQueryChannelRange :: ChainHash -> TlvStream -> QueryChannelRange mkQueryChannelRange !ch !tlvs = QueryChannelRange { queryRangeChainHash = ch - , queryRangeFirstBlock = 700000 - , queryRangeNumBlocks = 10000 + , queryRangeFirstBlock = BlockHeight 700000 + , queryRangeNumBlocks = BlockCount 10000 , queryRangeTlvs = tlvs } @@ -205,8 +204,8 @@ mkQueryChannelRange !ch !tlvs = QueryChannelRange mkReplyChannelRange :: ChainHash -> TlvStream -> ReplyChannelRange mkReplyChannelRange !ch !tlvs = ReplyChannelRange { replyRangeChainHash = ch - , replyRangeFirstBlock = 700000 - , replyRangeNumBlocks = 10000 + , replyRangeFirstBlock = BlockHeight 700000 + , replyRangeNumBlocks = BlockCount 10000 , replyRangeSyncComplete = 1 , replyRangeData = encodeShortChannelIdList [testShortChannelId] , replyRangeTlvs = tlvs diff --git a/lib/Lightning/Protocol/BOLT7/Codec.hs b/lib/Lightning/Protocol/BOLT7/Codec.hs @@ -50,6 +50,7 @@ module Lightning.Protocol.BOLT7.Codec ( ) where import Control.DeepSeq (NFData) +import Data.Bits ((.&.)) import Data.ByteString (ByteString) import qualified Data.ByteString as BS import Data.Word (Word8, Word16, Word32, Word64) @@ -239,9 +240,12 @@ decodeAddresses bs = do Just a -> Right (AddrTorV3 a port, d3) 5 -> do -- DNS hostname (hostLen, d2) <- decodeU8 d1 - (hostBytes, d3) <- decodeBytes (fromIntegral hostLen) d2 + (hostBytes, d3) <- + decodeBytes (fromIntegral hostLen) d2 (port, d4) <- decodeU16 d3 - Right (AddrDNS hostBytes port, d4) + case hostname hostBytes of + Nothing -> Left DecodeInvalidAddress + Just h -> Right (AddrDNS h port, d4) _ -> Left DecodeInvalidAddress -- Unknown address type -- Channel announcement -------------------------------------------------------- @@ -331,12 +335,14 @@ encodeAddresses addrs = Right $ mconcat (map encodeAddress addrs) , getTorV3Addr a , Prim.encodeU16 port ] - encodeAddress (AddrDNS host port) = mconcat - [ BS.singleton 5 - , BS.singleton (fromIntegral $ BS.length host) - , host - , Prim.encodeU16 port - ] + encodeAddress (AddrDNS h port) = + let !host = getHostname h + in mconcat + [ BS.singleton 5 + , BS.singleton (fromIntegral $ BS.length host) + , host + , Prim.encodeU16 port + ] -- | Decode node_announcement message. decodeNodeAnnouncement :: ByteString @@ -363,41 +369,57 @@ decodeNodeAnnouncement bs = do -- Channel update -------------------------------------------------------------- -- | Encode channel_update message. +-- +-- The message_flags byte is derived from the presence of +-- 'chanUpdateHtlcMaxMsat': bit 0 is set when the field is +-- 'Just'. encodeChannelUpdate :: ChannelUpdate -> ByteString -encodeChannelUpdate msg = mconcat - [ unSignature (chanUpdateSignature msg) - , unChainHash (chanUpdateChainHash msg) - , scidToBytes (chanUpdateShortChanId msg) - , Prim.encodeU32 (chanUpdateTimestamp msg) - , BS.singleton (encodeMessageFlags (chanUpdateMsgFlags msg)) - , BS.singleton (encodeChannelFlags (chanUpdateChanFlags msg)) - , Prim.encodeU16 (getCltvExpiryDelta (chanUpdateCltvExpDelta msg)) - , Prim.encodeU64 (getHtlcMinimumMsat (chanUpdateHtlcMinMsat msg)) - , Prim.encodeU32 (getFeeBaseMsat (chanUpdateFeeBaseMsat msg)) - , Prim.encodeU32 (getFeeProportionalMillionths (chanUpdateFeeProportional msg)) - , case chanUpdateHtlcMaxMsat msg of - Nothing -> BS.empty - Just m -> Prim.encodeU64 (getHtlcMaximumMsat m) - ] +encodeChannelUpdate msg = + let !msgFlagsByte = case chanUpdateHtlcMaxMsat msg of + Nothing -> 0x00 :: Word8 + Just _ -> 0x01 + in mconcat + [ unSignature (chanUpdateSignature msg) + , unChainHash (chanUpdateChainHash msg) + , scidToBytes (chanUpdateShortChanId msg) + , Prim.encodeU32 (chanUpdateTimestamp msg) + , BS.singleton msgFlagsByte + , BS.singleton (encodeChannelFlags (chanUpdateChanFlags msg)) + , Prim.encodeU16 + (getCltvExpiryDelta (chanUpdateCltvExpDelta msg)) + , Prim.encodeU64 + (getHtlcMinimumMsat (chanUpdateHtlcMinMsat msg)) + , Prim.encodeU32 + (getFeeBaseMsat (chanUpdateFeeBaseMsat msg)) + , Prim.encodeU32 + (getFeeProportionalMillionths + (chanUpdateFeeProportional msg)) + , case chanUpdateHtlcMaxMsat msg of + Nothing -> BS.empty + Just m -> Prim.encodeU64 (getHtlcMaximumMsat m) + ] -- | Decode channel_update message. +-- +-- The message_flags byte is read from the wire but not stored; +-- bit 0 determines whether 'htlc_maximum_msat' is present. decodeChannelUpdate :: ByteString - -> Either DecodeError (ChannelUpdate, ByteString) + -> Either DecodeError + (ChannelUpdate, ByteString) decodeChannelUpdate bs = do - (sig, bs1) <- decodeSignature bs - (chainH, bs2) <- decodeChainHash bs1 - (scid, bs3) <- decodeShortChannelId bs2 - (timestamp, bs4) <- decodeU32 bs3 - (msgFlagsRaw, bs5) <- decodeU8 bs4 + (sig, bs1) <- decodeSignature bs + (chainH, bs2) <- decodeChainHash bs1 + (scid, bs3) <- decodeShortChannelId bs2 + (timestamp, bs4) <- decodeU32 bs3 + (msgFlagsRaw, bs5) <- decodeU8 bs4 (chanFlagsRaw, bs6) <- decodeU8 bs5 - (cltvDelta, bs7) <- decodeU16 bs6 - (htlcMin, bs8) <- decodeU64 bs7 - (feeBase, bs9) <- decodeU32 bs8 - (feeProp, bs10) <- decodeU32 bs9 - let msgFlags' = decodeMessageFlags msgFlagsRaw - chanFlags' = decodeChannelFlags chanFlagsRaw - -- htlc_maximum_msat is present if message_flags bit 0 is set - (htlcMax, rest) <- if mfHtlcMaxPresent msgFlags' + (cltvDelta, bs7) <- decodeU16 bs6 + (htlcMin, bs8) <- decodeU64 bs7 + (feeBase, bs9) <- decodeU32 bs8 + (feeProp, bs10) <- decodeU32 bs9 + let !chanFlags' = decodeChannelFlags chanFlagsRaw + !htlcMaxPresent = msgFlagsRaw .&. 0x01 /= 0 + (htlcMax, rest) <- if htlcMaxPresent then do (m, r) <- decodeU64 bs10 Right (Just (HtlcMaximumMsat m), r) @@ -407,12 +429,12 @@ decodeChannelUpdate bs = do , chanUpdateChainHash = chainH , chanUpdateShortChanId = scid , chanUpdateTimestamp = timestamp - , chanUpdateMsgFlags = msgFlags' , chanUpdateChanFlags = chanFlags' , chanUpdateCltvExpDelta = CltvExpiryDelta cltvDelta , chanUpdateHtlcMinMsat = HtlcMinimumMsat htlcMin , chanUpdateFeeBaseMsat = FeeBaseMsat feeBase - , chanUpdateFeeProportional = FeeProportionalMillionths feeProp + , chanUpdateFeeProportional = + FeeProportionalMillionths feeProp , chanUpdateHtlcMaxMsat = htlcMax } Right (msg, rest) @@ -501,14 +523,17 @@ decodeReplyShortChannelIdsEnd bs = do encodeQueryChannelRange :: QueryChannelRange -> ByteString encodeQueryChannelRange msg = mconcat [ unChainHash (queryRangeChainHash msg) - , Prim.encodeU32 (queryRangeFirstBlock msg) - , Prim.encodeU32 (queryRangeNumBlocks msg) + , Prim.encodeU32 + (getBlockHeight (queryRangeFirstBlock msg)) + , Prim.encodeU32 + (getBlockCount (queryRangeNumBlocks msg)) , TLV.encodeTlvStream (queryRangeTlvs msg) ] -- | Decode query_channel_range message. decodeQueryChannelRange :: ByteString - -> Either DecodeError (QueryChannelRange, ByteString) + -> Either DecodeError + (QueryChannelRange, ByteString) decodeQueryChannelRange bs = do (chainH, bs1) <- decodeChainHash bs (firstBlock, bs2) <- decodeU32 bs1 @@ -518,22 +543,25 @@ decodeQueryChannelRange bs = do Right t -> t let msg = QueryChannelRange { queryRangeChainHash = chainH - , queryRangeFirstBlock = firstBlock - , queryRangeNumBlocks = numBlocks + , queryRangeFirstBlock = BlockHeight firstBlock + , queryRangeNumBlocks = BlockCount numBlocks , queryRangeTlvs = tlvs } Right (msg, BS.empty) -- | Encode reply_channel_range message. -encodeReplyChannelRange :: ReplyChannelRange -> Either EncodeError ByteString +encodeReplyChannelRange :: ReplyChannelRange + -> Either EncodeError ByteString encodeReplyChannelRange msg = do let rangeData = replyRangeData msg if BS.length rangeData > 65535 then Left EncodeLengthOverflow else Right $ mconcat [ unChainHash (replyRangeChainHash msg) - , Prim.encodeU32 (replyRangeFirstBlock msg) - , Prim.encodeU32 (replyRangeNumBlocks msg) + , Prim.encodeU32 + (getBlockHeight (replyRangeFirstBlock msg)) + , Prim.encodeU32 + (getBlockCount (replyRangeNumBlocks msg)) , BS.singleton (replyRangeSyncComplete msg) , encodeLenPrefixed rangeData , TLV.encodeTlvStream (replyRangeTlvs msg) @@ -541,7 +569,8 @@ encodeReplyChannelRange msg = do -- | Decode reply_channel_range message. decodeReplyChannelRange :: ByteString - -> Either DecodeError (ReplyChannelRange, ByteString) + -> Either DecodeError + (ReplyChannelRange, ByteString) decodeReplyChannelRange bs = do (chainH, bs1) <- decodeChainHash bs (firstBlock, bs2) <- decodeU32 bs1 @@ -553,8 +582,8 @@ decodeReplyChannelRange bs = do Right t -> t let msg = ReplyChannelRange { replyRangeChainHash = chainH - , replyRangeFirstBlock = firstBlock - , replyRangeNumBlocks = numBlocks + , replyRangeFirstBlock = BlockHeight firstBlock + , replyRangeNumBlocks = BlockCount numBlocks , replyRangeSyncComplete = syncComplete , replyRangeData = rangeData , replyRangeTlvs = tlvs diff --git a/lib/Lightning/Protocol/BOLT7/Messages.hs b/lib/Lightning/Protocol/BOLT7/Messages.hs @@ -41,7 +41,7 @@ module Lightning.Protocol.BOLT7.Messages ( import Control.DeepSeq (NFData) import Data.ByteString (ByteString) -import Data.Word (Word8, Word16, Word32) -- Word8 still used by other messages +import Data.Word (Word8, Word16, Word32) import GHC.Generics (Generic) import Lightning.Protocol.BOLT1 (TlvStream) import Lightning.Protocol.BOLT7.Types @@ -121,18 +121,32 @@ instance NFData NodeAnnouncement -- | channel_update message (type 258). -- -- Communicates per-direction routing parameters. +-- +-- The message_flags field is derived automatically during +-- encoding: bit 0 is set when 'chanUpdateHtlcMaxMsat' is +-- 'Just'. data ChannelUpdate = ChannelUpdate - { chanUpdateSignature :: !Signature -- ^ Signature of message - , chanUpdateChainHash :: !ChainHash -- ^ Chain identifier - , chanUpdateShortChanId :: !ShortChannelId -- ^ Short channel ID - , chanUpdateTimestamp :: !Timestamp -- ^ Unix timestamp - , chanUpdateMsgFlags :: !MessageFlags -- ^ Message flags - , chanUpdateChanFlags :: !ChannelFlags -- ^ Channel flags - , chanUpdateCltvExpDelta :: !CltvExpiryDelta -- ^ CLTV expiry delta - , chanUpdateHtlcMinMsat :: !HtlcMinimumMsat -- ^ Minimum HTLC msat - , chanUpdateFeeBaseMsat :: !FeeBaseMsat -- ^ Base fee msat - , chanUpdateFeeProportional :: !FeeProportionalMillionths -- ^ Prop fee - , chanUpdateHtlcMaxMsat :: !(Maybe HtlcMaximumMsat) -- ^ Max HTLC (optional) + { chanUpdateSignature :: !Signature + -- ^ Signature of message + , chanUpdateChainHash :: !ChainHash + -- ^ Chain identifier + , chanUpdateShortChanId :: !ShortChannelId + -- ^ Short channel ID + , chanUpdateTimestamp :: !Timestamp + -- ^ Unix timestamp + , chanUpdateChanFlags :: !ChannelFlags + -- ^ Channel flags + , chanUpdateCltvExpDelta :: !CltvExpiryDelta + -- ^ CLTV expiry delta + , chanUpdateHtlcMinMsat :: !HtlcMinimumMsat + -- ^ Minimum HTLC msat + , chanUpdateFeeBaseMsat :: !FeeBaseMsat + -- ^ Base fee msat + , chanUpdateFeeProportional :: !FeeProportionalMillionths + -- ^ Proportional fee + , chanUpdateHtlcMaxMsat :: !(Maybe HtlcMaximumMsat) + -- ^ Max HTLC (optional; presence sets message_flags + -- bit 0) } deriving (Eq, Show, Generic) @@ -182,10 +196,10 @@ instance NFData ReplyShortChannelIdsEnd -- -- Queries channels within a block range. data QueryChannelRange = QueryChannelRange - { queryRangeChainHash :: !ChainHash -- ^ Chain identifier - , queryRangeFirstBlock :: !Word32 -- ^ First block number - , queryRangeNumBlocks :: !Word32 -- ^ Number of blocks - , queryRangeTlvs :: !TlvStream -- ^ Optional TLV (query_option) + { queryRangeChainHash :: !ChainHash -- ^ Chain identifier + , queryRangeFirstBlock :: !BlockHeight -- ^ First block number + , queryRangeNumBlocks :: !BlockCount -- ^ Number of blocks + , queryRangeTlvs :: !TlvStream -- ^ Optional TLV } deriving (Eq, Show, Generic) @@ -195,12 +209,12 @@ instance NFData QueryChannelRange -- -- Responds to query_channel_range with channel IDs. data ReplyChannelRange = ReplyChannelRange - { replyRangeChainHash :: !ChainHash -- ^ Chain identifier - , replyRangeFirstBlock :: !Word32 -- ^ First block number - , replyRangeNumBlocks :: !Word32 -- ^ Number of blocks - , replyRangeSyncComplete :: !Word8 -- ^ 1 if sync complete - , replyRangeData :: !ByteString -- ^ Encoded short_channel_ids - , replyRangeTlvs :: !TlvStream -- ^ Optional TLVs + { replyRangeChainHash :: !ChainHash -- ^ Chain identifier + , replyRangeFirstBlock :: !BlockHeight -- ^ First block + , replyRangeNumBlocks :: !BlockCount -- ^ Block count + , replyRangeSyncComplete :: !Word8 -- ^ 1 if complete + , replyRangeData :: !ByteString -- ^ Encoded SCIDs + , replyRangeTlvs :: !TlvStream -- ^ Optional TLVs } deriving (Eq, Show, Generic) diff --git a/lib/Lightning/Protocol/BOLT7/Types.hs b/lib/Lightning/Protocol/BOLT7/Types.hs @@ -64,11 +64,13 @@ module Lightning.Protocol.BOLT7.Types ( , TorV3Addr , torV3Addr , getTorV3Addr + , Hostname + , hostname + , getHostname -- * Channel update flags - , MessageFlags(..) - , encodeMessageFlags - , decodeMessageFlags + , Direction(..) + , ChannelStatus(..) , ChannelFlags(..) , encodeChannelFlags , decodeChannelFlags @@ -80,6 +82,10 @@ module Lightning.Protocol.BOLT7.Types ( , HtlcMinimumMsat(..) , HtlcMaximumMsat(..) + -- * Block range types + , BlockHeight(..) + , BlockCount(..) + -- * Constants , chainHashLen , shortChannelIdLen @@ -318,45 +324,68 @@ torV3Addr !bs | otherwise = Nothing {-# INLINE torV3Addr #-} +-- | DNS hostname (1-255 bytes). +-- +-- Per BOLT #7 address descriptor type 5, the hostname is +-- a length-prefixed DNS name. The length byte limits it to +-- 255 bytes. +newtype Hostname = Hostname { getHostname :: ByteString } + deriving (Eq, Show, Generic) + +instance NFData Hostname + +-- | Smart constructor for Hostname. +-- +-- Returns Nothing if the hostname is empty or exceeds +-- 255 bytes. +hostname :: ByteString -> Maybe Hostname +hostname !bs + | BS.null bs = Nothing + | BS.length bs > 255 = Nothing + | otherwise = Just (Hostname bs) +{-# INLINE hostname #-} + -- | Network address with port. data Address - = AddrIPv4 !IPv4Addr !Word16 -- ^ IPv4 address + port - | AddrIPv6 !IPv6Addr !Word16 -- ^ IPv6 address + port + = AddrIPv4 !IPv4Addr !Word16 -- ^ IPv4 address + port + | AddrIPv6 !IPv6Addr !Word16 -- ^ IPv6 address + port | AddrTorV3 !TorV3Addr !Word16 -- ^ Tor v3 address + port - | AddrDNS !ByteString !Word16 -- ^ DNS hostname + port + | AddrDNS !Hostname !Word16 -- ^ DNS hostname + port deriving (Eq, Show, Generic) instance NFData Address -- Channel update flags -------------------------------------------------------- --- | Message flags for channel_update. +-- | Direction of a channel_update. -- --- Bit 0: htlc_maximum_msat field is present. -data MessageFlags = MessageFlags - { mfHtlcMaxPresent :: !Bool -- ^ htlc_maximum_msat is present - } - deriving (Eq, Show, Generic) +-- Per BOLT #7, bit 0 of channel_flags indicates which node +-- is the origin of the update. +data Direction + = NodeOne -- ^ Update from node_id_1 (bit 0 = 0) + | NodeTwo -- ^ Update from node_id_2 (bit 0 = 1) + deriving (Eq, Ord, Show, Generic) -instance NFData MessageFlags +instance NFData Direction --- | Encode MessageFlags to Word8. -encodeMessageFlags :: MessageFlags -> Word8 -encodeMessageFlags mf = if mfHtlcMaxPresent mf then 0x01 else 0x00 -{-# INLINE encodeMessageFlags #-} +-- | Channel enabled\/disabled status. +-- +-- Per BOLT #7, bit 1 of channel_flags indicates whether +-- the channel is disabled. +data ChannelStatus + = Enabled -- ^ Channel is active (bit 1 = 0) + | Disabled -- ^ Channel is disabled (bit 1 = 1) + deriving (Eq, Ord, Show, Generic) --- | Decode Word8 to MessageFlags. -decodeMessageFlags :: Word8 -> MessageFlags -decodeMessageFlags w = MessageFlags { mfHtlcMaxPresent = w .&. 0x01 /= 0 } -{-# INLINE decodeMessageFlags #-} +instance NFData ChannelStatus -- | Channel flags for channel_update. -- --- Bit 0: direction (0 = node_id_1 is origin, 1 = node_id_2 is origin). --- Bit 1: disabled (1 = channel disabled). +-- Bit 0: direction (0 = node_id_1, 1 = node_id_2). +-- Bit 1: disabled (0 = enabled, 1 = disabled). data ChannelFlags = ChannelFlags - { cfDirection :: !Bool -- ^ True = node_id_2 is origin - , cfDisabled :: !Bool -- ^ True = channel is disabled + { cfDirection :: !Direction -- ^ Update origin + , cfStatus :: !ChannelStatus -- ^ Channel status } deriving (Eq, Show, Generic) @@ -365,15 +394,25 @@ instance NFData ChannelFlags -- | Encode ChannelFlags to Word8. encodeChannelFlags :: ChannelFlags -> Word8 encodeChannelFlags cf = - (if cfDirection cf then 0x01 else 0x00) .|. - (if cfDisabled cf then 0x02 else 0x00) + dir .|. sta + where + dir = case cfDirection cf of + NodeOne -> 0x00 + NodeTwo -> 0x01 + sta = case cfStatus cf of + Enabled -> 0x00 + Disabled -> 0x02 {-# INLINE encodeChannelFlags #-} -- | Decode Word8 to ChannelFlags. decodeChannelFlags :: Word8 -> ChannelFlags decodeChannelFlags w = ChannelFlags - { cfDirection = w .&. 0x01 /= 0 - , cfDisabled = w .&. 0x02 /= 0 + { cfDirection = if w .&. 0x01 /= 0 + then NodeTwo + else NodeOne + , cfStatus = if w .&. 0x02 /= 0 + then Disabled + else Enabled } {-# INLINE decodeChannelFlags #-} @@ -405,7 +444,30 @@ newtype HtlcMinimumMsat = HtlcMinimumMsat { getHtlcMinimumMsat :: Word64 } instance NFData HtlcMinimumMsat -- | Maximum HTLC value in millisatoshis. -newtype HtlcMaximumMsat = HtlcMaximumMsat { getHtlcMaximumMsat :: Word64 } +newtype HtlcMaximumMsat = HtlcMaximumMsat + { getHtlcMaximumMsat :: Word64 } deriving (Eq, Ord, Show, Generic) instance NFData HtlcMaximumMsat + +-- Block range types ----------------------------------------------------------- + +-- | Absolute block height. +-- +-- Used in query_channel_range and reply_channel_range for +-- the first_blocknum field. +newtype BlockHeight = BlockHeight + { getBlockHeight :: Word32 } + deriving (Eq, Ord, Show, Generic) + +instance NFData BlockHeight + +-- | Block count (relative duration). +-- +-- Used in query_channel_range and reply_channel_range for +-- the number_of_blocks field. +newtype BlockCount = BlockCount + { getBlockCount :: Word32 } + deriving (Eq, Ord, Show, Generic) + +instance NFData BlockCount diff --git a/lib/Lightning/Protocol/BOLT7/Validate.hs b/lib/Lightning/Protocol/BOLT7/Validate.hs @@ -79,12 +79,13 @@ validateNodeAnnouncement msg = do -- -- Checks: -- --- * htlc_minimum_msat <= htlc_maximum_msat (if htlc_maximum_msat present) +-- * htlc_minimum_msat <= htlc_maximum_msat (if present) -- --- Note: The spec says message_flags bit 0 MUST be set if htlc_maximum_msat --- is advertised. We don't enforce this at validation time since the codec --- already handles the conditional field based on the flag. -validateChannelUpdate :: ChannelUpdate -> Either ValidationError () +-- Note: message_flags consistency is enforced at the type +-- level -- the flag is derived from the presence of +-- 'chanUpdateHtlcMaxMsat'. +validateChannelUpdate :: ChannelUpdate + -> Either ValidationError () validateChannelUpdate msg = do case chanUpdateHtlcMaxMsat msg of Nothing -> Right () @@ -99,10 +100,15 @@ validateChannelUpdate msg = do -- Checks: -- -- * first_blocknum + number_of_blocks does not overflow -validateQueryChannelRange :: QueryChannelRange -> Either ValidationError () +validateQueryChannelRange :: QueryChannelRange + -> Either ValidationError () validateQueryChannelRange msg = do - let first = fromIntegral (queryRangeFirstBlock msg) :: Word64 - num = fromIntegral (queryRangeNumBlocks msg) :: Word64 + let first = fromIntegral + (getBlockHeight (queryRangeFirstBlock msg)) + :: Word64 + num = fromIntegral + (getBlockCount (queryRangeNumBlocks msg)) + :: Word64 if first + num > fromIntegral (maxBound :: Word32) then Left ValidateBlockOverflow else Right () diff --git a/test/Main.hs b/test/Main.hs @@ -207,9 +207,8 @@ channel_update_tests = testGroup "ChannelUpdate" [ , chanUpdateChainHash = testChainHash , chanUpdateShortChanId = testShortChannelId , chanUpdateTimestamp = 1234567890 - , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = False } , chanUpdateChanFlags = ChannelFlags - { cfDirection = True, cfDisabled = False } + { cfDirection = NodeTwo, cfStatus = Enabled } , chanUpdateCltvExpDelta = CltvExpiryDelta 144 , chanUpdateHtlcMinMsat = HtlcMinimumMsat 1000 , chanUpdateFeeBaseMsat = FeeBaseMsat 1000 @@ -226,9 +225,8 @@ channel_update_tests = testGroup "ChannelUpdate" [ , chanUpdateChainHash = testChainHash , chanUpdateShortChanId = testShortChannelId , chanUpdateTimestamp = 1234567890 - , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = True } , chanUpdateChanFlags = ChannelFlags - { cfDirection = False, cfDisabled = False } + { cfDirection = NodeOne, cfStatus = Enabled } , chanUpdateCltvExpDelta = CltvExpiryDelta 40 , chanUpdateHtlcMinMsat = HtlcMinimumMsat 1000 , chanUpdateFeeBaseMsat = FeeBaseMsat 500 @@ -292,8 +290,8 @@ query_tests = testGroup "Query Messages" [ testCase "encode/decode roundtrip" $ do let msg = QueryChannelRange { queryRangeChainHash = testChainHash - , queryRangeFirstBlock = 600000 - , queryRangeNumBlocks = 10000 + , queryRangeFirstBlock = BlockHeight 600000 + , queryRangeNumBlocks = BlockCount 10000 , queryRangeTlvs = emptyTlvs } encoded = encodeQueryChannelRange msg @@ -308,8 +306,8 @@ query_tests = testGroup "Query Messages" [ testCase "encode/decode roundtrip" $ do let msg = ReplyChannelRange { replyRangeChainHash = testChainHash - , replyRangeFirstBlock = 600000 - , replyRangeNumBlocks = 10000 + , replyRangeFirstBlock = BlockHeight 600000 + , replyRangeNumBlocks = BlockCount 10000 , replyRangeSyncComplete = 1 , replyRangeData = BS.replicate 16 0xcd , replyRangeTlvs = emptyTlvs @@ -427,9 +425,9 @@ hash_tests = testGroup "Hash Functions" [ , chanUpdateChainHash = testChainHash , chanUpdateShortChanId = testShortChannelId , chanUpdateTimestamp = 1234567890 - , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = False } , chanUpdateChanFlags = ChannelFlags - { cfDirection = False, cfDisabled = False } + { cfDirection = NodeOne + , cfStatus = Enabled } , chanUpdateCltvExpDelta = CltvExpiryDelta 144 , chanUpdateHtlcMinMsat = HtlcMinimumMsat 1000 , chanUpdateFeeBaseMsat = FeeBaseMsat 1000 @@ -447,9 +445,9 @@ hash_tests = testGroup "Hash Functions" [ , chanUpdateChainHash = testChainHash , chanUpdateShortChanId = testShortChannelId , chanUpdateTimestamp = 1234567890 - , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = False } , chanUpdateChanFlags = ChannelFlags - { cfDirection = False, cfDisabled = False } + { cfDirection = NodeOne + , cfStatus = Enabled } , chanUpdateCltvExpDelta = CltvExpiryDelta 144 , chanUpdateHtlcMinMsat = HtlcMinimumMsat 1000 , chanUpdateFeeBaseMsat = FeeBaseMsat 1000 @@ -466,9 +464,9 @@ hash_tests = testGroup "Hash Functions" [ , chanUpdateChainHash = testChainHash , chanUpdateShortChanId = testShortChannelId , chanUpdateTimestamp = 1000000000 - , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = False } , chanUpdateChanFlags = ChannelFlags - { cfDirection = False, cfDisabled = False } + { cfDirection = NodeOne + , cfStatus = Enabled } , chanUpdateCltvExpDelta = CltvExpiryDelta 144 , chanUpdateHtlcMinMsat = HtlcMinimumMsat 1000 , chanUpdateFeeBaseMsat = FeeBaseMsat 1000 @@ -480,9 +478,9 @@ hash_tests = testGroup "Hash Functions" [ , chanUpdateChainHash = testChainHash , chanUpdateShortChanId = testShortChannelId , chanUpdateTimestamp = 2000000000 - , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = False } , chanUpdateChanFlags = ChannelFlags - { cfDirection = False, cfDisabled = False } + { cfDirection = NodeOne + , cfStatus = Enabled } , chanUpdateCltvExpDelta = CltvExpiryDelta 144 , chanUpdateHtlcMinMsat = HtlcMinimumMsat 1000 , chanUpdateFeeBaseMsat = FeeBaseMsat 1000 @@ -536,9 +534,9 @@ validation_tests = testGroup "Validation" [ , chanUpdateChainHash = testChainHash , chanUpdateShortChanId = testShortChannelId , chanUpdateTimestamp = 1234567890 - , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = True } , chanUpdateChanFlags = ChannelFlags - { cfDirection = False, cfDisabled = False } + { cfDirection = NodeOne + , cfStatus = Enabled } , chanUpdateCltvExpDelta = CltvExpiryDelta 144 , chanUpdateHtlcMinMsat = HtlcMinimumMsat 1000 , chanUpdateFeeBaseMsat = FeeBaseMsat 1000 @@ -552,9 +550,9 @@ validation_tests = testGroup "Validation" [ , chanUpdateChainHash = testChainHash , chanUpdateShortChanId = testShortChannelId , chanUpdateTimestamp = 1234567890 - , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = True } , chanUpdateChanFlags = ChannelFlags - { cfDirection = False, cfDisabled = False } + { cfDirection = NodeOne + , cfStatus = Enabled } , chanUpdateCltvExpDelta = CltvExpiryDelta 144 , chanUpdateHtlcMinMsat = HtlcMinimumMsat 2000000000 -- > htlcMax , chanUpdateFeeBaseMsat = FeeBaseMsat 1000 @@ -567,16 +565,16 @@ validation_tests = testGroup "Validation" [ testCase "valid range passes" $ do let msg = QueryChannelRange { queryRangeChainHash = testChainHash - , queryRangeFirstBlock = 600000 - , queryRangeNumBlocks = 10000 + , queryRangeFirstBlock = BlockHeight 600000 + , queryRangeNumBlocks = BlockCount 10000 , queryRangeTlvs = emptyTlvs } validateQueryChannelRange msg @?= Right () , testCase "rejects overflow" $ do let msg = QueryChannelRange { queryRangeChainHash = testChainHash - , queryRangeFirstBlock = maxBound -- 0xFFFFFFFF - , queryRangeNumBlocks = 10 + , queryRangeFirstBlock = BlockHeight maxBound + , queryRangeNumBlocks = BlockCount 10 , queryRangeTlvs = emptyTlvs } validateQueryChannelRange msg @?= Left ValidateBlockOverflow @@ -648,9 +646,9 @@ propChannelUpdateRoundtrip timestamp cltvDelta = property $ do , chanUpdateChainHash = testChainHash , chanUpdateShortChanId = testShortChannelId , chanUpdateTimestamp = timestamp - , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = False } , chanUpdateChanFlags = ChannelFlags - { cfDirection = False, cfDisabled = False } + { cfDirection = NodeOne + , cfStatus = Enabled } , chanUpdateCltvExpDelta = CltvExpiryDelta cltvDelta , chanUpdateHtlcMinMsat = HtlcMinimumMsat 1000 , chanUpdateFeeBaseMsat = FeeBaseMsat 1000