commit 7dc02f80403bf219e8322850a3575563217e7dc1
parent af22e72c4b9258f4ec5f48af5c9bc4bf5260cdfb
Author: Jared Tobin <jared@jtobin.io>
Date: Sun, 25 Jan 2026 14:30:12 +0400
feat: add ChannelId newtype for type-safe channel identifiers
- Add ChannelId newtype with 32-byte validation
- Add channelId smart constructor and allChannels constant
- Update Error/Warning to use ChannelId
- Update Codec encode/decode for ChannelId
- Update tests
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat:
4 files changed, 70 insertions(+), 10 deletions(-)
diff --git a/lib/Lightning/Protocol/BOLT1.hs b/lib/Lightning/Protocol/BOLT1.hs
@@ -15,6 +15,11 @@ module Lightning.Protocol.BOLT1 (
, MsgType(..)
, msgTypeWord
+ -- * Channel identifiers
+ , ChannelId
+ , channelId
+ , allChannels
+
-- ** Setup messages
, Init(..)
, Error(..)
diff --git a/lib/Lightning/Protocol/BOLT1/Codec.hs b/lib/Lightning/Protocol/BOLT1/Codec.hs
@@ -81,13 +81,13 @@ encodeInit (Init gf feat tlvs) = do
encodeError :: Error -> Either EncodeError BS.ByteString
encodeError (Error cid dat) = do
datLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength dat)
- Right $ mconcat [cid, datLen, dat]
+ Right $ mconcat [unChannelId cid, datLen, dat]
-- | Encode a Warning message payload.
encodeWarning :: Warning -> Either EncodeError BS.ByteString
encodeWarning (Warning cid dat) = do
datLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength dat)
- Right $ mconcat [cid, datLen, dat]
+ Right $ mconcat [unChannelId cid, datLen, dat]
-- | Encode a Ping message payload.
encodePing :: Ping -> Either EncodeError BS.ByteString
@@ -194,8 +194,9 @@ decodeInit !bs = do
decodeError :: BS.ByteString -> Either DecodeError (Error, BS.ByteString)
decodeError !bs = do
unless (BS.length bs >= 32) $ Left DecodeInsufficientBytes
- let !cid = BS.take 32 bs
+ let !cidBytes = BS.take 32 bs
!rest1 = BS.drop 32 bs
+ cid <- maybe (Left DecodeInvalidChannelId) Right (channelId cidBytes)
(dLen, rest2) <- maybe (Left DecodeInsufficientBytes) Right
(decodeU16 rest1)
unless (BS.length rest2 >= fromIntegral dLen) $
@@ -208,8 +209,9 @@ decodeError !bs = do
decodeWarning :: BS.ByteString -> Either DecodeError (Warning, BS.ByteString)
decodeWarning !bs = do
unless (BS.length bs >= 32) $ Left DecodeInsufficientBytes
- let !cid = BS.take 32 bs
+ let !cidBytes = BS.take 32 bs
!rest1 = BS.drop 32 bs
+ cid <- maybe (Left DecodeInvalidChannelId) Right (channelId cidBytes)
(dLen, rest2) <- maybe (Left DecodeInsufficientBytes) Right
(decodeU16 rest1)
unless (BS.length rest2 >= fromIntegral dLen) $
diff --git a/lib/Lightning/Protocol/BOLT1/Message.hs b/lib/Lightning/Protocol/BOLT1/Message.hs
@@ -16,6 +16,12 @@ module Lightning.Protocol.BOLT1.Message (
, msgTypeWord
, parseMsgType
+ -- * Channel identifiers
+ , ChannelId
+ , channelId
+ , unChannelId
+ , allChannels
+
-- * Setup messages
, Init(..)
, Error(..)
@@ -79,6 +85,43 @@ parseMsgType 7 = MsgPeerStorage
parseMsgType 9 = MsgPeerStorageRet
parseMsgType w = MsgUnknown w
+-- Channel identifiers ---------------------------------------------------------
+
+-- | A 32-byte channel identifier.
+--
+-- Use 'channelId' to construct, which validates the length.
+-- Use 'allChannels' for connection-level errors (all-zeros channel ID).
+newtype ChannelId = ChannelId BS.ByteString
+ deriving stock (Eq, Show, Generic)
+
+instance NFData ChannelId
+
+-- | Construct a 'ChannelId' from a 32-byte 'BS.ByteString'.
+--
+-- Returns 'Nothing' if the input is not exactly 32 bytes.
+--
+-- >>> channelId (BS.replicate 32 0x00)
+-- Just (ChannelId "\NUL\NUL...")
+-- >>> channelId "too short"
+-- Nothing
+channelId :: BS.ByteString -> Maybe ChannelId
+channelId bs
+ | BS.length bs == 32 = Just (ChannelId bs)
+ | otherwise = Nothing
+{-# INLINE channelId #-}
+
+-- | The all-zeros channel ID, used for connection-level errors.
+--
+-- Per BOLT #1, setting channel_id to all zeros means the error applies
+-- to the connection rather than a specific channel.
+allChannels :: ChannelId
+allChannels = ChannelId (BS.replicate 32 0x00)
+
+-- | Extract the raw bytes from a 'ChannelId'.
+unChannelId :: ChannelId -> BS.ByteString
+unChannelId (ChannelId bs) = bs
+{-# INLINE unChannelId #-}
+
-- Message ADTs ----------------------------------------------------------------
-- | The init message (type 16).
@@ -92,7 +135,7 @@ instance NFData Init
-- | The error message (type 17).
data Error = Error
- { errorChannelId :: !BS.ByteString -- ^ 32 bytes
+ { errorChannelId :: !ChannelId
, errorData :: !BS.ByteString
} deriving stock (Eq, Show, Generic)
@@ -100,7 +143,7 @@ instance NFData Error
-- | The warning message (type 1).
data Warning = Warning
- { warningChannelId :: !BS.ByteString -- ^ 32 bytes
+ { warningChannelId :: !ChannelId
, warningData :: !BS.ByteString
} deriving stock (Eq, Show, Generic)
diff --git a/test/Main.hs b/test/Main.hs
@@ -366,7 +366,7 @@ message_tests = testGroup "Messages" [
]
, testGroup "Error" [
testCase "encode/decode error" $ do
- let cid = BS.replicate 32 0xff
+ let cid = unsafeChannelId (BS.replicate 32 0xff)
msg = Error cid "something went wrong"
case encodeMessage (MsgErrorVal msg) of
Left e -> assertFailure $ "encode failed: " ++ show e
@@ -380,7 +380,7 @@ message_tests = testGroup "Messages" [
]
, testGroup "Warning" [
testCase "encode/decode warning" $ do
- let cid = BS.replicate 32 0x00
+ let cid = unsafeChannelId (BS.replicate 32 0x00)
msg = Warning cid "be careful"
case encodeMessage (MsgWarningVal msg) of
Left e -> assertFailure $ "encode failed: " ++ show e
@@ -555,7 +555,8 @@ bounds_tests = testGroup "Bounds checking" [
Left EncodeLengthOverflow -> pure ()
other -> assertFailure $ "expected overflow: " ++ show other
, testCase "encode error with oversized data fails" $ do
- let msg = Error (BS.replicate 32 0x00) (BS.replicate 70000 0x00)
+ let cid = unsafeChannelId (BS.replicate 32 0x00)
+ msg = Error cid (BS.replicate 70000 0x00)
case encodeMessage (MsgErrorVal msg) of
Left EncodeLengthOverflow -> pure ()
other -> assertFailure $ "expected overflow: " ++ show other
@@ -621,7 +622,7 @@ property_tests = testGroup "Properties" [
decoded == msg && BS.null rest
_ -> False
, testProperty "Error roundtrip" $ \bs ->
- let cid = BS.replicate 32 0x00
+ let cid = unsafeChannelId (BS.replicate 32 0x00)
dat = BS.pack (take 1000 bs)
msg = Error cid dat
in case encodeMessage (MsgErrorVal msg) of
@@ -645,6 +646,15 @@ property_tests = testGroup "Properties" [
-- Helpers ---------------------------------------------------------------------
+-- | Construct a 'ChannelId' from a known-valid 32-byte 'BS.ByteString'.
+--
+-- Uses 'error' for invalid input since all channel IDs in tests are
+-- known-valid compile-time constants.
+unsafeChannelId :: BS.ByteString -> ChannelId
+unsafeChannelId bs = case channelId bs of
+ Just cid -> cid
+ Nothing -> error $ "unsafeChannelId: invalid length: " ++ show (BS.length bs)
+
-- | Decode hex string (test-only helper).
--
-- Uses 'error' for invalid hex since all hex literals in tests are