commit 4f4dd33a850eb03978de3ea49924ff13130a453f
parent af22e72c4b9258f4ec5f48af5c9bc4bf5260cdfb
Author: Jared Tobin <jared@jtobin.io>
Date: Sun, 25 Jan 2026 14:30:55 +0400
feat: add ChainHash newtype for type-safe chain hashes
- Add ChainHash newtype with 32-byte validation in Prim module
- Add chainHash smart constructor and unChainHash accessor
- Update InitNetworks to use [ChainHash] instead of [ByteString]
- Update TLV parsing/encoding to wrap/unwrap ChainHash
- Add unsafeChainHash test helper for known-valid test data
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat:
4 files changed, 62 insertions(+), 6 deletions(-)
diff --git a/lib/Lightning/Protocol/BOLT1.hs b/lib/Lightning/Protocol/BOLT1.hs
@@ -39,6 +39,9 @@ module Lightning.Protocol.BOLT1 (
-- ** Init TLVs
, InitTlv(..)
+ , ChainHash
+ , chainHash
+ , unChainHash
-- * Message envelope
, Envelope(..)
diff --git a/lib/Lightning/Protocol/BOLT1/Prim.hs b/lib/Lightning/Protocol/BOLT1/Prim.hs
@@ -1,5 +1,7 @@
{-# OPTIONS_HADDOCK prune #-}
{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE DerivingStrategies #-}
-- |
-- Module: Lightning.Protocol.BOLT1.Prim
@@ -10,8 +12,13 @@
-- Primitive type encoding and decoding for BOLT #1.
module Lightning.Protocol.BOLT1.Prim (
+ -- * Chain hash
+ ChainHash
+ , chainHash
+ , unChainHash
+
-- * Unsigned integer encoding
- encodeU16
+ , encodeU16
, encodeU32
, encodeU64
@@ -58,12 +65,36 @@ module Lightning.Protocol.BOLT1.Prim (
, encodeLength
) where
+import Control.DeepSeq (NFData)
import Data.Bits (unsafeShiftL, unsafeShiftR, (.|.))
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as BSB
import qualified Data.ByteString.Lazy as BSL
import Data.Int (Int8, Int16, Int32, Int64)
import Data.Word (Word16, Word32, Word64)
+import GHC.Generics (Generic)
+
+-- Chain hash ------------------------------------------------------------------
+
+-- | A chain hash (32-byte hash identifying a blockchain).
+newtype ChainHash = ChainHash BS.ByteString
+ deriving stock (Eq, Show, Generic)
+
+instance NFData ChainHash
+
+-- | Construct a chain hash from a 32-byte bytestring.
+--
+-- Returns 'Nothing' if the input is not exactly 32 bytes.
+chainHash :: BS.ByteString -> Maybe ChainHash
+chainHash bs
+ | BS.length bs == 32 = Just (ChainHash bs)
+ | otherwise = Nothing
+{-# INLINE chainHash #-}
+
+-- | Extract the raw bytes from a chain hash.
+unChainHash :: ChainHash -> BS.ByteString
+unChainHash (ChainHash bs) = bs
+{-# INLINE unChainHash #-}
-- Unsigned integer encoding ---------------------------------------------------
diff --git a/lib/Lightning/Protocol/BOLT1/TLV.hs b/lib/Lightning/Protocol/BOLT1/TLV.hs
@@ -30,6 +30,11 @@ module Lightning.Protocol.BOLT1.TLV (
, InitTlv(..)
, parseInitTlvs
, encodeInitTlvs
+
+ -- * Re-exports
+ , ChainHash
+ , chainHash
+ , unChainHash
) where
import Control.DeepSeq (NFData)
@@ -174,7 +179,7 @@ decodeTlvStream = decodeTlvStreamWith isInitTlvType
-- | TLV records for init message.
data InitTlv
- = InitNetworks ![BS.ByteString] -- ^ Type 1: chain hashes (32 bytes each)
+ = InitNetworks ![ChainHash] -- ^ Type 1: chain hashes (32 bytes each)
| InitRemoteAddr !BS.ByteString -- ^ Type 3: remote address
deriving stock (Eq, Show, Generic)
@@ -186,11 +191,18 @@ parseInitTlvs (TlvStream recs) = traverse parseOne recs
where
parseOne (TlvRecord 1 val)
| BS.length val `mod` 32 == 0 =
- Right (InitNetworks (chunksOf 32 val))
+ Right (InitNetworks (map mkChainHash (chunksOf 32 val)))
| otherwise = Left (TlvInvalidKnownType 1)
parseOne (TlvRecord 3 val) = Right (InitRemoteAddr val)
parseOne (TlvRecord t _) = Left (TlvUnknownEvenType t)
+ -- Each chunk is exactly 32 bytes from chunksOf, so chainHash always
+ -- succeeds. We use a partial pattern match as the Nothing case is
+ -- unreachable given our chunksOf guarantee.
+ mkChainHash bs = case chainHash bs of
+ Just ch -> ch
+ Nothing -> error "parseInitTlvs: impossible - chunk is not 32 bytes"
+
-- | Split bytestring into chunks of given size.
chunksOf :: Int -> BS.ByteString -> [BS.ByteString]
chunksOf !n !bs
@@ -204,6 +216,6 @@ encodeInitTlvs :: [InitTlv] -> TlvStream
encodeInitTlvs = TlvStream . map toRecord
where
toRecord (InitNetworks chains) =
- TlvRecord 1 (mconcat chains)
+ TlvRecord 1 (mconcat (map unChainHash chains))
toRecord (InitRemoteAddr addr) =
TlvRecord 3 addr
diff --git a/test/Main.hs b/test/Main.hs
@@ -356,8 +356,8 @@ message_tests = testGroup "Messages" [
Right (MsgInitVal decoded, _) -> decoded @?= msg
other -> assertFailure $ "unexpected: " ++ show other
, testCase "encode/decode init with networks TLV" $ do
- let chainHash = BS.replicate 32 0xab
- msg = Init "" "" [InitNetworks [chainHash]]
+ let ch = unsafeChainHash (BS.replicate 32 0xab)
+ msg = Init "" "" [InitNetworks [ch]]
case encodeMessage (MsgInitVal msg) of
Left e -> assertFailure $ "encode failed: " ++ show e
Right encoded -> case decodeMessage MsgInit encoded of
@@ -654,3 +654,13 @@ unhex :: BS.ByteString -> BS.ByteString
unhex bs = case B16.decode bs of
Just r -> r
Nothing -> error $ "unhex: invalid hex literal: " ++ show bs
+
+-- | Construct a ChainHash from a bytestring (test-only helper).
+--
+-- Uses 'error' for invalid input since all chain hashes in tests are
+-- known-valid 32-byte constants. This is acceptable in test code where
+-- the failure would indicate a bug in the test itself.
+unsafeChainHash :: BS.ByteString -> ChainHash
+unsafeChainHash bs = case chainHash bs of
+ Just c -> c
+ Nothing -> error $ "unsafeChainHash: not 32 bytes: " ++ show (BS.length bs)