commit 0fc4d9fe07087ea81873be87bab89cd5baa555cf
parent 500140d29936670be7064731f40961a0915aebce
Author: Jared Tobin <jared@jtobin.io>
Date: Sun, 25 Jan 2026 18:31:42 +0400
Improve type safety: NonEmpty tx io, Maybe sighash, mkTxId
- Change tx_inputs and tx_outputs from [TxIn]/[TxOut] to NonEmpty,
enforcing Bitcoin's requirement of at least one input and output
at the type level
- Add mkTxId smart constructor that validates 32-byte length,
returning Maybe TxId instead of allowing invalid construction
- Change sighash_segwit to return Maybe BS.ByteString instead of
using error on invalid input index, making it total
- Refactor serialize_legacy_sighash to work directly without
constructing intermediate Tx (needed since SIGHASH_NONE can
produce empty output lists, now invalid for Tx type)
- Update all tests and benchmarks for NonEmpty syntax
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat:
5 files changed, 115 insertions(+), 83 deletions(-)
diff --git a/bench/Main.hs b/bench/Main.hs
@@ -6,6 +6,7 @@ module Main where
import Control.DeepSeq
import Criterion.Main
import qualified Data.ByteString as BS
+import Data.List.NonEmpty (NonEmpty(..))
import Bitcoin.Prim.Tx
import Bitcoin.Prim.Tx.Sighash
@@ -57,21 +58,23 @@ sampleWitness = Witness
]
-- | Create a legacy transaction with n inputs and m outputs.
+-- Requires n >= 1 and m >= 1.
mkLegacyTx :: Int -> Int -> Tx
mkLegacyTx !numInputs !numOutputs = Tx
{ tx_version = 1
- , tx_inputs = replicate numInputs sampleInput
- , tx_outputs = replicate numOutputs sampleOutput
+ , tx_inputs = sampleInput :| replicate (numInputs - 1) sampleInput
+ , tx_outputs = sampleOutput :| replicate (numOutputs - 1) sampleOutput
, tx_witnesses = []
, tx_locktime = 0
}
-- | Create a segwit transaction with n inputs and m outputs.
+-- Requires n >= 1 and m >= 1.
mkSegwitTx :: Int -> Int -> Tx
mkSegwitTx !numInputs !numOutputs = Tx
{ tx_version = 2
- , tx_inputs = replicate numInputs sampleSegwitInput
- , tx_outputs = replicate numOutputs sampleOutput
+ , tx_inputs = sampleSegwitInput :| replicate (numInputs - 1) sampleSegwitInput
+ , tx_outputs = sampleOutput :| replicate (numOutputs - 1) sampleOutput
, tx_witnesses = replicate numInputs sampleWitness
, tx_locktime = 0
}
diff --git a/bench/Weight.hs b/bench/Weight.hs
@@ -5,6 +5,7 @@ module Main where
import Control.DeepSeq
import qualified Data.ByteString as BS
+import Data.List.NonEmpty (NonEmpty(..))
import qualified Weigh as W
import Bitcoin.Prim.Tx
@@ -57,21 +58,23 @@ sampleWitness = Witness
]
-- | Create a legacy transaction with n inputs and m outputs.
+-- Requires n >= 1 and m >= 1.
mkLegacyTx :: Int -> Int -> Tx
mkLegacyTx !numInputs !numOutputs = Tx
{ tx_version = 1
- , tx_inputs = replicate numInputs sampleInput
- , tx_outputs = replicate numOutputs sampleOutput
+ , tx_inputs = sampleInput :| replicate (numInputs - 1) sampleInput
+ , tx_outputs = sampleOutput :| replicate (numOutputs - 1) sampleOutput
, tx_witnesses = []
, tx_locktime = 0
}
-- | Create a segwit transaction with n inputs and m outputs.
+-- Requires n >= 1 and m >= 1.
mkSegwitTx :: Int -> Int -> Tx
mkSegwitTx !numInputs !numOutputs = Tx
{ tx_version = 2
- , tx_inputs = replicate numInputs sampleSegwitInput
- , tx_outputs = replicate numOutputs sampleOutput
+ , tx_inputs = sampleSegwitInput :| replicate (numInputs - 1) sampleSegwitInput
+ , tx_outputs = sampleOutput :| replicate (numOutputs - 1) sampleOutput
, tx_witnesses = replicate numInputs sampleWitness
, tx_locktime = 0
}
diff --git a/lib/Bitcoin/Prim/Tx.hs b/lib/Bitcoin/Prim/Tx.hs
@@ -21,6 +21,7 @@ module Bitcoin.Prim.Tx (
, OutPoint(..)
, Witness(..)
, TxId(..)
+ , mkTxId
-- * Serialisation
, to_bytes
@@ -47,6 +48,8 @@ import qualified Data.ByteString as BS
import qualified Data.ByteString.Base16 as B16
import qualified Data.ByteString.Builder as BSB
import qualified Data.ByteString.Lazy as BL
+import Data.List.NonEmpty (NonEmpty(..))
+import qualified Data.List.NonEmpty as NE
import Data.Word (Word32, Word64)
import GHC.Generics (Generic)
@@ -54,6 +57,19 @@ import GHC.Generics (Generic)
newtype TxId = TxId BS.ByteString
deriving (Eq, Show, Generic)
+-- | Construct a TxId from a 32-byte ByteString.
+--
+-- Returns 'Nothing' if the input is not exactly 32 bytes.
+--
+-- @
+-- mkTxId (BS.replicate 32 0x00) == Just (TxId ...)
+-- mkTxId (BS.replicate 31 0x00) == Nothing
+-- @
+mkTxId :: BS.ByteString -> Maybe TxId
+mkTxId bs
+ | BS.length bs == 32 = Just (TxId bs)
+ | otherwise = Nothing
+
-- | Transaction outpoint (txid + output index).
data OutPoint = OutPoint
{ op_txid :: {-# UNPACK #-} !TxId
@@ -78,10 +94,13 @@ newtype Witness = Witness [BS.ByteString]
deriving (Eq, Show, Generic)
-- | Complete transaction.
+--
+-- Bitcoin requires at least one input and one output, enforced here
+-- via 'NonEmpty' lists.
data Tx = Tx
{ tx_version :: {-# UNPACK #-} !Word32
- , tx_inputs :: ![TxIn]
- , tx_outputs :: ![TxOut]
+ , tx_inputs :: !(NonEmpty TxIn)
+ , tx_outputs :: !(NonEmpty TxOut)
, tx_witnesses :: ![Witness] -- ^ empty list for legacy tx
, tx_locktime :: {-# UNPACK #-} !Word32
} deriving (Eq, Show, Generic)
@@ -103,9 +122,9 @@ to_bytes tx@Tx {..}
put_word32_le tx_version
<> BSB.word8 0x00 -- marker
<> BSB.word8 0x01 -- flag
- <> put_compact (fromIntegral (length tx_inputs))
+ <> put_compact (fromIntegral (NE.length tx_inputs))
<> foldMap put_txin tx_inputs
- <> put_compact (fromIntegral (length tx_outputs))
+ <> put_compact (fromIntegral (NE.length tx_outputs))
<> foldMap put_txout tx_outputs
<> foldMap put_witness tx_witnesses
<> put_word32_le tx_locktime
@@ -124,9 +143,9 @@ to_bytes tx@Tx {..}
to_bytes_legacy :: Tx -> BS.ByteString
to_bytes_legacy Tx {..} = to_strict $
put_word32_le tx_version
- <> put_compact (fromIntegral (length tx_inputs))
+ <> put_compact (fromIntegral (NE.length tx_inputs))
<> foldMap put_txin tx_inputs
- <> put_compact (fromIntegral (length tx_outputs))
+ <> put_compact (fromIntegral (NE.length tx_outputs))
<> foldMap put_txout tx_outputs
<> put_word32_le tx_locktime
@@ -248,12 +267,14 @@ parse_legacy :: BS.ByteString -> Word32 -> Int -> Maybe Tx
parse_legacy !bs !version !off0 = do
-- input count
(input_count, off1) <- get_compact bs off0
- -- inputs
- (inputs, off2) <- get_many get_txin bs off1 (fromIntegral input_count)
+ -- inputs (must have at least one)
+ (inputs_list, off2) <- get_many get_txin bs off1 (fromIntegral input_count)
+ inputs <- NE.nonEmpty inputs_list
-- output count
(output_count, off3) <- get_compact bs off2
- -- outputs
- (outputs, off4) <- get_many get_txout bs off3 (fromIntegral output_count)
+ -- outputs (must have at least one)
+ (outputs_list, off4) <- get_many get_txout bs off3 (fromIntegral output_count)
+ outputs <- NE.nonEmpty outputs_list
-- locktime (4 bytes)
guard (BS.length bs >= off4 + 4)
let !locktime = get_word32_le bs off4
@@ -267,12 +288,14 @@ parse_segwit :: BS.ByteString -> Word32 -> Int -> Maybe Tx
parse_segwit !bs !version !off0 = do
-- input count
(input_count, off1) <- get_compact bs off0
- -- inputs
- (inputs, off2) <- get_many get_txin bs off1 (fromIntegral input_count)
+ -- inputs (must have at least one)
+ (inputs_list, off2) <- get_many get_txin bs off1 (fromIntegral input_count)
+ inputs <- NE.nonEmpty inputs_list
-- output count
(output_count, off3) <- get_compact bs off2
- -- outputs
- (outputs, off4) <- get_many get_txout bs off3 (fromIntegral output_count)
+ -- outputs (must have at least one)
+ (outputs_list, off4) <- get_many get_txout bs off3 (fromIntegral output_count)
+ outputs <- NE.nonEmpty outputs_list
-- witnesses (one per input)
(witnesses, off5) <- get_many get_witness bs off4 (fromIntegral input_count)
-- locktime (4 bytes)
diff --git a/lib/Bitcoin/Prim/Tx/Sighash.hs b/lib/Bitcoin/Prim/Tx/Sighash.hs
@@ -36,6 +36,7 @@ import Bitcoin.Prim.Tx
import qualified Crypto.Hash.SHA256 as SHA256
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as BSB
+import qualified Data.List.NonEmpty as NE
import Data.Word (Word8, Word64)
import GHC.Generics (Generic)
@@ -116,25 +117,27 @@ sighash_legacy
-> BS.ByteString -- ^ 32-byte hash
sighash_legacy !tx !idx !script_pubkey !sighash_type
-- SIGHASH_SINGLE edge case: index >= number of outputs
- | base == SIGHASH_SINGLE && idx >= length (tx_outputs tx) =
+ | base == SIGHASH_SINGLE && idx >= NE.length (tx_outputs tx) =
sighash_single_bug
| otherwise =
- let !modified = modify_tx_legacy tx idx script_pubkey sighash_type
- !serialized = serialize_legacy_for_sighash modified sighash_type
+ let !serialized = serialize_legacy_sighash tx idx script_pubkey sighash_type
in hash256 serialized
where
!base = base_type sighash_type
--- | Modify transaction for legacy sighash computation.
-modify_tx_legacy
+-- | Serialize transaction for legacy sighash computation.
+-- Handles all sighash flags directly without constructing intermediate Tx.
+serialize_legacy_sighash
:: Tx
-> Int
-> BS.ByteString
-> SighashType
- -> Tx
-modify_tx_legacy Tx{..} !idx !script_pubkey !sighash_type =
+ -> BS.ByteString
+serialize_legacy_sighash Tx{..} !idx !script_pubkey !sighash_type =
let !base = base_type sighash_type
!anyonecanpay = is_anyonecanpay sighash_type
+ !inputs_list = NE.toList tx_inputs
+ !outputs_list = NE.toList tx_outputs
-- Clear all scriptSigs, set signing input's script to scriptPubKey
clear_scripts :: Int -> [TxIn] -> [TxIn]
@@ -154,7 +157,7 @@ modify_tx_legacy Tx{..} !idx !script_pubkey !sighash_type =
inp { txin_sequence = 0 } : zero_other_sequences (i + 1) rest
-- Process inputs based on sighash type
- !inputs_cleared = clear_scripts 0 tx_inputs
+ !inputs_cleared = clear_scripts 0 inputs_list
!inputs_processed = case base of
SIGHASH_NONE -> zero_other_sequences 0 inputs_cleared
@@ -171,10 +174,17 @@ modify_tx_legacy Tx{..} !idx !script_pubkey !sighash_type =
-- Process outputs based on sighash type
!final_outputs = case base of
SIGHASH_NONE -> []
- SIGHASH_SINGLE -> build_single_outputs tx_outputs idx
- _ -> tx_outputs
+ SIGHASH_SINGLE -> build_single_outputs outputs_list idx
+ _ -> outputs_list
- in Tx tx_version final_inputs final_outputs [] tx_locktime
+ in to_strict $
+ put_word32_le tx_version
+ <> put_compact (fromIntegral (length final_inputs))
+ <> foldMap put_txin_legacy final_inputs
+ <> put_compact (fromIntegral (length final_outputs))
+ <> foldMap put_txout final_outputs
+ <> put_word32_le tx_locktime
+ <> put_word32_le (fromIntegral (sighash_byte sighash_type))
-- | Build outputs for SIGHASH_SINGLE: keep only output at idx,
-- replace earlier outputs with empty/zero outputs.
@@ -201,17 +211,6 @@ safe_index (x : xs) !n
| otherwise = safe_index xs (n - 1)
{-# INLINE safe_index #-}
--- | Serialize modified transaction for legacy sighash, appending sighash type.
-serialize_legacy_for_sighash :: Tx -> SighashType -> BS.ByteString
-serialize_legacy_for_sighash Tx{..} !sighash_type = to_strict $
- put_word32_le tx_version
- <> put_compact (fromIntegral (length tx_inputs))
- <> foldMap put_txin_legacy tx_inputs
- <> put_compact (fromIntegral (length tx_outputs))
- <> foldMap put_txout tx_outputs
- <> put_word32_le tx_locktime
- <> put_word32_le (fromIntegral (sighash_byte sighash_type))
-
-- | Encode TxIn for legacy sighash (same as normal encoding).
put_txin_legacy :: TxIn -> BSB.Builder
put_txin_legacy TxIn{..} =
@@ -229,11 +228,13 @@ put_txin_legacy TxIn{..} =
-- sighash, this commits to the value being spent, preventing fee
-- manipulation attacks.
--
+-- Returns 'Nothing' if the input index is out of range.
+--
-- @
-- -- sign P2WPKH input 0
-- let scriptCode = ... -- P2WPKH scriptCode
-- let hash = sighash_segwit tx 0 scriptCode inputValue SIGHASH_ALL
--- -- use hash with ECDSA signing
+-- -- use hash with ECDSA signing (after checking Just)
-- @
sighash_segwit
:: Tx
@@ -241,20 +242,26 @@ sighash_segwit
-> BS.ByteString -- ^ scriptCode
-> Word64 -- ^ value being spent (satoshis)
-> SighashType
- -> BS.ByteString -- ^ 32-byte hash
-sighash_segwit !tx !idx !script_code !value !sighash_type =
- let !preimage = build_bip143_preimage tx idx script_code value sighash_type
- in hash256 preimage
+ -> Maybe BS.ByteString -- ^ 32-byte hash, or Nothing if index invalid
+sighash_segwit !tx !idx !script_code !value !sighash_type = do
+ preimage <- build_bip143_preimage tx idx script_code value sighash_type
+ pure $! hash256 preimage
-- | Build BIP143 preimage for signing.
+-- Returns Nothing if the input index is out of range.
build_bip143_preimage
:: Tx
-> Int
-> BS.ByteString
-> Word64
-> SighashType
- -> BS.ByteString
-build_bip143_preimage Tx{..} !idx !script_code !value !sighash_type =
+ -> Maybe BS.ByteString
+build_bip143_preimage Tx{..} !idx !script_code !value !sighash_type = do
+ -- Get the input being signed; fail if index out of range
+ let !inputs_list = NE.toList tx_inputs
+ !outputs_list = NE.toList tx_outputs
+ signing_input <- safe_index inputs_list idx
+
let !base = base_type sighash_type
!anyonecanpay = is_anyonecanpay sighash_type
@@ -277,28 +284,23 @@ build_bip143_preimage Tx{..} !idx !script_code !value !sighash_type =
!hash_outputs = case base of
SIGHASH_NONE -> zero32
SIGHASH_SINGLE ->
- case safe_index tx_outputs idx of
+ case safe_index outputs_list idx of
Nothing -> zero32 -- index out of range
Just out -> hash256 $ to_strict $ put_txout out
_ -> hash256 $ to_strict $ foldMap put_txout tx_outputs
- -- Get the input being signed
- !signing_input = case safe_index tx_inputs idx of
- Just inp -> inp
- Nothing -> error "sighash_segwit: invalid input index"
-
!outpoint = txin_prevout signing_input
!sequence_n = txin_sequence signing_input
- in to_strict $
- put_word32_le tx_version
- <> BSB.byteString hash_prevouts
- <> BSB.byteString hash_sequence
- <> put_outpoint outpoint
- <> put_compact (fromIntegral (BS.length script_code))
- <> BSB.byteString script_code
- <> put_word64_le value
- <> put_word32_le sequence_n
- <> BSB.byteString hash_outputs
- <> put_word32_le tx_locktime
- <> put_word32_le (fromIntegral (sighash_byte sighash_type))
+ pure $! to_strict $
+ put_word32_le tx_version
+ <> BSB.byteString hash_prevouts
+ <> BSB.byteString hash_sequence
+ <> put_outpoint outpoint
+ <> put_compact (fromIntegral (BS.length script_code))
+ <> BSB.byteString script_code
+ <> put_word64_le value
+ <> put_word32_le sequence_n
+ <> BSB.byteString hash_outputs
+ <> put_word32_le tx_locktime
+ <> put_word32_le (fromIntegral (sighash_byte sighash_type))
diff --git a/test/Main.hs b/test/Main.hs
@@ -5,6 +5,7 @@ module Main where
import Bitcoin.Prim.Tx
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base16 as B16
+import Data.List.NonEmpty (NonEmpty(..))
import Test.Tasty
import qualified Test.Tasty.HUnit as H
@@ -73,8 +74,8 @@ roundtrip_legacy_simple = H.testCase "simple legacy tx" $
where
legacyTx = Tx
{ tx_version = 1
- , tx_inputs = [txin]
- , tx_outputs = [txout]
+ , tx_inputs = txin :| []
+ , tx_outputs = txout :| []
, tx_witnesses = []
, tx_locktime = 0
}
@@ -98,8 +99,8 @@ roundtrip_segwit = H.testCase "segwit tx with witnesses" $
where
segwitTx = Tx
{ tx_version = 2
- , tx_inputs = [txin]
- , tx_outputs = [txout]
+ , tx_inputs = txin :| []
+ , tx_outputs = txout :| []
, tx_witnesses = [witness]
, tx_locktime = 500000
}
@@ -127,8 +128,8 @@ roundtrip_multi_io = H.testCase "multiple inputs/outputs" $
where
multiTx = Tx
{ tx_version = 1
- , tx_inputs = [txin1, txin2, txin3]
- , tx_outputs = [txout1, txout2]
+ , tx_inputs = txin1 :| [txin2, txin3]
+ , tx_outputs = txout1 :| [txout2]
, tx_witnesses = []
, tx_locktime = 123456
}
@@ -220,8 +221,8 @@ edge_empty_scriptsig = H.testCase "empty scriptSig" $
where
tx = Tx
{ tx_version = 2
- , tx_inputs = [txin]
- , tx_outputs = [txout]
+ , tx_inputs = txin :| []
+ , tx_outputs = txout :| []
, tx_witnesses = [witness]
, tx_locktime = 0
}
@@ -246,8 +247,8 @@ edge_max_sequence = H.testCase "maximum sequence (0xffffffff)" $
where
tx = Tx
{ tx_version = 1
- , tx_inputs = [txin]
- , tx_outputs = [txout]
+ , tx_inputs = txin :| []
+ , tx_outputs = txout :| []
, tx_witnesses = []
, tx_locktime = 0
}
@@ -271,8 +272,8 @@ edge_zero_locktime = H.testCase "zero locktime" $
where
tx = Tx
{ tx_version = 1
- , tx_inputs = [txin]
- , tx_outputs = [txout]
+ , tx_inputs = txin :| []
+ , tx_outputs = txout :| []
, tx_witnesses = []
, tx_locktime = 0
}
@@ -296,8 +297,8 @@ edge_multi_witness = H.testCase "multiple witness items" $
where
tx = Tx
{ tx_version = 2
- , tx_inputs = [txin1, txin2]
- , tx_outputs = [txout]
+ , tx_inputs = txin1 :| [txin2]
+ , tx_outputs = txout :| []
, tx_witnesses = [witness1, witness2]
, tx_locktime = 0
}