bolt1

Base Lightning protocol, per BOLT #1 (docs.ppad.tech/bolt1).
git clone git://git.ppad.tech/bolt1.git
Log | Files | Refs | README | LICENSE

commit 135531bf49bc5e479928803cfa1b09d3a0787dd2
parent 20ea43188d781368e5e64c7c646285a6b0aaeb94
Author: Jared Tobin <jared@jtobin.io>
Date:   Mon, 20 Apr 2026 14:55:58 +0800

lib: tighten type safety for MsgType, Envelope, ChainHash

Add Internal module with unsafe constructors for test use:
- unsafeMsgUnknown: bypass known-code validation
- unsafeEnvelope: bypass type/message consistency check
- unsafeChainHash: bypass length validation
- unsafeChannelId: bypass length validation

Hide MsgUnknown constructor from public API (BOLT1.hs);
add msgUnknown smart constructor that maps known codes to
their proper constructors.

Hide Envelope data constructor from public API; add
envelope smart constructor that derives envType from the
Message, and export field accessors.

Update tests to import unsafe constructors from Internal.
Rename local 'envelope' bindings to avoid shadowing.

No behaviour changes; all 151 tests pass.

Diffstat:
Mlib/Lightning/Protocol/BOLT1.hs | 17+++++++++++++++--
Alib/Lightning/Protocol/BOLT1/Internal.hs | 56++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Mlib/Lightning/Protocol/BOLT1/Message.hs | 25++++++++++++++++++++++++-
Mppad-bolt1.cabal | 1+
Mtest/Main.hs | 37++++++++++---------------------------
5 files changed, 106 insertions(+), 30 deletions(-)

