bolt1

Base Lightning protocol, per BOLT #1.
git clone git://git.ppad.tech/bolt1.git
Log | Files | Refs | README | LICENSE

commit 5b57c48bc7effea56205c9b9dcbdd0ef10fdc37f
parent af22e72c4b9258f4ec5f48af5c9bc4bf5260cdfb
Author: Jared Tobin <jared@jtobin.io>
Date:   Sun, 25 Jan 2026 14:30:07 +0400

feat: add ChannelId newtype for type-safe channel identifiers

- Add ChannelId newtype wrapping 32-byte ByteString
- Add channelId smart constructor with length validation
- Add allChannels constant for connection-level errors
- Add unChannelId accessor for Codec use
- Update Error/Warning records to use ChannelId
- Update Codec encode/decode to handle ChannelId
- Update tests to use unsafeChannelId helper

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Diffstat:
Mlib/Lightning/Protocol/BOLT1.hs | 5+++++
Mlib/Lightning/Protocol/BOLT1/Codec.hs | 10++++++----
Mlib/Lightning/Protocol/BOLT1/Message.hs | 47+++++++++++++++++++++++++++++++++++++++++++++--
Mtest/Main.hs | 18++++++++++++++----
4 files changed, 70 insertions(+), 10 deletions(-)

diff --git a/lib/Lightning/Protocol/BOLT1.hs b/lib/Lightning/Protocol/BOLT1.hs @@ -15,6 +15,11 @@ module Lightning.Protocol.BOLT1 ( , MsgType(..) , msgTypeWord + -- * Channel identifiers + , ChannelId + , channelId + , allChannels + -- ** Setup messages , Init(..) , Error(..) diff --git a/lib/Lightning/Protocol/BOLT1/Codec.hs b/lib/Lightning/Protocol/BOLT1/Codec.hs @@ -81,13 +81,13 @@ encodeInit (Init gf feat tlvs) = do encodeError :: Error -> Either EncodeError BS.ByteString encodeError (Error cid dat) = do datLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength dat) - Right $ mconcat [cid, datLen, dat] + Right $ mconcat [unChannelId cid, datLen, dat] -- | Encode a Warning message payload. encodeWarning :: Warning -> Either EncodeError BS.ByteString encodeWarning (Warning cid dat) = do datLen <- maybe (Left EncodeLengthOverflow) Right (encodeLength dat) - Right $ mconcat [cid, datLen, dat] + Right $ mconcat [unChannelId cid, datLen, dat] -- | Encode a Ping message payload. encodePing :: Ping -> Either EncodeError BS.ByteString @@ -194,8 +194,9 @@ decodeInit !bs = do decodeError :: BS.ByteString -> Either DecodeError (Error, BS.ByteString) decodeError !bs = do unless (BS.length bs >= 32) $ Left DecodeInsufficientBytes - let !cid = BS.take 32 bs + let !cidBytes = BS.take 32 bs !rest1 = BS.drop 32 bs + cid <- maybe (Left DecodeInvalidChannelId) Right (channelId cidBytes) (dLen, rest2) <- maybe (Left DecodeInsufficientBytes) Right (decodeU16 rest1) unless (BS.length rest2 >= fromIntegral dLen) $ @@ -208,8 +209,9 @@ decodeError !bs = do decodeWarning :: BS.ByteString -> Either DecodeError (Warning, BS.ByteString) decodeWarning !bs = do unless (BS.length bs >= 32) $ Left DecodeInsufficientBytes - let !cid = BS.take 32 bs + let !cidBytes = BS.take 32 bs !rest1 = BS.drop 32 bs + cid <- maybe (Left DecodeInvalidChannelId) Right (channelId cidBytes) (dLen, rest2) <- maybe (Left DecodeInsufficientBytes) Right (decodeU16 rest1) unless (BS.length rest2 >= fromIntegral dLen) $ diff --git a/lib/Lightning/Protocol/BOLT1/Message.hs b/lib/Lightning/Protocol/BOLT1/Message.hs @@ -16,6 +16,12 @@ module Lightning.Protocol.BOLT1.Message ( , msgTypeWord , parseMsgType + -- * Channel identifiers + , ChannelId + , channelId + , unChannelId + , allChannels + -- * Setup messages , Init(..) , Error(..) @@ -79,6 +85,43 @@ parseMsgType 7 = MsgPeerStorage parseMsgType 9 = MsgPeerStorageRet parseMsgType w = MsgUnknown w +-- Channel identifiers --------------------------------------------------------- + +-- | A 32-byte channel identifier. +-- +-- Use 'channelId' to construct, which validates the length. +-- Use 'allChannels' for connection-level errors (all-zeros channel ID). +newtype ChannelId = ChannelId BS.ByteString + deriving stock (Eq, Show, Generic) + +instance NFData ChannelId + +-- | Construct a 'ChannelId' from a 32-byte 'BS.ByteString'. +-- +-- Returns 'Nothing' if the input is not exactly 32 bytes. +-- +-- >>> channelId (BS.replicate 32 0x00) +-- Just (ChannelId "\NUL\NUL...") +-- >>> channelId "too short" +-- Nothing +channelId :: BS.ByteString -> Maybe ChannelId +channelId bs + | BS.length bs == 32 = Just (ChannelId bs) + | otherwise = Nothing +{-# INLINE channelId #-} + +-- | The all-zeros channel ID, used for connection-level errors. +-- +-- Per BOLT #1, setting channel_id to all zeros means the error applies +-- to the connection rather than a specific channel. +allChannels :: ChannelId +allChannels = ChannelId (BS.replicate 32 0x00) + +-- | Extract the raw bytes from a 'ChannelId'. +unChannelId :: ChannelId -> BS.ByteString +unChannelId (ChannelId bs) = bs +{-# INLINE unChannelId #-} + -- Message ADTs ---------------------------------------------------------------- -- | The init message (type 16). @@ -92,7 +135,7 @@ instance NFData Init -- | The error message (type 17). data Error = Error - { errorChannelId :: !BS.ByteString -- ^ 32 bytes + { errorChannelId :: !ChannelId , errorData :: !BS.ByteString } deriving stock (Eq, Show, Generic) @@ -100,7 +143,7 @@ instance NFData Error -- | The warning message (type 1). data Warning = Warning - { warningChannelId :: !BS.ByteString -- ^ 32 bytes + { warningChannelId :: !ChannelId , warningData :: !BS.ByteString } deriving stock (Eq, Show, Generic) diff --git a/test/Main.hs b/test/Main.hs @@ -366,7 +366,7 @@ message_tests = testGroup "Messages" [ ] , testGroup "Error" [ testCase "encode/decode error" $ do - let cid = BS.replicate 32 0xff + let cid = unsafeChannelId (BS.replicate 32 0xff) msg = Error cid "something went wrong" case encodeMessage (MsgErrorVal msg) of Left e -> assertFailure $ "encode failed: " ++ show e @@ -380,7 +380,7 @@ message_tests = testGroup "Messages" [ ] , testGroup "Warning" [ testCase "encode/decode warning" $ do - let cid = BS.replicate 32 0x00 + let cid = unsafeChannelId (BS.replicate 32 0x00) msg = Warning cid "be careful" case encodeMessage (MsgWarningVal msg) of Left e -> assertFailure $ "encode failed: " ++ show e @@ -555,7 +555,8 @@ bounds_tests = testGroup "Bounds checking" [ Left EncodeLengthOverflow -> pure () other -> assertFailure $ "expected overflow: " ++ show other , testCase "encode error with oversized data fails" $ do - let msg = Error (BS.replicate 32 0x00) (BS.replicate 70000 0x00) + let cid = unsafeChannelId (BS.replicate 32 0x00) + msg = Error cid (BS.replicate 70000 0x00) case encodeMessage (MsgErrorVal msg) of Left EncodeLengthOverflow -> pure () other -> assertFailure $ "expected overflow: " ++ show other @@ -621,7 +622,7 @@ property_tests = testGroup "Properties" [ decoded == msg && BS.null rest _ -> False , testProperty "Error roundtrip" $ \bs -> - let cid = BS.replicate 32 0x00 + let cid = unsafeChannelId (BS.replicate 32 0x00) dat = BS.pack (take 1000 bs) msg = Error cid dat in case encodeMessage (MsgErrorVal msg) of @@ -645,6 +646,15 @@ property_tests = testGroup "Properties" [ -- Helpers --------------------------------------------------------------------- +-- | Construct a 'ChannelId' from a known-valid 32-byte 'BS.ByteString'. +-- +-- Uses 'error' for invalid input since all channel IDs in tests are +-- known-valid compile-time constants. +unsafeChannelId :: BS.ByteString -> ChannelId +unsafeChannelId bs = case channelId bs of + Just cid -> cid + Nothing -> error $ "unsafeChannelId: invalid length: " ++ show (BS.length bs) + -- | Decode hex string (test-only helper). -- -- Uses 'error' for invalid hex since all hex literals in tests are