commit 4e310c774c1683b2f284a726ee0003b150efd96f
parent 367ebbd8a17c670bc0b249ec4520ebb3ac869c00
Author: Jared Tobin <jared@jtobin.io>
Date: Mon, 20 Apr 2026 15:13:29 +0800
lib: add HtlcId, SerialId newtypes and Initiator sum type
Introduce domain-specific newtypes to prevent mixing bare Word64
values across different identifier namespaces:
- HtlcId: wraps Word64 for HTLC identifiers (update_add_htlc,
update_fulfill_htlc, update_fail_htlc,
update_fail_malformed_htlc)
- SerialId: wraps Word64 for interactive-tx serial identifiers
(tx_add_input, tx_add_output, tx_remove_input,
tx_remove_output)
Replace the Word8 initiator flag in Stfu with an Initiator sum
type (IsInitiator | NotInitiator) to make illegal values
unrepresentable and clarify semantics.
Add DecodeInvalidInitiator error variant for values other than
0 or 1. Update codec, tests, and benchmarks throughout.
Diffstat:
6 files changed, 148 insertions(+), 79 deletions(-)
diff --git a/bench/Main.hs b/bench/Main.hs
@@ -186,7 +186,7 @@ encodedClosingSigned = encodeClosingSigned testClosingSigned
testUpdateAddHtlc :: UpdateAddHtlc
testUpdateAddHtlc = UpdateAddHtlc
{ updateAddHtlcChannelId = testChannelId
- , updateAddHtlcId = 0
+ , updateAddHtlcId = htlcId 0
, updateAddHtlcAmountMsat = MilliSatoshi 10000000
, updateAddHtlcPaymentHash = testPaymentHash
, updateAddHtlcCltvExpiry = 800000
diff --git a/bench/Weight.hs b/bench/Weight.hs
@@ -156,7 +156,7 @@ mkUpdateAddHtlc
:: ChannelId -> PaymentHash -> OnionPacket -> TlvStream -> UpdateAddHtlc
mkUpdateAddHtlc !cid !ph !onion !tlvs = UpdateAddHtlc
{ updateAddHtlcChannelId = cid
- , updateAddHtlcId = 0
+ , updateAddHtlcId = htlcId 0
, updateAddHtlcAmountMsat = MilliSatoshi 10000000
, updateAddHtlcPaymentHash = ph
, updateAddHtlcCltvExpiry = 800000
diff --git a/lib/Lightning/Protocol/BOLT2/Codec.hs b/lib/Lightning/Protocol/BOLT2/Codec.hs
@@ -128,6 +128,7 @@ data DecodeError
| DecodeInvalidPaymentPreimage
| DecodeInvalidOnionPacket
| DecodeInvalidSecret
+ | DecodeInvalidInitiator
| DecodeTlvError !TlvError
deriving stock (Eq, Show, Generic)
@@ -580,18 +581,24 @@ decodeChannelReady !bs = do
encodeStfu :: Stfu -> BS.ByteString
encodeStfu !msg = mconcat
[ unChannelId (stfuChannelId msg)
- , BS.singleton (stfuInitiator msg)
+ , BS.singleton $! case stfuInitiator msg of
+ IsInitiator -> 1
+ NotInitiator -> 0
]
-- | Decode a Stfu message (type 2).
decodeStfu :: BS.ByteString -> Either DecodeError (Stfu, BS.ByteString)
decodeStfu !bs = do
(chanId, rest1) <- decodeChannelIdBytes bs
- (initiator, rest2) <- maybe (Left DecodeInsufficientBytes) Right
- (decodeU8 rest1)
+ (raw, rest2) <- maybe (Left DecodeInsufficientBytes) Right
+ (decodeU8 rest1)
+ ini <- case raw of
+ 1 -> Right IsInitiator
+ 0 -> Right NotInitiator
+ _ -> Left DecodeInvalidInitiator
let !msg = Stfu
{ stfuChannelId = chanId
- , stfuInitiator = initiator
+ , stfuInitiator = ini
}
Right (msg, rest2)
@@ -898,7 +905,7 @@ encodeTxAddInput !msg = do
prevTxEnc <- encodeU16BytesE (txAddInputPrevTx msg)
Right $! mconcat
[ unChannelId (txAddInputChannelId msg)
- , encodeU64 (txAddInputSerialId msg)
+ , encodeU64 (unSerialId (txAddInputSerialId msg))
, prevTxEnc
, encodeU32 (txAddInputPrevVout msg)
, encodeU32 (txAddInputSequence msg)
@@ -909,14 +916,14 @@ decodeTxAddInput
:: BS.ByteString -> Either DecodeError (TxAddInput, BS.ByteString)
decodeTxAddInput !bs = do
(cid, rest1) <- decodeChannelIdBytes bs
- (serialId, rest2) <- maybe (Left DecodeInsufficientBytes) Right
- (decodeU64 rest1)
+ (sid, rest2) <- maybe (Left DecodeInsufficientBytes) Right
+ (decodeU64 rest1)
(prevTx, rest3) <- decodeU16Bytes rest2
(prevVout, rest4) <- decodeU32E rest3
(seqNum, rest5) <- decodeU32E rest4
let !msg = TxAddInput
{ txAddInputChannelId = cid
- , txAddInputSerialId = serialId
+ , txAddInputSerialId = serialId sid
, txAddInputPrevTx = prevTx
, txAddInputPrevVout = prevVout
, txAddInputSequence = seqNum
@@ -929,7 +936,7 @@ encodeTxAddOutput !msg = do
scriptEnc <- encodeU16BytesE (unScriptPubKey (txAddOutputScript msg))
Right $! mconcat
[ unChannelId (txAddOutputChannelId msg)
- , encodeU64 (txAddOutputSerialId msg)
+ , encodeU64 (unSerialId (txAddOutputSerialId msg))
, encodeU64 (unSatoshi (txAddOutputSats msg))
, scriptEnc
]
@@ -939,13 +946,13 @@ decodeTxAddOutput
:: BS.ByteString -> Either DecodeError (TxAddOutput, BS.ByteString)
decodeTxAddOutput !bs = do
(cid, rest1) <- decodeChannelIdBytes bs
- (serialId, rest2) <- maybe (Left DecodeInsufficientBytes) Right
- (decodeU64 rest1)
+ (sid, rest2) <- maybe (Left DecodeInsufficientBytes) Right
+ (decodeU64 rest1)
(sats, rest3) <- decodeSatoshi rest2
(scriptBs, rest4) <- decodeU16Bytes rest3
let !msg = TxAddOutput
{ txAddOutputChannelId = cid
- , txAddOutputSerialId = serialId
+ , txAddOutputSerialId = serialId sid
, txAddOutputSats = sats
, txAddOutputScript = scriptPubKey scriptBs
}
@@ -955,7 +962,7 @@ decodeTxAddOutput !bs = do
encodeTxRemoveInput :: TxRemoveInput -> BS.ByteString
encodeTxRemoveInput !msg = mconcat
[ unChannelId (txRemoveInputChannelId msg)
- , encodeU64 (txRemoveInputSerialId msg)
+ , encodeU64 (unSerialId (txRemoveInputSerialId msg))
]
-- | Decode a TxRemoveInput message (type 68).
@@ -963,11 +970,11 @@ decodeTxRemoveInput
:: BS.ByteString -> Either DecodeError (TxRemoveInput, BS.ByteString)
decodeTxRemoveInput !bs = do
(cid, rest1) <- decodeChannelIdBytes bs
- (serialId, rest2) <- maybe (Left DecodeInsufficientBytes) Right
- (decodeU64 rest1)
+ (sid, rest2) <- maybe (Left DecodeInsufficientBytes) Right
+ (decodeU64 rest1)
let !msg = TxRemoveInput
{ txRemoveInputChannelId = cid
- , txRemoveInputSerialId = serialId
+ , txRemoveInputSerialId = serialId sid
}
Right (msg, rest2)
@@ -975,7 +982,7 @@ decodeTxRemoveInput !bs = do
encodeTxRemoveOutput :: TxRemoveOutput -> BS.ByteString
encodeTxRemoveOutput !msg = mconcat
[ unChannelId (txRemoveOutputChannelId msg)
- , encodeU64 (txRemoveOutputSerialId msg)
+ , encodeU64 (unSerialId (txRemoveOutputSerialId msg))
]
-- | Decode a TxRemoveOutput message (type 69).
@@ -983,11 +990,11 @@ decodeTxRemoveOutput
:: BS.ByteString -> Either DecodeError (TxRemoveOutput, BS.ByteString)
decodeTxRemoveOutput !bs = do
(cid, rest1) <- decodeChannelIdBytes bs
- (serialId, rest2) <- maybe (Left DecodeInsufficientBytes) Right
- (decodeU64 rest1)
+ (sid, rest2) <- maybe (Left DecodeInsufficientBytes) Right
+ (decodeU64 rest1)
let !msg = TxRemoveOutput
{ txRemoveOutputChannelId = cid
- , txRemoveOutputSerialId = serialId
+ , txRemoveOutputSerialId = serialId sid
}
Right (msg, rest2)
@@ -1119,7 +1126,7 @@ decodeTxAbort !bs = do
encodeUpdateAddHtlc :: UpdateAddHtlc -> BS.ByteString
encodeUpdateAddHtlc !m = mconcat
[ unChannelId (updateAddHtlcChannelId m)
- , encodeU64 (updateAddHtlcId m)
+ , encodeU64 (unHtlcId (updateAddHtlcId m))
, encodeU64 (unMilliSatoshi (updateAddHtlcAmountMsat m))
, unPaymentHash (updateAddHtlcPaymentHash m)
, encodeU32 (updateAddHtlcCltvExpiry m)
@@ -1133,8 +1140,8 @@ decodeUpdateAddHtlc
:: BS.ByteString -> Either DecodeError (UpdateAddHtlc, BS.ByteString)
decodeUpdateAddHtlc !bs = do
(cid, rest1) <- decodeChannelIdBytes bs
- (htlcId, rest2) <- maybe (Left DecodeInsufficientBytes) Right
- (decodeU64 rest1)
+ (hid, rest2) <- maybe (Left DecodeInsufficientBytes) Right
+ (decodeU64 rest1)
(amtMsat, rest3) <- maybe (Left DecodeInsufficientBytes) Right
(decodeU64 rest2)
(pHash, rest4) <- decodePaymentHashBytes rest3
@@ -1143,7 +1150,7 @@ decodeUpdateAddHtlc !bs = do
(tlvs, rest7) <- decodeOptionalTlvs rest6
let !msg = UpdateAddHtlc
{ updateAddHtlcChannelId = cid
- , updateAddHtlcId = htlcId
+ , updateAddHtlcId = htlcId hid
, updateAddHtlcAmountMsat = MilliSatoshi amtMsat
, updateAddHtlcPaymentHash = pHash
, updateAddHtlcCltvExpiry = cltvExp
@@ -1157,7 +1164,7 @@ decodeUpdateAddHtlc !bs = do
encodeUpdateFulfillHtlc :: UpdateFulfillHtlc -> BS.ByteString
encodeUpdateFulfillHtlc !m = mconcat
[ unChannelId (updateFulfillHtlcChannelId m)
- , encodeU64 (updateFulfillHtlcId m)
+ , encodeU64 (unHtlcId (updateFulfillHtlcId m))
, unPaymentPreimage (updateFulfillHtlcPaymentPreimage m)
, encodeTlvStream (updateFulfillHtlcTlvs m)
]
@@ -1167,13 +1174,13 @@ decodeUpdateFulfillHtlc
:: BS.ByteString -> Either DecodeError (UpdateFulfillHtlc, BS.ByteString)
decodeUpdateFulfillHtlc !bs = do
(cid, rest1) <- decodeChannelIdBytes bs
- (htlcId, rest2) <- maybe (Left DecodeInsufficientBytes) Right
- (decodeU64 rest1)
+ (hid, rest2) <- maybe (Left DecodeInsufficientBytes) Right
+ (decodeU64 rest1)
(preimage, rest3) <- decodePaymentPreimageBytes rest2
(tlvs, rest4) <- decodeOptionalTlvs rest3
let !msg = UpdateFulfillHtlc
{ updateFulfillHtlcChannelId = cid
- , updateFulfillHtlcId = htlcId
+ , updateFulfillHtlcId = htlcId hid
, updateFulfillHtlcPaymentPreimage = preimage
, updateFulfillHtlcTlvs = tlvs
}
@@ -1185,7 +1192,7 @@ encodeUpdateFailHtlc !m = do
reasonEnc <- encodeU16BytesE (updateFailHtlcReason m)
Right $! mconcat
[ unChannelId (updateFailHtlcChannelId m)
- , encodeU64 (updateFailHtlcId m)
+ , encodeU64 (unHtlcId (updateFailHtlcId m))
, reasonEnc
, encodeTlvStream (updateFailHtlcTlvs m)
]
@@ -1195,13 +1202,13 @@ decodeUpdateFailHtlc
:: BS.ByteString -> Either DecodeError (UpdateFailHtlc, BS.ByteString)
decodeUpdateFailHtlc !bs = do
(cid, rest1) <- decodeChannelIdBytes bs
- (htlcId, rest2) <- maybe (Left DecodeInsufficientBytes) Right
- (decodeU64 rest1)
+ (hid, rest2) <- maybe (Left DecodeInsufficientBytes) Right
+ (decodeU64 rest1)
(reason, rest3) <- decodeU16Bytes rest2
(tlvs, rest4) <- decodeOptionalTlvs rest3
let !msg = UpdateFailHtlc
{ updateFailHtlcChannelId = cid
- , updateFailHtlcId = htlcId
+ , updateFailHtlcId = htlcId hid
, updateFailHtlcReason = reason
, updateFailHtlcTlvs = tlvs
}
@@ -1211,7 +1218,7 @@ decodeUpdateFailHtlc !bs = do
encodeUpdateFailMalformedHtlc :: UpdateFailMalformedHtlc -> BS.ByteString
encodeUpdateFailMalformedHtlc !m = mconcat
[ unChannelId (updateFailMalformedHtlcChannelId m)
- , encodeU64 (updateFailMalformedHtlcId m)
+ , encodeU64 (unHtlcId (updateFailMalformedHtlcId m))
, unPaymentHash (updateFailMalformedHtlcSha256Onion m)
, encodeU16 (updateFailMalformedHtlcFailureCode m)
]
@@ -1221,13 +1228,13 @@ decodeUpdateFailMalformedHtlc
:: BS.ByteString -> Either DecodeError (UpdateFailMalformedHtlc, BS.ByteString)
decodeUpdateFailMalformedHtlc !bs = do
(cid, rest1) <- decodeChannelIdBytes bs
- (htlcId, rest2) <- maybe (Left DecodeInsufficientBytes) Right
- (decodeU64 rest1)
+ (hid, rest2) <- maybe (Left DecodeInsufficientBytes) Right
+ (decodeU64 rest1)
(sha256Onion, rest3) <- decodePaymentHashBytes rest2
(failCode, rest4) <- decodeU16E rest3
let !msg = UpdateFailMalformedHtlc
{ updateFailMalformedHtlcChannelId = cid
- , updateFailMalformedHtlcId = htlcId
+ , updateFailMalformedHtlcId = htlcId hid
, updateFailMalformedHtlcSha256Onion = sha256Onion
, updateFailMalformedHtlcFailureCode = failCode
}
diff --git a/lib/Lightning/Protocol/BOLT2/Messages.hs b/lib/Lightning/Protocol/BOLT2/Messages.hs
@@ -292,7 +292,7 @@ instance NFData AcceptChannel2
-- Adds a transaction input to the collaborative transaction.
data TxAddInput = TxAddInput
{ txAddInputChannelId :: !ChannelId
- , txAddInputSerialId :: {-# UNPACK #-} !Word64
+ , txAddInputSerialId :: !SerialId
, txAddInputPrevTx :: !BS.ByteString
, txAddInputPrevVout :: {-# UNPACK #-} !Word32
, txAddInputSequence :: {-# UNPACK #-} !Word32
@@ -305,7 +305,7 @@ instance NFData TxAddInput
-- Adds a transaction output to the collaborative transaction.
data TxAddOutput = TxAddOutput
{ txAddOutputChannelId :: !ChannelId
- , txAddOutputSerialId :: {-# UNPACK #-} !Word64
+ , txAddOutputSerialId :: !SerialId
, txAddOutputSats :: {-# UNPACK #-} !Satoshi
, txAddOutputScript :: !ScriptPubKey
} deriving stock (Eq, Show, Generic)
@@ -317,7 +317,7 @@ instance NFData TxAddOutput
-- Removes a previously added input from the collaborative transaction.
data TxRemoveInput = TxRemoveInput
{ txRemoveInputChannelId :: !ChannelId
- , txRemoveInputSerialId :: {-# UNPACK #-} !Word64
+ , txRemoveInputSerialId :: !SerialId
} deriving stock (Eq, Show, Generic)
instance NFData TxRemoveInput
@@ -327,7 +327,7 @@ instance NFData TxRemoveInput
-- Removes a previously added output from the collaborative transaction.
data TxRemoveOutput = TxRemoveOutput
{ txRemoveOutputChannelId :: !ChannelId
- , txRemoveOutputSerialId :: {-# UNPACK #-} !Word64
+ , txRemoveOutputSerialId :: !SerialId
} deriving stock (Eq, Show, Generic)
instance NFData TxRemoveOutput
@@ -399,7 +399,7 @@ instance NFData TxAbort
-- quiescence.
data Stfu = Stfu
{ stfuChannelId :: !ChannelId
- , stfuInitiator :: {-# UNPACK #-} !Word8
+ , stfuInitiator :: !Initiator
} deriving stock (Eq, Show, Generic)
instance NFData Stfu
@@ -462,7 +462,7 @@ instance NFData ClosingSig
-- preimage.
data UpdateAddHtlc = UpdateAddHtlc
{ updateAddHtlcChannelId :: !ChannelId
- , updateAddHtlcId :: {-# UNPACK #-} !Word64
+ , updateAddHtlcId :: !HtlcId
, updateAddHtlcAmountMsat :: {-# UNPACK #-} !MilliSatoshi
, updateAddHtlcPaymentHash :: !PaymentHash
, updateAddHtlcCltvExpiry :: {-# UNPACK #-} !Word32
@@ -477,7 +477,7 @@ instance NFData UpdateAddHtlc
-- Supplies the preimage to fulfill an HTLC.
data UpdateFulfillHtlc = UpdateFulfillHtlc
{ updateFulfillHtlcChannelId :: !ChannelId
- , updateFulfillHtlcId :: {-# UNPACK #-} !Word64
+ , updateFulfillHtlcId :: !HtlcId
, updateFulfillHtlcPaymentPreimage :: !PaymentPreimage
, updateFulfillHtlcTlvs :: !TlvStream
} deriving stock (Eq, Show, Generic)
@@ -489,7 +489,7 @@ instance NFData UpdateFulfillHtlc
-- Indicates an HTLC has failed.
data UpdateFailHtlc = UpdateFailHtlc
{ updateFailHtlcChannelId :: !ChannelId
- , updateFailHtlcId :: {-# UNPACK #-} !Word64
+ , updateFailHtlcId :: !HtlcId
, updateFailHtlcReason :: !BS.ByteString
, updateFailHtlcTlvs :: !TlvStream
} deriving stock (Eq, Show, Generic)
@@ -501,7 +501,7 @@ instance NFData UpdateFailHtlc
-- Indicates an HTLC could not be parsed.
data UpdateFailMalformedHtlc = UpdateFailMalformedHtlc
{ updateFailMalformedHtlcChannelId :: !ChannelId
- , updateFailMalformedHtlcId :: {-# UNPACK #-} !Word64
+ , updateFailMalformedHtlcId :: !HtlcId
, updateFailMalformedHtlcSha256Onion :: !PaymentHash
, updateFailMalformedHtlcFailureCode :: {-# UNPACK #-} !Word16
} deriving stock (Eq, Show, Generic)
diff --git a/lib/Lightning/Protocol/BOLT2/Types.hs b/lib/Lightning/Protocol/BOLT2/Types.hs
@@ -62,6 +62,17 @@ module Lightning.Protocol.BOLT2.Types (
, scriptPubKey
, unScriptPubKey
+ -- * Protocol identifiers
+ , HtlcId
+ , htlcId
+ , unHtlcId
+ , SerialId
+ , serialId
+ , unSerialId
+
+ -- * Quiescence
+ , Initiator(..)
+
-- * Protocol types
, FeatureBits
, featureBits
@@ -85,6 +96,7 @@ module Lightning.Protocol.BOLT2.Types (
import Bitcoin.Prim.Tx (TxId(..), mkTxId, OutPoint(..))
import Control.DeepSeq (NFData)
import qualified Data.ByteString as BS
+import Data.Word (Word64)
import GHC.Generics (Generic)
import Lightning.Protocol.BOLT1.Prim
( ChannelId(..), channelId, unChannelId
@@ -164,6 +176,48 @@ unScriptPubKey :: ScriptPubKey -> BS.ByteString
unScriptPubKey (ScriptPubKey bs) = bs
{-# INLINE unScriptPubKey #-}
+-- protocol identifiers -------------------------------------------------------
+
+-- | An HTLC identifier, unique per channel per direction.
+newtype HtlcId = HtlcId Word64
+ deriving stock (Eq, Ord, Show, Generic)
+ deriving newtype NFData
+
+-- | Construct an 'HtlcId' from a 'Word64'.
+htlcId :: Word64 -> HtlcId
+htlcId = HtlcId
+{-# INLINE htlcId #-}
+
+-- | Extract the underlying 'Word64' from an 'HtlcId'.
+unHtlcId :: HtlcId -> Word64
+unHtlcId (HtlcId w) = w
+{-# INLINE unHtlcId #-}
+
+-- | A serial identifier for interactive transaction construction.
+newtype SerialId = SerialId Word64
+ deriving stock (Eq, Ord, Show, Generic)
+ deriving newtype NFData
+
+-- | Construct a 'SerialId' from a 'Word64'.
+serialId :: Word64 -> SerialId
+serialId = SerialId
+{-# INLINE serialId #-}
+
+-- | Extract the underlying 'Word64' from a 'SerialId'.
+unSerialId :: SerialId -> Word64
+unSerialId (SerialId w) = w
+{-# INLINE unSerialId #-}
+
+-- quiescence -----------------------------------------------------------------
+
+-- | Role in quiescence (STFU) protocol.
+data Initiator
+ = IsInitiator -- ^ This node initiated quiescence.
+ | NotInitiator -- ^ This node did not initiate quiescence.
+ deriving stock (Eq, Ord, Show, Generic)
+
+instance NFData Initiator
+
-- protocol types --------------------------------------------------------------
-- | Feature bits (variable length).
diff --git a/test/Main.hs b/test/Main.hs
@@ -237,7 +237,7 @@ v2_establishment_tests = testGroup "V2 Channel Establishment" [
testCase "encode/decode roundtrip" $ do
let msg = TxAddInput
{ txAddInputChannelId = testChannelId
- , txAddInputSerialId = 12345
+ , txAddInputSerialId = serialId 12345
, txAddInputPrevTx = BS.pack [0x01, 0x02, 0x03, 0x04]
, txAddInputPrevVout = 0
, txAddInputSequence = 0xfffffffe
@@ -250,7 +250,7 @@ v2_establishment_tests = testGroup "V2 Channel Establishment" [
, testCase "roundtrip with empty prevTx" $ do
let msg = TxAddInput
{ txAddInputChannelId = testChannelId
- , txAddInputSerialId = 0
+ , txAddInputSerialId = serialId 0
, txAddInputPrevTx = BS.empty
, txAddInputPrevVout = 0
, txAddInputSequence = 0
@@ -265,7 +265,7 @@ v2_establishment_tests = testGroup "V2 Channel Establishment" [
testCase "encode/decode roundtrip" $ do
let msg = TxAddOutput
{ txAddOutputChannelId = testChannelId
- , txAddOutputSerialId = 54321
+ , txAddOutputSerialId = serialId 54321
, txAddOutputSats = Satoshi 100000
, txAddOutputScript = scriptPubKey (BS.pack [0x00, 0x14] <>
BS.replicate 20 0xaa)
@@ -280,7 +280,7 @@ v2_establishment_tests = testGroup "V2 Channel Establishment" [
testCase "encode/decode roundtrip" $ do
let msg = TxRemoveInput
{ txRemoveInputChannelId = testChannelId
- , txRemoveInputSerialId = 12345
+ , txRemoveInputSerialId = serialId 12345
}
encoded = encodeTxRemoveInput msg
case decodeTxRemoveInput encoded of
@@ -291,7 +291,7 @@ v2_establishment_tests = testGroup "V2 Channel Establishment" [
testCase "encode/decode roundtrip" $ do
let msg = TxRemoveOutput
{ txRemoveOutputChannelId = testChannelId
- , txRemoveOutputSerialId = 54321
+ , txRemoveOutputSerialId = serialId 54321
}
encoded = encodeTxRemoveOutput msg
case decodeTxRemoveOutput encoded of
@@ -385,19 +385,19 @@ v2_establishment_tests = testGroup "V2 Channel Establishment" [
close_tests :: TestTree
close_tests = testGroup "Channel Close" [
testGroup "Stfu" [
- testCase "encode/decode initiator=1" $ do
+ testCase "encode/decode IsInitiator" $ do
let msg = Stfu
{ stfuChannelId = testChannelId
- , stfuInitiator = 1
+ , stfuInitiator = IsInitiator
}
encoded = encodeStfu msg
case decodeStfu encoded of
Right (decoded, _) -> decoded @?= msg
Left e -> assertFailure $ "decode failed: " ++ show e
- , testCase "encode/decode initiator=0" $ do
+ , testCase "encode/decode NotInitiator" $ do
let msg = Stfu
{ stfuChannelId = testChannelId
- , stfuInitiator = 0
+ , stfuInitiator = NotInitiator
}
encoded = encodeStfu msg
case decodeStfu encoded of
@@ -493,7 +493,7 @@ normal_operation_tests = testGroup "Normal Operation" [
testCase "encode/decode roundtrip" $ do
let msg = UpdateAddHtlc
{ updateAddHtlcChannelId = testChannelId
- , updateAddHtlcId = 0
+ , updateAddHtlcId = htlcId 0
, updateAddHtlcAmountMsat = MilliSatoshi 10000000
, updateAddHtlcPaymentHash = testPaymentHash
, updateAddHtlcCltvExpiry = 800144
@@ -509,7 +509,7 @@ normal_operation_tests = testGroup "Normal Operation" [
testCase "encode/decode roundtrip" $ do
let msg = UpdateFulfillHtlc
{ updateFulfillHtlcChannelId = testChannelId
- , updateFulfillHtlcId = 42
+ , updateFulfillHtlcId = htlcId 42
, updateFulfillHtlcPaymentPreimage = testPaymentPreimage
, updateFulfillHtlcTlvs = emptyTlvs
}
@@ -522,7 +522,7 @@ normal_operation_tests = testGroup "Normal Operation" [
testCase "encode/decode roundtrip" $ do
let msg = UpdateFailHtlc
{ updateFailHtlcChannelId = testChannelId
- , updateFailHtlcId = 42
+ , updateFailHtlcId = htlcId 42
, updateFailHtlcReason = BS.replicate 32 0xaa
, updateFailHtlcTlvs = emptyTlvs
}
@@ -534,7 +534,7 @@ normal_operation_tests = testGroup "Normal Operation" [
, testCase "roundtrip with empty reason" $ do
let msg = UpdateFailHtlc
{ updateFailHtlcChannelId = testChannelId
- , updateFailHtlcId = 0
+ , updateFailHtlcId = htlcId 0
, updateFailHtlcReason = BS.empty
, updateFailHtlcTlvs = emptyTlvs
}
@@ -548,7 +548,7 @@ normal_operation_tests = testGroup "Normal Operation" [
testCase "encode/decode roundtrip" $ do
let msg = UpdateFailMalformedHtlc
{ updateFailMalformedHtlcChannelId = testChannelId
- , updateFailMalformedHtlcId = 42
+ , updateFailMalformedHtlcId = htlcId 42
, updateFailMalformedHtlcSha256Onion = testPaymentHash
, updateFailMalformedHtlcFailureCode = 0x8002
}
@@ -679,6 +679,13 @@ error_tests = testGroup "Error Conditions" [
case decodeShutdown (BS.replicate 32 0x00) of
Left DecodeInsufficientBytes -> pure ()
other -> assertFailure $ "expected insufficient: " ++ show other
+ , testCase "decodeStfu invalid initiator byte" $ do
+ -- channel_id (32 bytes) + initiator (1 byte, value 2)
+ let encoded = BS.replicate 32 0xab <> BS.singleton 0x02
+ case decodeStfu encoded of
+ Left DecodeInvalidInitiator -> pure ()
+ other -> assertFailure $
+ "expected invalid initiator: " ++ show other
, testCase "decodeUpdateAddHtlc too short" $ do
case decodeUpdateAddHtlc (BS.replicate 100 0x00) of
Left DecodeInsufficientBytes -> pure ()
@@ -970,7 +977,7 @@ propTxAddInputRoundtrip prevTxBytes vout seqNum = property $ do
let prevTx = BS.pack (take 1000 prevTxBytes) -- limit size
msg = TxAddInput
{ txAddInputChannelId = testChannelId
- , txAddInputSerialId = 12345
+ , txAddInputSerialId = serialId 12345
, txAddInputPrevTx = prevTx
, txAddInputPrevVout = vout
, txAddInputSequence = seqNum
@@ -987,7 +994,7 @@ propTxAddOutputRoundtrip sats scriptBytes = property $ do
let script = scriptPubKey (BS.pack (take 100 scriptBytes))
msg = TxAddOutput
{ txAddOutputChannelId = testChannelId
- , txAddOutputSerialId = 54321
+ , txAddOutputSerialId = serialId 54321
, txAddOutputSats = Satoshi sats
, txAddOutputScript = script
}
@@ -999,10 +1006,10 @@ propTxAddOutputRoundtrip sats scriptBytes = property $ do
-- Property: TxRemoveInput roundtrip
propTxRemoveInputRoundtrip :: Word64 -> Property
-propTxRemoveInputRoundtrip serialId = property $ do
+propTxRemoveInputRoundtrip sid = property $ do
let msg = TxRemoveInput
{ txRemoveInputChannelId = testChannelId
- , txRemoveInputSerialId = serialId
+ , txRemoveInputSerialId = serialId sid
}
encoded = encodeTxRemoveInput msg
case decodeTxRemoveInput encoded of
@@ -1011,10 +1018,10 @@ propTxRemoveInputRoundtrip serialId = property $ do
-- Property: TxRemoveOutput roundtrip
propTxRemoveOutputRoundtrip :: Word64 -> Property
-propTxRemoveOutputRoundtrip serialId = property $ do
+propTxRemoveOutputRoundtrip sid = property $ do
let msg = TxRemoveOutput
{ txRemoveOutputChannelId = testChannelId
- , txRemoveOutputSerialId = serialId
+ , txRemoveOutputSerialId = serialId sid
}
encoded = encodeTxRemoveOutput msg
case decodeTxRemoveOutput encoded of
@@ -1086,11 +1093,12 @@ propTxAbortRoundtrip dataBytes = property $ do
Left _ -> False
-- Property: Stfu roundtrip
-propStfuRoundtrip :: Word8 -> Property
-propStfuRoundtrip initiator = property $ do
- let msg = Stfu
+propStfuRoundtrip :: Bool -> Property
+propStfuRoundtrip isInit = property $ do
+ let ini = if isInit then IsInitiator else NotInitiator
+ msg = Stfu
{ stfuChannelId = testChannelId
- , stfuInitiator = initiator
+ , stfuInitiator = ini
}
encoded = encodeStfu msg
case decodeStfu encoded of
@@ -1169,10 +1177,10 @@ propClosingSigRoundtrip feeSats locktime = property $ do
-- Property: UpdateAddHtlc roundtrip
propUpdateAddHtlcRoundtrip :: Word64 -> Word64 -> Word32 -> Property
-propUpdateAddHtlcRoundtrip htlcId amountMsat cltvExpiry = property $ do
+propUpdateAddHtlcRoundtrip hid amountMsat cltvExpiry = property $ do
let msg = UpdateAddHtlc
{ updateAddHtlcChannelId = testChannelId
- , updateAddHtlcId = htlcId
+ , updateAddHtlcId = htlcId hid
, updateAddHtlcAmountMsat = MilliSatoshi amountMsat
, updateAddHtlcPaymentHash = testPaymentHash
, updateAddHtlcCltvExpiry = cltvExpiry
@@ -1186,10 +1194,10 @@ propUpdateAddHtlcRoundtrip htlcId amountMsat cltvExpiry = property $ do
-- Property: UpdateFulfillHtlc roundtrip
propUpdateFulfillHtlcRoundtrip :: Word64 -> Property
-propUpdateFulfillHtlcRoundtrip htlcId = property $ do
+propUpdateFulfillHtlcRoundtrip hid = property $ do
let msg = UpdateFulfillHtlc
{ updateFulfillHtlcChannelId = testChannelId
- , updateFulfillHtlcId = htlcId
+ , updateFulfillHtlcId = htlcId hid
, updateFulfillHtlcPaymentPreimage = testPaymentPreimage
, updateFulfillHtlcTlvs = emptyTlvs
}
@@ -1200,11 +1208,11 @@ propUpdateFulfillHtlcRoundtrip htlcId = property $ do
-- Property: UpdateFailHtlc roundtrip
propUpdateFailHtlcRoundtrip :: Word64 -> [Word8] -> Property
-propUpdateFailHtlcRoundtrip htlcId reasonBytes = property $ do
+propUpdateFailHtlcRoundtrip hid reasonBytes = property $ do
let failReason = BS.pack (take 1000 reasonBytes)
msg = UpdateFailHtlc
{ updateFailHtlcChannelId = testChannelId
- , updateFailHtlcId = htlcId
+ , updateFailHtlcId = htlcId hid
, updateFailHtlcReason = failReason
, updateFailHtlcTlvs = emptyTlvs
}
@@ -1216,10 +1224,10 @@ propUpdateFailHtlcRoundtrip htlcId reasonBytes = property $ do
-- Property: UpdateFailMalformedHtlc roundtrip
propUpdateFailMalformedHtlcRoundtrip :: Word64 -> Word16 -> Property
-propUpdateFailMalformedHtlcRoundtrip htlcId failCode = property $ do
+propUpdateFailMalformedHtlcRoundtrip hid failCode = property $ do
let msg = UpdateFailMalformedHtlc
{ updateFailMalformedHtlcChannelId = testChannelId
- , updateFailMalformedHtlcId = htlcId
+ , updateFailMalformedHtlcId = htlcId hid
, updateFailMalformedHtlcSha256Onion = testPaymentHash
, updateFailMalformedHtlcFailureCode = failCode
}