commit 580036e8f5cb22d423a205abeb15fca33307267c
parent 20ea43188d781368e5e64c7c646285a6b0aaeb94
Author: Jared Tobin <jared@jtobin.io>
Date: Mon, 20 Apr 2026 15:03:34 +0800
Merge branch 'impl/type-safety'
Type safety improvements for MsgType, Envelope, ChainHash.
- Internal module with unsafe constructors for test use
- MsgUnknown hidden from public API; msgUnknown smart constructor
validates against known type codes
- Envelope constructor hidden; envelope smart constructor derives
envType from Message automatically
- unsafeChainHash added alongside existing unsafeChannelId
Diffstat:
5 files changed, 106 insertions(+), 30 deletions(-)
diff --git a/lib/Lightning/Protocol/BOLT1.hs b/lib/Lightning/Protocol/BOLT1.hs
@@ -12,7 +12,16 @@
module Lightning.Protocol.BOLT1 (
-- * Message types
Message(..)
- , MsgType(..)
+ , MsgType(
+ MsgInit
+ , MsgError
+ , MsgPing
+ , MsgPong
+ , MsgWarning
+ , MsgPeerStorage
+ , MsgPeerStorageRet
+ )
+ , msgUnknown
, msgTypeWord
-- * Channel identifiers
@@ -88,7 +97,11 @@ module Lightning.Protocol.BOLT1 (
, unChainHash
-- * Message envelope
- , Envelope(..)
+ , Envelope
+ , envelope
+ , envType
+ , envPayload
+ , envExtension
-- * Encoding
, EncodeError(..)
diff --git a/lib/Lightning/Protocol/BOLT1/Internal.hs b/lib/Lightning/Protocol/BOLT1/Internal.hs
@@ -0,0 +1,56 @@
+{-# OPTIONS_HADDOCK hide #-}
+
+-- |
+-- Module: Lightning.Protocol.BOLT1.Internal
+-- Copyright: (c) 2025 Jared Tobin
+-- License: MIT
+-- Maintainer: Jared Tobin <jared@ppad.tech>
+--
+-- Internal definitions for BOLT #1.
+--
+-- This module exports unsafe constructors that bypass
+-- validation. Use only in tests or trusted internal code.
+
+module Lightning.Protocol.BOLT1.Internal (
+ -- * Unsafe constructors
+ unsafeMsgUnknown
+ , unsafeEnvelope
+ , unsafeChainHash
+ , unsafeChannelId
+ ) where
+
+import qualified Data.ByteString as BS
+import Data.Word (Word16)
+import Lightning.Protocol.BOLT1.Message
+import Lightning.Protocol.BOLT1.Prim
+import Lightning.Protocol.BOLT1.TLV
+
+-- | Construct a 'MsgUnknown' without validation.
+--
+-- This bypasses the check that prevents wrapping known
+-- type codes. For test use only.
+unsafeMsgUnknown :: Word16 -> MsgType
+unsafeMsgUnknown = MsgUnknown
+
+-- | Construct an 'Envelope' without validation.
+--
+-- This bypasses the check that 'envType' matches the
+-- message. For test use only.
+unsafeEnvelope
+ :: MsgType
+ -> BS.ByteString
+ -> Maybe TlvStream
+ -> Envelope
+unsafeEnvelope = Envelope
+
+-- | Construct a 'ChainHash' without length validation.
+--
+-- For test use only.
+unsafeChainHash :: BS.ByteString -> ChainHash
+unsafeChainHash = ChainHash
+
+-- | Construct a 'ChannelId' without length validation.
+--
+-- For test use only.
+unsafeChannelId :: BS.ByteString -> ChannelId
+unsafeChannelId = ChannelId
diff --git a/lib/Lightning/Protocol/BOLT1/Message.hs b/lib/Lightning/Protocol/BOLT1/Message.hs
@@ -13,6 +13,7 @@
module Lightning.Protocol.BOLT1.Message (
-- * Message types
MsgType(..)
+ , msgUnknown
, msgTypeWord
, parseMsgType
@@ -39,6 +40,7 @@ module Lightning.Protocol.BOLT1.Message (
, Message(..)
, messageType
, Envelope(..)
+ , envelope
) where
import Control.DeepSeq (NFData)
@@ -87,6 +89,15 @@ parseMsgType 7 = MsgPeerStorage
parseMsgType 9 = MsgPeerStorageRet
parseMsgType w = MsgUnknown w
+-- | Smart constructor for unknown message types.
+--
+-- Returns the appropriate known constructor for known
+-- type codes (16, 17, 18, 19, 1, 7, 9) and only uses
+-- 'MsgUnknown' for truly unknown codes.
+msgUnknown :: Word16 -> MsgType
+msgUnknown = parseMsgType
+{-# INLINE msgUnknown #-}
+
-- Message ADTs ----------------------------------------------------------------
-- | The init message (type 16).
@@ -168,7 +179,8 @@ messageType (MsgPeerStorageRetrievalVal _) = MsgPeerStorageRet
-- Message envelope ------------------------------------------------------------
--- | A complete message envelope with type, payload, and optional extension.
+-- | A complete message envelope with type, payload,
+-- and optional extension.
data Envelope = Envelope
{ envType :: !MsgType
, envPayload :: !BS.ByteString
@@ -176,3 +188,14 @@ data Envelope = Envelope
} deriving stock (Eq, Show, Generic)
instance NFData Envelope
+
+-- | Construct an 'Envelope' from a 'Message' and optional
+-- extension TLV stream. The 'envType' is derived
+-- automatically from the 'Message'.
+envelope :: Message -> Maybe TlvStream -> Envelope
+envelope msg mext = Envelope
+ { envType = messageType msg
+ , envPayload = BS.empty
+ , envExtension = mext
+ }
+{-# INLINE envelope #-}
diff --git a/ppad-bolt1.cabal b/ppad-bolt1.cabal
@@ -26,6 +26,7 @@ library
exposed-modules:
Lightning.Protocol.BOLT1
Lightning.Protocol.BOLT1.Codec
+ Lightning.Protocol.BOLT1.Internal
Lightning.Protocol.BOLT1.Message
Lightning.Protocol.BOLT1.Prim
Lightning.Protocol.BOLT1.TLV
diff --git a/test/Main.hs b/test/Main.hs
@@ -5,6 +5,8 @@ module Main where
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base16 as B16
import Lightning.Protocol.BOLT1
+import Lightning.Protocol.BOLT1.Internal
+ (unsafeChannelId, unsafeChainHash, unsafeMsgUnknown)
import Test.Tasty
import Test.Tasty.HUnit
import Test.Tasty.QuickCheck
@@ -454,11 +456,11 @@ message_tests = testGroup "Messages" [
]
, testGroup "Unknown types" [
testCase "decodeMessage unknown even type" $ do
- case decodeMessage (MsgUnknown 100) "payload" of
+ case decodeMessage (unsafeMsgUnknown 100) "payload" of
Left (DecodeUnknownEvenType 100) -> pure ()
other -> assertFailure $ "expected unknown even: " ++ show other
, testCase "decodeMessage unknown odd type" $ do
- case decodeMessage (MsgUnknown 101) "payload" of
+ case decodeMessage (unsafeMsgUnknown 101) "payload" of
Left (DecodeUnknownOddType 101) -> pure ()
other -> assertFailure $ "expected unknown odd: " ++ show other
]
@@ -534,8 +536,8 @@ extension_tests = testGroup "Extension TLV" [
-- Per BOLT #1: unknown even types must cause failure
let pingPayload = mconcat [encodeU16 10, encodeU16 0] -- numPong=10, len=0
extTlv = mconcat [encodeBigSize 100, encodeBigSize 3, "abc"] -- even!
- envelope = encodeU16 18 <> pingPayload <> extTlv -- type 18 = ping
- case decodeEnvelope envelope of
+ env = encodeU16 18 <> pingPayload <> extTlv -- type 18 = ping
+ case decodeEnvelope env of
Left (DecodeInvalidExtension (TlvUnknownEvenType 100)) -> pure ()
other -> assertFailure $ "expected unknown even error: " ++ show other
, testCase "decode envelope with invalid extension fails" $ do
@@ -545,8 +547,8 @@ extension_tests = testGroup "Extension TLV" [
encodeBigSize 101, encodeBigSize 1, "a" -- odd types for this test
, encodeBigSize 51, encodeBigSize 1, "b" -- 51 < 101, invalid
]
- envelope = encodeU16 18 <> pingPayload <> badTlv
- case decodeEnvelope envelope of
+ env = encodeU16 18 <> pingPayload <> badTlv
+ case decodeEnvelope env of
Left (DecodeInvalidExtension TlvNotStrictlyIncreasing) -> pure ()
other -> assertFailure $ "expected invalid extension: " ++ show other
, testCase "unknown even in extension fails even with odd types present" $ do
@@ -556,8 +558,8 @@ extension_tests = testGroup "Extension TLV" [
encodeBigSize 101, encodeBigSize 1, "a" -- odd, would be skipped
, encodeBigSize 200, encodeBigSize 1, "b" -- even, must fail
]
- envelope = encodeU16 18 <> pingPayload <> extTlv
- case decodeEnvelope envelope of
+ env = encodeU16 18 <> pingPayload <> extTlv
+ case decodeEnvelope env of
Left (DecodeInvalidExtension (TlvUnknownEvenType 200)) -> pure ()
other -> assertFailure $ "expected unknown even error: " ++ show other
]
@@ -669,15 +671,6 @@ property_tests = testGroup "Properties" [
-- Helpers ---------------------------------------------------------------------
--- | Construct a 'ChannelId' from a known-valid 32-byte 'BS.ByteString'.
---
--- Uses 'error' for invalid input since all channel IDs in tests are
--- known-valid compile-time constants.
-unsafeChannelId :: BS.ByteString -> ChannelId
-unsafeChannelId bs = case channelId bs of
- Just cid -> cid
- Nothing -> error $ "unsafeChannelId: invalid length: " ++ show (BS.length bs)
-
-- | Decode hex string (test-only helper).
--
-- Uses 'error' for invalid hex since all hex literals in tests are
@@ -687,13 +680,3 @@ unhex :: BS.ByteString -> BS.ByteString
unhex bs = case B16.decode bs of
Just r -> r
Nothing -> error $ "unhex: invalid hex literal: " ++ show bs
-
--- | Construct a ChainHash from a bytestring (test-only helper).
---
--- Uses 'error' for invalid input since all chain hashes in tests are
--- known-valid 32-byte constants. This is acceptable in test code where
--- the failure would indicate a bug in the test itself.
-unsafeChainHash :: BS.ByteString -> ChainHash
-unsafeChainHash bs = case chainHash bs of
- Just c -> c
- Nothing -> error $ "unsafeChainHash: not 32 bytes: " ++ show (BS.length bs)