commit 7db453176af2081e20de3aaca61886af22852d96
parent b7b8b6f8d6b3ef9b2d3740af8768169a4bfca599
Author: Jared Tobin <jared@jtobin.io>
Date: Mon, 20 Apr 2026 15:16:59 +0800
merge: type safety improvements
Replace boolean protocol flags with descriptive sum types
(Direction, ChannelStatus), eliminate the MessageFlags /
htlcMaxMsat inconsistency by deriving the wire flag from
field presence, add a Hostname newtype with length
validation for DNS addresses, and add BlockHeight /
BlockCount newtypes to prevent mixing absolute and
relative block values in query_channel_range and
reply_channel_range messages.
All 49 tests pass.
Diffstat:
7 files changed, 255 insertions(+), 148 deletions(-)
diff --git a/bench/Main.hs b/bench/Main.hs
@@ -173,8 +173,7 @@ testChannelUpdate = ChannelUpdate
, chanUpdateChainHash = testChainHash
, chanUpdateShortChanId = testShortChannelId
, chanUpdateTimestamp = 1234567890
- , chanUpdateMsgFlags = MessageFlags True
- , chanUpdateChanFlags = ChannelFlags False False
+ , chanUpdateChanFlags = ChannelFlags NodeOne Enabled
, chanUpdateCltvExpDelta = CltvExpiryDelta 144
, chanUpdateHtlcMinMsat = HtlcMinimumMsat 1000
, chanUpdateFeeBaseMsat = FeeBaseMsat 1000
@@ -239,8 +238,8 @@ encodedReplyShortChannelIdsEnd =
testQueryChannelRange :: QueryChannelRange
testQueryChannelRange = QueryChannelRange
{ queryRangeChainHash = testChainHash
- , queryRangeFirstBlock = 700000
- , queryRangeNumBlocks = 10000
+ , queryRangeFirstBlock = BlockHeight 700000
+ , queryRangeNumBlocks = BlockCount 10000
, queryRangeTlvs = emptyTlvs
}
{-# NOINLINE testQueryChannelRange #-}
@@ -254,8 +253,8 @@ encodedQueryChannelRange = encodeQueryChannelRange testQueryChannelRange
testReplyChannelRange :: ReplyChannelRange
testReplyChannelRange = ReplyChannelRange
{ replyRangeChainHash = testChainHash
- , replyRangeFirstBlock = 700000
- , replyRangeNumBlocks = 10000
+ , replyRangeFirstBlock = BlockHeight 700000
+ , replyRangeNumBlocks = BlockCount 10000
, replyRangeSyncComplete = 1
, replyRangeData = encodeShortChannelIdList [testShortChannelId]
, replyRangeTlvs = emptyTlvs
diff --git a/bench/Weight.hs b/bench/Weight.hs
@@ -165,8 +165,7 @@ mkChannelUpdate !sig !ch !scid = ChannelUpdate
, chanUpdateChainHash = ch
, chanUpdateShortChanId = scid
, chanUpdateTimestamp = 1234567890
- , chanUpdateMsgFlags = MessageFlags True
- , chanUpdateChanFlags = ChannelFlags False False
+ , chanUpdateChanFlags = ChannelFlags NodeOne Enabled
, chanUpdateCltvExpDelta = CltvExpiryDelta 144
, chanUpdateHtlcMinMsat = HtlcMinimumMsat 1000
, chanUpdateFeeBaseMsat = FeeBaseMsat 1000
@@ -196,8 +195,8 @@ mkGossipTimestampFilter !ch = GossipTimestampFilter
mkQueryChannelRange :: ChainHash -> TlvStream -> QueryChannelRange
mkQueryChannelRange !ch !tlvs = QueryChannelRange
{ queryRangeChainHash = ch
- , queryRangeFirstBlock = 700000
- , queryRangeNumBlocks = 10000
+ , queryRangeFirstBlock = BlockHeight 700000
+ , queryRangeNumBlocks = BlockCount 10000
, queryRangeTlvs = tlvs
}
@@ -205,8 +204,8 @@ mkQueryChannelRange !ch !tlvs = QueryChannelRange
mkReplyChannelRange :: ChainHash -> TlvStream -> ReplyChannelRange
mkReplyChannelRange !ch !tlvs = ReplyChannelRange
{ replyRangeChainHash = ch
- , replyRangeFirstBlock = 700000
- , replyRangeNumBlocks = 10000
+ , replyRangeFirstBlock = BlockHeight 700000
+ , replyRangeNumBlocks = BlockCount 10000
, replyRangeSyncComplete = 1
, replyRangeData = encodeShortChannelIdList [testShortChannelId]
, replyRangeTlvs = tlvs
diff --git a/lib/Lightning/Protocol/BOLT7/Codec.hs b/lib/Lightning/Protocol/BOLT7/Codec.hs
@@ -50,6 +50,7 @@ module Lightning.Protocol.BOLT7.Codec (
) where
import Control.DeepSeq (NFData)
+import Data.Bits ((.&.))
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.Word (Word8, Word16, Word32, Word64)
@@ -239,9 +240,12 @@ decodeAddresses bs = do
Just a -> Right (AddrTorV3 a port, d3)
5 -> do -- DNS hostname
(hostLen, d2) <- decodeU8 d1
- (hostBytes, d3) <- decodeBytes (fromIntegral hostLen) d2
+ (hostBytes, d3) <-
+ decodeBytes (fromIntegral hostLen) d2
(port, d4) <- decodeU16 d3
- Right (AddrDNS hostBytes port, d4)
+ case hostname hostBytes of
+ Nothing -> Left DecodeInvalidAddress
+ Just h -> Right (AddrDNS h port, d4)
_ -> Left DecodeInvalidAddress -- Unknown address type
-- Channel announcement --------------------------------------------------------
@@ -331,12 +335,14 @@ encodeAddresses addrs = Right $ mconcat (map encodeAddress addrs)
, getTorV3Addr a
, Prim.encodeU16 port
]
- encodeAddress (AddrDNS host port) = mconcat
- [ BS.singleton 5
- , BS.singleton (fromIntegral $ BS.length host)
- , host
- , Prim.encodeU16 port
- ]
+ encodeAddress (AddrDNS h port) =
+ let !host = getHostname h
+ in mconcat
+ [ BS.singleton 5
+ , BS.singleton (fromIntegral $ BS.length host)
+ , host
+ , Prim.encodeU16 port
+ ]
-- | Decode node_announcement message.
decodeNodeAnnouncement :: ByteString
@@ -363,41 +369,57 @@ decodeNodeAnnouncement bs = do
-- Channel update --------------------------------------------------------------
-- | Encode channel_update message.
+--
+-- The message_flags byte is derived from the presence of
+-- 'chanUpdateHtlcMaxMsat': bit 0 is set when the field is
+-- 'Just'.
encodeChannelUpdate :: ChannelUpdate -> ByteString
-encodeChannelUpdate msg = mconcat
- [ unSignature (chanUpdateSignature msg)
- , unChainHash (chanUpdateChainHash msg)
- , scidToBytes (chanUpdateShortChanId msg)
- , Prim.encodeU32 (chanUpdateTimestamp msg)
- , BS.singleton (encodeMessageFlags (chanUpdateMsgFlags msg))
- , BS.singleton (encodeChannelFlags (chanUpdateChanFlags msg))
- , Prim.encodeU16 (getCltvExpiryDelta (chanUpdateCltvExpDelta msg))
- , Prim.encodeU64 (getHtlcMinimumMsat (chanUpdateHtlcMinMsat msg))
- , Prim.encodeU32 (getFeeBaseMsat (chanUpdateFeeBaseMsat msg))
- , Prim.encodeU32 (getFeeProportionalMillionths (chanUpdateFeeProportional msg))
- , case chanUpdateHtlcMaxMsat msg of
- Nothing -> BS.empty
- Just m -> Prim.encodeU64 (getHtlcMaximumMsat m)
- ]
+encodeChannelUpdate msg =
+ let !msgFlagsByte = case chanUpdateHtlcMaxMsat msg of
+ Nothing -> 0x00 :: Word8
+ Just _ -> 0x01
+ in mconcat
+ [ unSignature (chanUpdateSignature msg)
+ , unChainHash (chanUpdateChainHash msg)
+ , scidToBytes (chanUpdateShortChanId msg)
+ , Prim.encodeU32 (chanUpdateTimestamp msg)
+ , BS.singleton msgFlagsByte
+ , BS.singleton (encodeChannelFlags (chanUpdateChanFlags msg))
+ , Prim.encodeU16
+ (getCltvExpiryDelta (chanUpdateCltvExpDelta msg))
+ , Prim.encodeU64
+ (getHtlcMinimumMsat (chanUpdateHtlcMinMsat msg))
+ , Prim.encodeU32
+ (getFeeBaseMsat (chanUpdateFeeBaseMsat msg))
+ , Prim.encodeU32
+ (getFeeProportionalMillionths
+ (chanUpdateFeeProportional msg))
+ , case chanUpdateHtlcMaxMsat msg of
+ Nothing -> BS.empty
+ Just m -> Prim.encodeU64 (getHtlcMaximumMsat m)
+ ]
-- | Decode channel_update message.
+--
+-- The message_flags byte is read from the wire but not stored;
+-- bit 0 determines whether 'htlc_maximum_msat' is present.
decodeChannelUpdate :: ByteString
- -> Either DecodeError (ChannelUpdate, ByteString)
+ -> Either DecodeError
+ (ChannelUpdate, ByteString)
decodeChannelUpdate bs = do
- (sig, bs1) <- decodeSignature bs
- (chainH, bs2) <- decodeChainHash bs1
- (scid, bs3) <- decodeShortChannelId bs2
- (timestamp, bs4) <- decodeU32 bs3
- (msgFlagsRaw, bs5) <- decodeU8 bs4
+ (sig, bs1) <- decodeSignature bs
+ (chainH, bs2) <- decodeChainHash bs1
+ (scid, bs3) <- decodeShortChannelId bs2
+ (timestamp, bs4) <- decodeU32 bs3
+ (msgFlagsRaw, bs5) <- decodeU8 bs4
(chanFlagsRaw, bs6) <- decodeU8 bs5
- (cltvDelta, bs7) <- decodeU16 bs6
- (htlcMin, bs8) <- decodeU64 bs7
- (feeBase, bs9) <- decodeU32 bs8
- (feeProp, bs10) <- decodeU32 bs9
- let msgFlags' = decodeMessageFlags msgFlagsRaw
- chanFlags' = decodeChannelFlags chanFlagsRaw
- -- htlc_maximum_msat is present if message_flags bit 0 is set
- (htlcMax, rest) <- if mfHtlcMaxPresent msgFlags'
+ (cltvDelta, bs7) <- decodeU16 bs6
+ (htlcMin, bs8) <- decodeU64 bs7
+ (feeBase, bs9) <- decodeU32 bs8
+ (feeProp, bs10) <- decodeU32 bs9
+ let !chanFlags' = decodeChannelFlags chanFlagsRaw
+ !htlcMaxPresent = msgFlagsRaw .&. 0x01 /= 0
+ (htlcMax, rest) <- if htlcMaxPresent
then do
(m, r) <- decodeU64 bs10
Right (Just (HtlcMaximumMsat m), r)
@@ -407,12 +429,12 @@ decodeChannelUpdate bs = do
, chanUpdateChainHash = chainH
, chanUpdateShortChanId = scid
, chanUpdateTimestamp = timestamp
- , chanUpdateMsgFlags = msgFlags'
, chanUpdateChanFlags = chanFlags'
, chanUpdateCltvExpDelta = CltvExpiryDelta cltvDelta
, chanUpdateHtlcMinMsat = HtlcMinimumMsat htlcMin
, chanUpdateFeeBaseMsat = FeeBaseMsat feeBase
- , chanUpdateFeeProportional = FeeProportionalMillionths feeProp
+ , chanUpdateFeeProportional =
+ FeeProportionalMillionths feeProp
, chanUpdateHtlcMaxMsat = htlcMax
}
Right (msg, rest)
@@ -501,14 +523,17 @@ decodeReplyShortChannelIdsEnd bs = do
encodeQueryChannelRange :: QueryChannelRange -> ByteString
encodeQueryChannelRange msg = mconcat
[ unChainHash (queryRangeChainHash msg)
- , Prim.encodeU32 (queryRangeFirstBlock msg)
- , Prim.encodeU32 (queryRangeNumBlocks msg)
+ , Prim.encodeU32
+ (getBlockHeight (queryRangeFirstBlock msg))
+ , Prim.encodeU32
+ (getBlockCount (queryRangeNumBlocks msg))
, TLV.encodeTlvStream (queryRangeTlvs msg)
]
-- | Decode query_channel_range message.
decodeQueryChannelRange :: ByteString
- -> Either DecodeError (QueryChannelRange, ByteString)
+ -> Either DecodeError
+ (QueryChannelRange, ByteString)
decodeQueryChannelRange bs = do
(chainH, bs1) <- decodeChainHash bs
(firstBlock, bs2) <- decodeU32 bs1
@@ -518,22 +543,25 @@ decodeQueryChannelRange bs = do
Right t -> t
let msg = QueryChannelRange
{ queryRangeChainHash = chainH
- , queryRangeFirstBlock = firstBlock
- , queryRangeNumBlocks = numBlocks
+ , queryRangeFirstBlock = BlockHeight firstBlock
+ , queryRangeNumBlocks = BlockCount numBlocks
, queryRangeTlvs = tlvs
}
Right (msg, BS.empty)
-- | Encode reply_channel_range message.
-encodeReplyChannelRange :: ReplyChannelRange -> Either EncodeError ByteString
+encodeReplyChannelRange :: ReplyChannelRange
+ -> Either EncodeError ByteString
encodeReplyChannelRange msg = do
let rangeData = replyRangeData msg
if BS.length rangeData > 65535
then Left EncodeLengthOverflow
else Right $ mconcat
[ unChainHash (replyRangeChainHash msg)
- , Prim.encodeU32 (replyRangeFirstBlock msg)
- , Prim.encodeU32 (replyRangeNumBlocks msg)
+ , Prim.encodeU32
+ (getBlockHeight (replyRangeFirstBlock msg))
+ , Prim.encodeU32
+ (getBlockCount (replyRangeNumBlocks msg))
, BS.singleton (replyRangeSyncComplete msg)
, encodeLenPrefixed rangeData
, TLV.encodeTlvStream (replyRangeTlvs msg)
@@ -541,7 +569,8 @@ encodeReplyChannelRange msg = do
-- | Decode reply_channel_range message.
decodeReplyChannelRange :: ByteString
- -> Either DecodeError (ReplyChannelRange, ByteString)
+ -> Either DecodeError
+ (ReplyChannelRange, ByteString)
decodeReplyChannelRange bs = do
(chainH, bs1) <- decodeChainHash bs
(firstBlock, bs2) <- decodeU32 bs1
@@ -553,8 +582,8 @@ decodeReplyChannelRange bs = do
Right t -> t
let msg = ReplyChannelRange
{ replyRangeChainHash = chainH
- , replyRangeFirstBlock = firstBlock
- , replyRangeNumBlocks = numBlocks
+ , replyRangeFirstBlock = BlockHeight firstBlock
+ , replyRangeNumBlocks = BlockCount numBlocks
, replyRangeSyncComplete = syncComplete
, replyRangeData = rangeData
, replyRangeTlvs = tlvs
diff --git a/lib/Lightning/Protocol/BOLT7/Messages.hs b/lib/Lightning/Protocol/BOLT7/Messages.hs
@@ -41,7 +41,7 @@ module Lightning.Protocol.BOLT7.Messages (
import Control.DeepSeq (NFData)
import Data.ByteString (ByteString)
-import Data.Word (Word8, Word16, Word32) -- Word8 still used by other messages
+import Data.Word (Word8, Word16, Word32)
import GHC.Generics (Generic)
import Lightning.Protocol.BOLT1 (TlvStream)
import Lightning.Protocol.BOLT7.Types
@@ -121,18 +121,32 @@ instance NFData NodeAnnouncement
-- | channel_update message (type 258).
--
-- Communicates per-direction routing parameters.
+--
+-- The message_flags field is derived automatically during
+-- encoding: bit 0 is set when 'chanUpdateHtlcMaxMsat' is
+-- 'Just'.
data ChannelUpdate = ChannelUpdate
- { chanUpdateSignature :: !Signature -- ^ Signature of message
- , chanUpdateChainHash :: !ChainHash -- ^ Chain identifier
- , chanUpdateShortChanId :: !ShortChannelId -- ^ Short channel ID
- , chanUpdateTimestamp :: !Timestamp -- ^ Unix timestamp
- , chanUpdateMsgFlags :: !MessageFlags -- ^ Message flags
- , chanUpdateChanFlags :: !ChannelFlags -- ^ Channel flags
- , chanUpdateCltvExpDelta :: !CltvExpiryDelta -- ^ CLTV expiry delta
- , chanUpdateHtlcMinMsat :: !HtlcMinimumMsat -- ^ Minimum HTLC msat
- , chanUpdateFeeBaseMsat :: !FeeBaseMsat -- ^ Base fee msat
- , chanUpdateFeeProportional :: !FeeProportionalMillionths -- ^ Prop fee
- , chanUpdateHtlcMaxMsat :: !(Maybe HtlcMaximumMsat) -- ^ Max HTLC (optional)
+ { chanUpdateSignature :: !Signature
+ -- ^ Signature of message
+ , chanUpdateChainHash :: !ChainHash
+ -- ^ Chain identifier
+ , chanUpdateShortChanId :: !ShortChannelId
+ -- ^ Short channel ID
+ , chanUpdateTimestamp :: !Timestamp
+ -- ^ Unix timestamp
+ , chanUpdateChanFlags :: !ChannelFlags
+ -- ^ Channel flags
+ , chanUpdateCltvExpDelta :: !CltvExpiryDelta
+ -- ^ CLTV expiry delta
+ , chanUpdateHtlcMinMsat :: !HtlcMinimumMsat
+ -- ^ Minimum HTLC msat
+ , chanUpdateFeeBaseMsat :: !FeeBaseMsat
+ -- ^ Base fee msat
+ , chanUpdateFeeProportional :: !FeeProportionalMillionths
+ -- ^ Proportional fee
+ , chanUpdateHtlcMaxMsat :: !(Maybe HtlcMaximumMsat)
+ -- ^ Max HTLC (optional; presence sets message_flags
+ -- bit 0)
}
deriving (Eq, Show, Generic)
@@ -182,10 +196,10 @@ instance NFData ReplyShortChannelIdsEnd
--
-- Queries channels within a block range.
data QueryChannelRange = QueryChannelRange
- { queryRangeChainHash :: !ChainHash -- ^ Chain identifier
- , queryRangeFirstBlock :: !Word32 -- ^ First block number
- , queryRangeNumBlocks :: !Word32 -- ^ Number of blocks
- , queryRangeTlvs :: !TlvStream -- ^ Optional TLV (query_option)
+ { queryRangeChainHash :: !ChainHash -- ^ Chain identifier
+ , queryRangeFirstBlock :: !BlockHeight -- ^ First block number
+ , queryRangeNumBlocks :: !BlockCount -- ^ Number of blocks
+ , queryRangeTlvs :: !TlvStream -- ^ Optional TLV
}
deriving (Eq, Show, Generic)
@@ -195,12 +209,12 @@ instance NFData QueryChannelRange
--
-- Responds to query_channel_range with channel IDs.
data ReplyChannelRange = ReplyChannelRange
- { replyRangeChainHash :: !ChainHash -- ^ Chain identifier
- , replyRangeFirstBlock :: !Word32 -- ^ First block number
- , replyRangeNumBlocks :: !Word32 -- ^ Number of blocks
- , replyRangeSyncComplete :: !Word8 -- ^ 1 if sync complete
- , replyRangeData :: !ByteString -- ^ Encoded short_channel_ids
- , replyRangeTlvs :: !TlvStream -- ^ Optional TLVs
+ { replyRangeChainHash :: !ChainHash -- ^ Chain identifier
+ , replyRangeFirstBlock :: !BlockHeight -- ^ First block
+ , replyRangeNumBlocks :: !BlockCount -- ^ Block count
+ , replyRangeSyncComplete :: !Word8 -- ^ 1 if complete
+ , replyRangeData :: !ByteString -- ^ Encoded SCIDs
+ , replyRangeTlvs :: !TlvStream -- ^ Optional TLVs
}
deriving (Eq, Show, Generic)
diff --git a/lib/Lightning/Protocol/BOLT7/Types.hs b/lib/Lightning/Protocol/BOLT7/Types.hs
@@ -64,11 +64,13 @@ module Lightning.Protocol.BOLT7.Types (
, TorV3Addr
, torV3Addr
, getTorV3Addr
+ , Hostname
+ , hostname
+ , getHostname
-- * Channel update flags
- , MessageFlags(..)
- , encodeMessageFlags
- , decodeMessageFlags
+ , Direction(..)
+ , ChannelStatus(..)
, ChannelFlags(..)
, encodeChannelFlags
, decodeChannelFlags
@@ -80,6 +82,10 @@ module Lightning.Protocol.BOLT7.Types (
, HtlcMinimumMsat(..)
, HtlcMaximumMsat(..)
+ -- * Block range types
+ , BlockHeight(..)
+ , BlockCount(..)
+
-- * Constants
, chainHashLen
, shortChannelIdLen
@@ -318,45 +324,68 @@ torV3Addr !bs
| otherwise = Nothing
{-# INLINE torV3Addr #-}
+-- | DNS hostname (1-255 bytes).
+--
+-- Per BOLT #7 address descriptor type 5, the hostname is
+-- a length-prefixed DNS name. The length byte limits it to
+-- 255 bytes.
+newtype Hostname = Hostname { getHostname :: ByteString }
+ deriving (Eq, Show, Generic)
+
+instance NFData Hostname
+
+-- | Smart constructor for Hostname.
+--
+-- Returns Nothing if the hostname is empty or exceeds
+-- 255 bytes.
+hostname :: ByteString -> Maybe Hostname
+hostname !bs
+ | BS.null bs = Nothing
+ | BS.length bs > 255 = Nothing
+ | otherwise = Just (Hostname bs)
+{-# INLINE hostname #-}
+
-- | Network address with port.
data Address
- = AddrIPv4 !IPv4Addr !Word16 -- ^ IPv4 address + port
- | AddrIPv6 !IPv6Addr !Word16 -- ^ IPv6 address + port
+ = AddrIPv4 !IPv4Addr !Word16 -- ^ IPv4 address + port
+ | AddrIPv6 !IPv6Addr !Word16 -- ^ IPv6 address + port
| AddrTorV3 !TorV3Addr !Word16 -- ^ Tor v3 address + port
- | AddrDNS !ByteString !Word16 -- ^ DNS hostname + port
+ | AddrDNS !Hostname !Word16 -- ^ DNS hostname + port
deriving (Eq, Show, Generic)
instance NFData Address
-- Channel update flags --------------------------------------------------------
--- | Message flags for channel_update.
+-- | Direction of a channel_update.
--
--- Bit 0: htlc_maximum_msat field is present.
-data MessageFlags = MessageFlags
- { mfHtlcMaxPresent :: !Bool -- ^ htlc_maximum_msat is present
- }
- deriving (Eq, Show, Generic)
+-- Per BOLT #7, bit 0 of channel_flags indicates which node
+-- is the origin of the update.
+data Direction
+ = NodeOne -- ^ Update from node_id_1 (bit 0 = 0)
+ | NodeTwo -- ^ Update from node_id_2 (bit 0 = 1)
+ deriving (Eq, Ord, Show, Generic)
-instance NFData MessageFlags
+instance NFData Direction
--- | Encode MessageFlags to Word8.
-encodeMessageFlags :: MessageFlags -> Word8
-encodeMessageFlags mf = if mfHtlcMaxPresent mf then 0x01 else 0x00
-{-# INLINE encodeMessageFlags #-}
+-- | Channel enabled\/disabled status.
+--
+-- Per BOLT #7, bit 1 of channel_flags indicates whether
+-- the channel is disabled.
+data ChannelStatus
+ = Enabled -- ^ Channel is active (bit 1 = 0)
+ | Disabled -- ^ Channel is disabled (bit 1 = 1)
+ deriving (Eq, Ord, Show, Generic)
--- | Decode Word8 to MessageFlags.
-decodeMessageFlags :: Word8 -> MessageFlags
-decodeMessageFlags w = MessageFlags { mfHtlcMaxPresent = w .&. 0x01 /= 0 }
-{-# INLINE decodeMessageFlags #-}
+instance NFData ChannelStatus
-- | Channel flags for channel_update.
--
--- Bit 0: direction (0 = node_id_1 is origin, 1 = node_id_2 is origin).
--- Bit 1: disabled (1 = channel disabled).
+-- Bit 0: direction (0 = node_id_1, 1 = node_id_2).
+-- Bit 1: disabled (0 = enabled, 1 = disabled).
data ChannelFlags = ChannelFlags
- { cfDirection :: !Bool -- ^ True = node_id_2 is origin
- , cfDisabled :: !Bool -- ^ True = channel is disabled
+ { cfDirection :: !Direction -- ^ Update origin
+ , cfStatus :: !ChannelStatus -- ^ Channel status
}
deriving (Eq, Show, Generic)
@@ -365,15 +394,25 @@ instance NFData ChannelFlags
-- | Encode ChannelFlags to Word8.
encodeChannelFlags :: ChannelFlags -> Word8
encodeChannelFlags cf =
- (if cfDirection cf then 0x01 else 0x00) .|.
- (if cfDisabled cf then 0x02 else 0x00)
+ dir .|. sta
+ where
+ dir = case cfDirection cf of
+ NodeOne -> 0x00
+ NodeTwo -> 0x01
+ sta = case cfStatus cf of
+ Enabled -> 0x00
+ Disabled -> 0x02
{-# INLINE encodeChannelFlags #-}
-- | Decode Word8 to ChannelFlags.
decodeChannelFlags :: Word8 -> ChannelFlags
decodeChannelFlags w = ChannelFlags
- { cfDirection = w .&. 0x01 /= 0
- , cfDisabled = w .&. 0x02 /= 0
+ { cfDirection = if w .&. 0x01 /= 0
+ then NodeTwo
+ else NodeOne
+ , cfStatus = if w .&. 0x02 /= 0
+ then Disabled
+ else Enabled
}
{-# INLINE decodeChannelFlags #-}
@@ -405,7 +444,30 @@ newtype HtlcMinimumMsat = HtlcMinimumMsat { getHtlcMinimumMsat :: Word64 }
instance NFData HtlcMinimumMsat
-- | Maximum HTLC value in millisatoshis.
-newtype HtlcMaximumMsat = HtlcMaximumMsat { getHtlcMaximumMsat :: Word64 }
+newtype HtlcMaximumMsat = HtlcMaximumMsat
+ { getHtlcMaximumMsat :: Word64 }
deriving (Eq, Ord, Show, Generic)
instance NFData HtlcMaximumMsat
+
+-- Block range types -----------------------------------------------------------
+
+-- | Absolute block height.
+--
+-- Used in query_channel_range and reply_channel_range for
+-- the first_blocknum field.
+newtype BlockHeight = BlockHeight
+ { getBlockHeight :: Word32 }
+ deriving (Eq, Ord, Show, Generic)
+
+instance NFData BlockHeight
+
+-- | Block count (relative duration).
+--
+-- Used in query_channel_range and reply_channel_range for
+-- the number_of_blocks field.
+newtype BlockCount = BlockCount
+ { getBlockCount :: Word32 }
+ deriving (Eq, Ord, Show, Generic)
+
+instance NFData BlockCount
diff --git a/lib/Lightning/Protocol/BOLT7/Validate.hs b/lib/Lightning/Protocol/BOLT7/Validate.hs
@@ -79,12 +79,13 @@ validateNodeAnnouncement msg = do
--
-- Checks:
--
--- * htlc_minimum_msat <= htlc_maximum_msat (if htlc_maximum_msat present)
+-- * htlc_minimum_msat <= htlc_maximum_msat (if present)
--
--- Note: The spec says message_flags bit 0 MUST be set if htlc_maximum_msat
--- is advertised. We don't enforce this at validation time since the codec
--- already handles the conditional field based on the flag.
-validateChannelUpdate :: ChannelUpdate -> Either ValidationError ()
+-- Note: message_flags consistency is enforced at the type
+-- level -- the flag is derived from the presence of
+-- 'chanUpdateHtlcMaxMsat'.
+validateChannelUpdate :: ChannelUpdate
+ -> Either ValidationError ()
validateChannelUpdate msg = do
case chanUpdateHtlcMaxMsat msg of
Nothing -> Right ()
@@ -99,10 +100,15 @@ validateChannelUpdate msg = do
-- Checks:
--
-- * first_blocknum + number_of_blocks does not overflow
-validateQueryChannelRange :: QueryChannelRange -> Either ValidationError ()
+validateQueryChannelRange :: QueryChannelRange
+ -> Either ValidationError ()
validateQueryChannelRange msg = do
- let first = fromIntegral (queryRangeFirstBlock msg) :: Word64
- num = fromIntegral (queryRangeNumBlocks msg) :: Word64
+ let first = fromIntegral
+ (getBlockHeight (queryRangeFirstBlock msg))
+ :: Word64
+ num = fromIntegral
+ (getBlockCount (queryRangeNumBlocks msg))
+ :: Word64
if first + num > fromIntegral (maxBound :: Word32)
then Left ValidateBlockOverflow
else Right ()
diff --git a/test/Main.hs b/test/Main.hs
@@ -207,9 +207,8 @@ channel_update_tests = testGroup "ChannelUpdate" [
, chanUpdateChainHash = testChainHash
, chanUpdateShortChanId = testShortChannelId
, chanUpdateTimestamp = 1234567890
- , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = False }
, chanUpdateChanFlags = ChannelFlags
- { cfDirection = True, cfDisabled = False }
+ { cfDirection = NodeTwo, cfStatus = Enabled }
, chanUpdateCltvExpDelta = CltvExpiryDelta 144
, chanUpdateHtlcMinMsat = HtlcMinimumMsat 1000
, chanUpdateFeeBaseMsat = FeeBaseMsat 1000
@@ -226,9 +225,8 @@ channel_update_tests = testGroup "ChannelUpdate" [
, chanUpdateChainHash = testChainHash
, chanUpdateShortChanId = testShortChannelId
, chanUpdateTimestamp = 1234567890
- , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = True }
, chanUpdateChanFlags = ChannelFlags
- { cfDirection = False, cfDisabled = False }
+ { cfDirection = NodeOne, cfStatus = Enabled }
, chanUpdateCltvExpDelta = CltvExpiryDelta 40
, chanUpdateHtlcMinMsat = HtlcMinimumMsat 1000
, chanUpdateFeeBaseMsat = FeeBaseMsat 500
@@ -292,8 +290,8 @@ query_tests = testGroup "Query Messages" [
testCase "encode/decode roundtrip" $ do
let msg = QueryChannelRange
{ queryRangeChainHash = testChainHash
- , queryRangeFirstBlock = 600000
- , queryRangeNumBlocks = 10000
+ , queryRangeFirstBlock = BlockHeight 600000
+ , queryRangeNumBlocks = BlockCount 10000
, queryRangeTlvs = emptyTlvs
}
encoded = encodeQueryChannelRange msg
@@ -308,8 +306,8 @@ query_tests = testGroup "Query Messages" [
testCase "encode/decode roundtrip" $ do
let msg = ReplyChannelRange
{ replyRangeChainHash = testChainHash
- , replyRangeFirstBlock = 600000
- , replyRangeNumBlocks = 10000
+ , replyRangeFirstBlock = BlockHeight 600000
+ , replyRangeNumBlocks = BlockCount 10000
, replyRangeSyncComplete = 1
, replyRangeData = BS.replicate 16 0xcd
, replyRangeTlvs = emptyTlvs
@@ -427,9 +425,9 @@ hash_tests = testGroup "Hash Functions" [
, chanUpdateChainHash = testChainHash
, chanUpdateShortChanId = testShortChannelId
, chanUpdateTimestamp = 1234567890
- , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = False }
, chanUpdateChanFlags = ChannelFlags
- { cfDirection = False, cfDisabled = False }
+ { cfDirection = NodeOne
+ , cfStatus = Enabled }
, chanUpdateCltvExpDelta = CltvExpiryDelta 144
, chanUpdateHtlcMinMsat = HtlcMinimumMsat 1000
, chanUpdateFeeBaseMsat = FeeBaseMsat 1000
@@ -447,9 +445,9 @@ hash_tests = testGroup "Hash Functions" [
, chanUpdateChainHash = testChainHash
, chanUpdateShortChanId = testShortChannelId
, chanUpdateTimestamp = 1234567890
- , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = False }
, chanUpdateChanFlags = ChannelFlags
- { cfDirection = False, cfDisabled = False }
+ { cfDirection = NodeOne
+ , cfStatus = Enabled }
, chanUpdateCltvExpDelta = CltvExpiryDelta 144
, chanUpdateHtlcMinMsat = HtlcMinimumMsat 1000
, chanUpdateFeeBaseMsat = FeeBaseMsat 1000
@@ -466,9 +464,9 @@ hash_tests = testGroup "Hash Functions" [
, chanUpdateChainHash = testChainHash
, chanUpdateShortChanId = testShortChannelId
, chanUpdateTimestamp = 1000000000
- , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = False }
, chanUpdateChanFlags = ChannelFlags
- { cfDirection = False, cfDisabled = False }
+ { cfDirection = NodeOne
+ , cfStatus = Enabled }
, chanUpdateCltvExpDelta = CltvExpiryDelta 144
, chanUpdateHtlcMinMsat = HtlcMinimumMsat 1000
, chanUpdateFeeBaseMsat = FeeBaseMsat 1000
@@ -480,9 +478,9 @@ hash_tests = testGroup "Hash Functions" [
, chanUpdateChainHash = testChainHash
, chanUpdateShortChanId = testShortChannelId
, chanUpdateTimestamp = 2000000000
- , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = False }
, chanUpdateChanFlags = ChannelFlags
- { cfDirection = False, cfDisabled = False }
+ { cfDirection = NodeOne
+ , cfStatus = Enabled }
, chanUpdateCltvExpDelta = CltvExpiryDelta 144
, chanUpdateHtlcMinMsat = HtlcMinimumMsat 1000
, chanUpdateFeeBaseMsat = FeeBaseMsat 1000
@@ -536,9 +534,9 @@ validation_tests = testGroup "Validation" [
, chanUpdateChainHash = testChainHash
, chanUpdateShortChanId = testShortChannelId
, chanUpdateTimestamp = 1234567890
- , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = True }
, chanUpdateChanFlags = ChannelFlags
- { cfDirection = False, cfDisabled = False }
+ { cfDirection = NodeOne
+ , cfStatus = Enabled }
, chanUpdateCltvExpDelta = CltvExpiryDelta 144
, chanUpdateHtlcMinMsat = HtlcMinimumMsat 1000
, chanUpdateFeeBaseMsat = FeeBaseMsat 1000
@@ -552,9 +550,9 @@ validation_tests = testGroup "Validation" [
, chanUpdateChainHash = testChainHash
, chanUpdateShortChanId = testShortChannelId
, chanUpdateTimestamp = 1234567890
- , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = True }
, chanUpdateChanFlags = ChannelFlags
- { cfDirection = False, cfDisabled = False }
+ { cfDirection = NodeOne
+ , cfStatus = Enabled }
, chanUpdateCltvExpDelta = CltvExpiryDelta 144
, chanUpdateHtlcMinMsat = HtlcMinimumMsat 2000000000 -- > htlcMax
, chanUpdateFeeBaseMsat = FeeBaseMsat 1000
@@ -567,16 +565,16 @@ validation_tests = testGroup "Validation" [
testCase "valid range passes" $ do
let msg = QueryChannelRange
{ queryRangeChainHash = testChainHash
- , queryRangeFirstBlock = 600000
- , queryRangeNumBlocks = 10000
+ , queryRangeFirstBlock = BlockHeight 600000
+ , queryRangeNumBlocks = BlockCount 10000
, queryRangeTlvs = emptyTlvs
}
validateQueryChannelRange msg @?= Right ()
, testCase "rejects overflow" $ do
let msg = QueryChannelRange
{ queryRangeChainHash = testChainHash
- , queryRangeFirstBlock = maxBound -- 0xFFFFFFFF
- , queryRangeNumBlocks = 10
+ , queryRangeFirstBlock = BlockHeight maxBound
+ , queryRangeNumBlocks = BlockCount 10
, queryRangeTlvs = emptyTlvs
}
validateQueryChannelRange msg @?= Left ValidateBlockOverflow
@@ -648,9 +646,9 @@ propChannelUpdateRoundtrip timestamp cltvDelta = property $ do
, chanUpdateChainHash = testChainHash
, chanUpdateShortChanId = testShortChannelId
, chanUpdateTimestamp = timestamp
- , chanUpdateMsgFlags = MessageFlags { mfHtlcMaxPresent = False }
, chanUpdateChanFlags = ChannelFlags
- { cfDirection = False, cfDisabled = False }
+ { cfDirection = NodeOne
+ , cfStatus = Enabled }
, chanUpdateCltvExpDelta = CltvExpiryDelta cltvDelta
, chanUpdateHtlcMinMsat = HtlcMinimumMsat 1000
, chanUpdateFeeBaseMsat = FeeBaseMsat 1000