diff --git a/lib/Lightning/Protocol/BOLT1.hs b/lib/Lightning/Protocol/BOLT1.hs @@ -12,7 +12,16 @@ module Lightning.Protocol.BOLT1 ( -- * Message types Message(..) - , MsgType(..) + , MsgType( + MsgInit + , MsgError + , MsgPing + , MsgPong + , MsgWarning + , MsgPeerStorage + , MsgPeerStorageRet + ) + , msgUnknown , msgTypeWord -- * Channel identifiers @@ -88,7 +97,11 @@ module Lightning.Protocol.BOLT1 ( , unChainHash -- * Message envelope - , Envelope(..) + , Envelope + , envelope + , envType + , envPayload + , envExtension -- * Encoding , EncodeError(..) diff --git a/lib/Lightning/Protocol/BOLT1/Internal.hs b/lib/Lightning/Protocol/BOLT1/Internal.hs @@ -0,0 +1,56 @@ +{-# OPTIONS_HADDOCK hide #-} + +-- | +-- Module: Lightning.Protocol.BOLT1.Internal +-- Copyright: (c) 2025 Jared Tobin +-- License: MIT +-- Maintainer: Jared Tobin <jared@ppad.tech> +-- +-- Internal definitions for BOLT #1. +-- +-- This module exports unsafe constructors that bypass +-- validation. Use only in tests or trusted internal code. + +module Lightning.Protocol.BOLT1.Internal ( + -- * Unsafe constructors + unsafeMsgUnknown + , unsafeEnvelope + , unsafeChainHash + , unsafeChannelId + ) where + +import qualified Data.ByteString as BS +import Data.Word (Word16) +import Lightning.Protocol.BOLT1.Message +import Lightning.Protocol.BOLT1.Prim +import Lightning.Protocol.BOLT1.TLV + +-- | Construct a 'MsgUnknown' without validation. +-- +-- This bypasses the check that prevents wrapping known +-- type codes. For test use only. +unsafeMsgUnknown :: Word16 -> MsgType +unsafeMsgUnknown = MsgUnknown + +-- | Construct an 'Envelope' without validation. +-- +-- This bypasses the check that 'envType' matches the +-- message. For test use only. +unsafeEnvelope + :: MsgType + -> BS.ByteString + -> Maybe TlvStream + -> Envelope +unsafeEnvelope = Envelope + +-- | Construct a 'ChainHash' without length validation. +-- +-- For test use only. +unsafeChainHash :: BS.ByteString -> ChainHash +unsafeChainHash = ChainHash + +-- | Construct a 'ChannelId' without length validation. +-- +-- For test use only. +unsafeChannelId :: BS.ByteString -> ChannelId +unsafeChannelId = ChannelId diff --git a/lib/Lightning/Protocol/BOLT1/Message.hs b/lib/Lightning/Protocol/BOLT1/Message.hs @@ -13,6 +13,7 @@ module Lightning.Protocol.BOLT1.Message ( -- * Message types MsgType(..) + , msgUnknown , msgTypeWord , parseMsgType @@ -39,6 +40,7 @@ module Lightning.Protocol.BOLT1.Message ( , Message(..) , messageType , Envelope(..) + , envelope ) where import Control.DeepSeq (NFData) @@ -87,6 +89,15 @@ parseMsgType 7 = MsgPeerStorage parseMsgType 9 = MsgPeerStorageRet parseMsgType w = MsgUnknown w +-- | Smart constructor for unknown message types. +-- +-- Returns the appropriate known constructor for known +-- type codes (16, 17, 18, 19, 1, 7, 9) and only uses +-- 'MsgUnknown' for truly unknown codes. +msgUnknown :: Word16 -> MsgType +msgUnknown = parseMsgType +{-# INLINE msgUnknown #-} + -- Message ADTs ---------------------------------------------------------------- -- | The init message (type 16). @@ -168,7 +179,8 @@ messageType (MsgPeerStorageRetrievalVal _) = MsgPeerStorageRet -- Message envelope ------------------------------------------------------------ --- | A complete message envelope with type, payload, and optional extension. +-- | A complete message envelope with type, payload, +-- and optional extension. data Envelope = Envelope { envType :: !MsgType , envPayload :: !BS.ByteString @@ -176,3 +188,14 @@ data Envelope = Envelope } deriving stock (Eq, Show, Generic) instance NFData Envelope + +-- | Construct an 'Envelope' from a 'Message' and optional +-- extension TLV stream. The 'envType' is derived +-- automatically from the 'Message'. +envelope :: Message -> Maybe TlvStream -> Envelope +envelope msg mext = Envelope + { envType = messageType msg + , envPayload = BS.empty + , envExtension = mext + } +{-# INLINE envelope #-} diff --git a/ppad-bolt1.cabal b/ppad-bolt1.cabal @@ -26,6 +26,7 @@ library exposed-modules: Lightning.Protocol.BOLT1 Lightning.Protocol.BOLT1.Codec + Lightning.Protocol.BOLT1.Internal Lightning.Protocol.BOLT1.Message Lightning.Protocol.BOLT1.Prim Lightning.Protocol.BOLT1.TLV diff --git a/test/Main.hs b/test/Main.hs @@ -5,6 +5,8 @@ module Main where import qualified Data.ByteString as BS import qualified Data.ByteString.Base16 as B16 import Lightning.Protocol.BOLT1 +import Lightning.Protocol.BOLT1.Internal + (unsafeChannelId, unsafeChainHash, unsafeMsgUnknown) import Test.Tasty import Test.Tasty.HUnit import Test.Tasty.QuickCheck @@ -454,11 +456,11 @@ message_tests = testGroup "Messages" [ ] , testGroup "Unknown types" [ testCase "decodeMessage unknown even type" $ do - case decodeMessage (MsgUnknown 100) "payload" of + case decodeMessage (unsafeMsgUnknown 100) "payload" of Left (DecodeUnknownEvenType 100) -> pure () other -> assertFailure $ "expected unknown even: " ++ show other , testCase "decodeMessage unknown odd type" $ do - case decodeMessage (MsgUnknown 101) "payload" of + case decodeMessage (unsafeMsgUnknown 101) "payload" of Left (DecodeUnknownOddType 101) -> pure () other -> assertFailure $ "expected unknown odd: " ++ show other ] @@ -534,8 +536,8 @@ extension_tests = testGroup "Extension TLV" [ -- Per BOLT #1: unknown even types must cause failure let pingPayload = mconcat [encodeU16 10, encodeU16 0] -- numPong=10, len=0 extTlv = mconcat [encodeBigSize 100, encodeBigSize 3, "abc"] -- even! - envelope = encodeU16 18 <> pingPayload <> extTlv -- type 18 = ping - case decodeEnvelope envelope of + env = encodeU16 18 <> pingPayload <> extTlv -- type 18 = ping + case decodeEnvelope env of Left (DecodeInvalidExtension (TlvUnknownEvenType 100)) -> pure () other -> assertFailure $ "expected unknown even error: " ++ show other , testCase "decode envelope with invalid extension fails" $ do @@ -545,8 +547,8 @@ extension_tests = testGroup "Extension TLV" [ encodeBigSize 101, encodeBigSize 1, "a" -- odd types for this test , encodeBigSize 51, encodeBigSize 1, "b" -- 51 < 101, invalid ] - envelope = encodeU16 18 <> pingPayload <> badTlv - case decodeEnvelope envelope of + env = encodeU16 18 <> pingPayload <> badTlv + case decodeEnvelope env of Left (DecodeInvalidExtension TlvNotStrictlyIncreasing) -> pure () other -> assertFailure $ "expected invalid extension: " ++ show other , testCase "unknown even in extension fails even with odd types present" $ do @@ -556,8 +558,8 @@ extension_tests = testGroup "Extension TLV" [ encodeBigSize 101, encodeBigSize 1, "a" -- odd, would be skipped , encodeBigSize 200, encodeBigSize 1, "b" -- even, must fail ] - envelope = encodeU16 18 <> pingPayload <> extTlv - case decodeEnvelope envelope of + env = encodeU16 18 <> pingPayload <> extTlv + case decodeEnvelope env of Left (DecodeInvalidExtension (TlvUnknownEvenType 200)) -> pure () other -> assertFailure $ "expected unknown even error: " ++ show other ] @@ -669,15 +671,6 @@ property_tests = testGroup "Properties" [ -- Helpers --------------------------------------------------------------------- --- | Construct a 'ChannelId' from a known-valid 32-byte 'BS.ByteString'. --- --- Uses 'error' for invalid input since all channel IDs in tests are --- known-valid compile-time constants. -unsafeChannelId :: BS.ByteString -> ChannelId -unsafeChannelId bs = case channelId bs of - Just cid -> cid - Nothing -> error $ "unsafeChannelId: invalid length: " ++ show (BS.length bs) - -- | Decode hex string (test-only helper). -- -- Uses 'error' for invalid hex since all hex literals in tests are @@ -687,13 +680,3 @@ unhex :: BS.ByteString -> BS.ByteString unhex bs = case B16.decode bs of Just r -> r Nothing -> error $ "unhex: invalid hex literal: " ++ show bs - --- | Construct a ChainHash from a bytestring (test-only helper). --- --- Uses 'error' for invalid input since all chain hashes in tests are --- known-valid 32-byte constants. This is acceptable in test code where --- the failure would indicate a bug in the test itself. -unsafeChainHash :: BS.ByteString -> ChainHash -unsafeChainHash bs = case chainHash bs of - Just c -> c - Nothing -> error $ "unsafeChainHash: not 32 bytes: " ++ show (BS.length bs)