commit 5e83d290b828178f360e53a1349f46bae09db625
parent 991f84308a8597540ddbceb608b189ddfab42353
Author: Jared Tobin <jared@jtobin.io>
Date: Sun, 25 Jan 2026 10:22:55 +0400
test: eliminate non-exhaustive pattern matches
Add expectJust/expectRight helpers that fail tests with clear messages
instead of using partial 'let Just' patterns. Make flip_byte total by
returning IO. All 22 tests pass with no pattern match warnings.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat:
| A | plans/IMPL5.md | | | 22 | ++++++++++++++++++++++ |
| M | test/Main.hs | | | 625 | ++++++++++++++++++++++++++++++++++++------------------------------------------- |
2 files changed, 305 insertions(+), 342 deletions(-)
diff --git a/plans/IMPL5.md b/plans/IMPL5.md
@@ -0,0 +1,22 @@
+# IMPL5: Eliminate non-exhaustive test pattern matches
+
+## Goal
+Remove non-exhaustive pattern match warnings when running tests.
+
+## Steps
+1) Identify the warning locations by grepping for partial pattern
+ matches in tests (e.g., `let Just`, `Right`, or direct field access).
+2) Replace brittle matches with total helpers:
+ - Use small helper functions like `expectRight`/`expectJust`
+ that fail the test with an assertion message.
+ - Avoid `error` and unchecked indexing.
+3) For repeated handshake setup, add a helper returning either
+ `Assertion` failure or the needed values, keeping each test
+ total and readable.
+4) Ensure all tests handle the full set of constructors explicitly
+ (`Left`/`Right`, `Nothing`/`Just`, `FrameResult` variants).
+5) Re-run `cabal test` to confirm warnings are gone.
+
+## Notes
+- Keep helpers local to `test/Main.hs`.
+- Preserve existing test coverage and intent.
diff --git a/test/Main.hs b/test/Main.hs
@@ -1,3 +1,4 @@
+{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Main where
@@ -9,6 +10,20 @@ import qualified Lightning.Protocol.BOLT8 as BOLT8
import Test.Tasty
import Test.Tasty.HUnit
+-- test helpers ----------------------------------------------------------------
+
+-- | Extract a Just value or fail the test.
+expectJust :: String -> Maybe a -> IO a
+expectJust msg = \case
+ Nothing -> assertFailure msg >> error "unreachable"
+ Just a -> pure a
+
+-- | Extract a Right value or fail the test.
+expectRight :: Show e => String -> Either e a -> IO a
+expectRight msg = \case
+ Left e -> assertFailure (msg ++ ": " ++ show e) >> error "unreachable"
+ Right a -> pure a
+
main :: IO ()
main = defaultMain $ testGroup "ppad-bolt8" [
handshake_tests
@@ -75,61 +90,55 @@ handshake_tests = testGroup "Handshake" [
test_act1 :: Assertion
test_act1 = do
- let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
- Just rs = BOLT8.parse_pub responder_s_pub
- case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
- Left err -> assertFailure $ "act1 failed: " ++ show err
- Right (act1_msg, _hs) -> act1_msg @?= expected_act1
+ (i_s_sec, i_s_pub) <- expectJust "initiator keypair"
+ (BOLT8.keypair initiator_s_priv)
+ rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub)
+ (act1_msg, _) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs
+ initiator_e_priv)
+ act1_msg @?= expected_act1
test_act2 :: Assertion
test_act2 = do
- let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
- Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
- Just rs = BOLT8.parse_pub responder_s_pub
-
- case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
- Left err -> assertFailure $ "act1 failed: " ++ show err
- Right (msg1, _) -> do
- case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
- Left err -> assertFailure $ "act2 failed: " ++ show err
- Right (msg2, _) -> msg2 @?= expected_act2
+ (i_s_sec, i_s_pub) <- expectJust "initiator keypair"
+ (BOLT8.keypair initiator_s_priv)
+ (r_s_sec, r_s_pub) <- expectJust "responder keypair"
+ (BOLT8.keypair responder_s_priv)
+ rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub)
+ (msg1, _) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs
+ initiator_e_priv)
+ (msg2, _) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv
+ msg1)
+ msg2 @?= expected_act2
test_act3 :: Assertion
test_act3 = do
- let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
- Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
- Just rs = BOLT8.parse_pub responder_s_pub
-
- case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
- Left err -> assertFailure $ "act1 failed: " ++ show err
- Right (msg1, i_hs) -> do
- case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
- Left err -> assertFailure $ "act2 failed: " ++ show err
- Right (msg2, _) -> do
- case BOLT8.act3 i_hs msg2 of
- Left err -> assertFailure $ "act3 failed: " ++ show err
- Right (msg3, _) -> msg3 @?= expected_act3
+ (i_s_sec, i_s_pub) <- expectJust "initiator keypair"
+ (BOLT8.keypair initiator_s_priv)
+ (r_s_sec, r_s_pub) <- expectJust "responder keypair"
+ (BOLT8.keypair responder_s_priv)
+ rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub)
+ (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs
+ initiator_e_priv)
+ (msg2, _) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv
+ msg1)
+ (msg3, _) <- expectRight "act3" (BOLT8.act3 i_hs msg2)
+ msg3 @?= expected_act3
test_full_handshake :: Assertion
test_full_handshake = do
- let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
- Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
- Just rs = BOLT8.parse_pub responder_s_pub
-
- case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
- Left err -> assertFailure $ "act1 failed: " ++ show err
- Right (msg1, i_hs) -> do
- case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
- Left err -> assertFailure $ "act2 failed: " ++ show err
- Right (msg2, r_hs) -> do
- case BOLT8.act3 i_hs msg2 of
- Left err -> assertFailure $ "act3 failed: " ++ show err
- Right (msg3, i_result) -> do
- case BOLT8.finalize r_hs msg3 of
- Left err -> assertFailure $ "finalize failed: " ++ show err
- Right r_result -> do
- BOLT8.remote_static i_result @?= r_s_pub
- BOLT8.remote_static r_result @?= i_s_pub
+ (i_s_sec, i_s_pub) <- expectJust "initiator keypair"
+ (BOLT8.keypair initiator_s_priv)
+ (r_s_sec, r_s_pub) <- expectJust "responder keypair"
+ (BOLT8.keypair responder_s_priv)
+ rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub)
+ (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs
+ initiator_e_priv)
+ (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv
+ msg1)
+ (msg3, i_result) <- expectRight "act3" (BOLT8.act3 i_hs msg2)
+ r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3)
+ BOLT8.remote_static i_result @?= r_s_pub
+ BOLT8.remote_static r_result @?= i_s_pub
-- message encryption tests --------------------------------------------------
@@ -182,19 +191,17 @@ expected_msg_1001 = hex
-- helper to get initiator session after handshake
get_initiator_session :: IO BOLT8.Session
get_initiator_session = do
- let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
- Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
- Just rs = BOLT8.parse_pub responder_s_pub
-
- case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
- Left err -> fail $ "act1 failed: " ++ show err
- Right (msg1, i_hs) ->
- case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
- Left err -> fail $ "act2 failed: " ++ show err
- Right (msg2, _) ->
- case BOLT8.act3 i_hs msg2 of
- Left err -> fail $ "act3 failed: " ++ show err
- Right (_, result) -> pure (BOLT8.session result)
+ (i_s_sec, i_s_pub) <- expectJust "initiator keypair"
+ (BOLT8.keypair initiator_s_priv)
+ (r_s_sec, r_s_pub) <- expectJust "responder keypair"
+ (BOLT8.keypair responder_s_priv)
+ rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub)
+ (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs
+ initiator_e_priv)
+ (msg2, _) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv
+ msg1)
+ (_, result) <- expectRight "act3" (BOLT8.act3 i_hs msg2)
+ pure (BOLT8.session result)
-- encrypt N messages, return Nth ciphertext
encrypt_n :: Int -> BOLT8.Session -> IO BS.ByteString
@@ -246,31 +253,22 @@ test_message_1001 = do
test_decrypt_roundtrip :: Assertion
test_decrypt_roundtrip = do
- let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
- Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
- Just rs = BOLT8.parse_pub responder_s_pub
-
- case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
- Left err -> assertFailure $ "act1 failed: " ++ show err
- Right (msg1, i_hs) ->
- case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
- Left err -> assertFailure $ "act2 failed: " ++ show err
- Right (msg2, r_hs) ->
- case BOLT8.act3 i_hs msg2 of
- Left err -> assertFailure $ "act3 failed: " ++ show err
- Right (msg3, i_result) ->
- case BOLT8.finalize r_hs msg3 of
- Left err -> assertFailure $ "finalize failed: " ++ show err
- Right r_result -> do
- let i_sess = BOLT8.session i_result
- r_sess = BOLT8.session r_result
- case BOLT8.encrypt i_sess hello of
- Left err -> assertFailure $ "encrypt failed: " ++ show err
- Right (ct, _) ->
- case BOLT8.decrypt r_sess ct of
- Left err ->
- assertFailure $ "decrypt failed: " ++ show err
- Right (pt, _) -> pt @?= hello
+ (i_s_sec, i_s_pub) <- expectJust "initiator keypair"
+ (BOLT8.keypair initiator_s_priv)
+ (r_s_sec, r_s_pub) <- expectJust "responder keypair"
+ (BOLT8.keypair responder_s_priv)
+ rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub)
+ (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs
+ initiator_e_priv)
+ (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv
+ msg1)
+ (msg3, i_result) <- expectRight "act3" (BOLT8.act3 i_hs msg2)
+ r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3)
+ let i_sess = BOLT8.session i_result
+ r_sess = BOLT8.session r_result
+ (ct, _) <- expectRight "encrypt" (BOLT8.encrypt i_sess hello)
+ (pt, _) <- expectRight "decrypt" (BOLT8.decrypt r_sess ct)
+ pt @?= hello
-- framing tests -------------------------------------------------------------
@@ -283,112 +281,78 @@ framing_tests = testGroup "Packet Framing" [
test_decrypt_trailing :: Assertion
test_decrypt_trailing = do
- let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
- Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
- Just rs = BOLT8.parse_pub responder_s_pub
-
- case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
- Left err -> assertFailure $ "act1 failed: " ++ show err
- Right (msg1, i_hs) ->
- case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
- Left err -> assertFailure $ "act2 failed: " ++ show err
- Right (msg2, r_hs) ->
- case BOLT8.act3 i_hs msg2 of
- Left err -> assertFailure $ "act3 failed: " ++ show err
- Right (msg3, i_result) ->
- case BOLT8.finalize r_hs msg3 of
- Left err -> assertFailure $ "finalize failed: " ++ show err
- Right r_result -> do
- let i_sess = BOLT8.session i_result
- r_sess = BOLT8.session r_result
- case BOLT8.encrypt i_sess hello of
- Left err -> assertFailure $ "encrypt failed: " ++ show err
- Right (ct, _) -> do
- -- append trailing bytes
- let ct_with_trailing = ct <> "extra"
- case BOLT8.decrypt r_sess ct_with_trailing of
- Left BOLT8.InvalidLength -> pure ()
- Left err ->
- assertFailure $ "expected InvalidLength, got: "
- ++ show err
- Right _ ->
- assertFailure "decrypt should reject trailing bytes"
+ (i_s_sec, i_s_pub) <- expectJust "initiator keypair"
+ (BOLT8.keypair initiator_s_priv)
+ (r_s_sec, r_s_pub) <- expectJust "responder keypair"
+ (BOLT8.keypair responder_s_priv)
+ rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub)
+ (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs
+ initiator_e_priv)
+ (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv
+ msg1)
+ (msg3, i_result) <- expectRight "act3" (BOLT8.act3 i_hs msg2)
+ r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3)
+ let i_sess = BOLT8.session i_result
+ r_sess = BOLT8.session r_result
+ (ct, _) <- expectRight "encrypt" (BOLT8.encrypt i_sess hello)
+ -- append trailing bytes
+ let ct_with_trailing = ct <> "extra"
+ case BOLT8.decrypt r_sess ct_with_trailing of
+ Left BOLT8.InvalidLength -> pure ()
+ Left err -> assertFailure $ "expected InvalidLength, got: " ++ show err
+ Right _ -> assertFailure "decrypt should reject trailing bytes"
test_decrypt_frame_remainder :: Assertion
test_decrypt_frame_remainder = do
- let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
- Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
- Just rs = BOLT8.parse_pub responder_s_pub
-
- case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
- Left err -> assertFailure $ "act1 failed: " ++ show err
- Right (msg1, i_hs) ->
- case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
- Left err -> assertFailure $ "act2 failed: " ++ show err
- Right (msg2, r_hs) ->
- case BOLT8.act3 i_hs msg2 of
- Left err -> assertFailure $ "act3 failed: " ++ show err
- Right (msg3, i_result) ->
- case BOLT8.finalize r_hs msg3 of
- Left err -> assertFailure $ "finalize failed: " ++ show err
- Right r_result -> do
- let i_sess = BOLT8.session i_result
- r_sess = BOLT8.session r_result
- case BOLT8.encrypt i_sess hello of
- Left err -> assertFailure $ "encrypt failed: " ++ show err
- Right (ct, _) -> do
- let trailing = "remainder"
- ct_with_trailing = ct <> trailing
- case BOLT8.decrypt_frame r_sess ct_with_trailing of
- Left err ->
- assertFailure $ "decrypt_frame failed: " ++ show err
- Right (pt, remainder, _) -> do
- pt @?= hello
- remainder @?= trailing
+ (i_s_sec, i_s_pub) <- expectJust "initiator keypair"
+ (BOLT8.keypair initiator_s_priv)
+ (r_s_sec, r_s_pub) <- expectJust "responder keypair"
+ (BOLT8.keypair responder_s_priv)
+ rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub)
+ (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs
+ initiator_e_priv)
+ (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv
+ msg1)
+ (msg3, i_result) <- expectRight "act3" (BOLT8.act3 i_hs msg2)
+ r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3)
+ let i_sess = BOLT8.session i_result
+ r_sess = BOLT8.session r_result
+ (ct, _) <- expectRight "encrypt" (BOLT8.encrypt i_sess hello)
+ let trailing = "remainder"
+ ct_with_trailing = ct <> trailing
+ (pt, remainder, _) <- expectRight "decrypt_frame"
+ (BOLT8.decrypt_frame r_sess ct_with_trailing)
+ pt @?= hello
+ remainder @?= trailing
test_decrypt_frame_multi :: Assertion
test_decrypt_frame_multi = do
- let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
- Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
- Just rs = BOLT8.parse_pub responder_s_pub
-
- case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
- Left err -> assertFailure $ "act1 failed: " ++ show err
- Right (msg1, i_hs) ->
- case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
- Left err -> assertFailure $ "act2 failed: " ++ show err
- Right (msg2, r_hs) ->
- case BOLT8.act3 i_hs msg2 of
- Left err -> assertFailure $ "act3 failed: " ++ show err
- Right (msg3, i_result) ->
- case BOLT8.finalize r_hs msg3 of
- Left err -> assertFailure $ "finalize failed: " ++ show err
- Right r_result -> do
- let i_sess = BOLT8.session i_result
- r_sess = BOLT8.session r_result
- -- encrypt two messages
- case BOLT8.encrypt i_sess "first" of
- Left err -> assertFailure $ "encrypt 1 failed: " ++ show err
- Right (ct1, i_sess') ->
- case BOLT8.encrypt i_sess' "second" of
- Left err ->
- assertFailure $ "encrypt 2 failed: " ++ show err
- Right (ct2, _) -> do
- -- concatenate frames
- let buffer = ct1 <> ct2
- -- decrypt first frame
- case BOLT8.decrypt_frame r_sess buffer of
- Left err ->
- assertFailure $ "frame 1 failed: " ++ show err
- Right (pt1, rest, r_sess') -> do
- pt1 @?= "first"
- -- decrypt second frame from remainder
- case BOLT8.decrypt_frame r_sess' rest of
- Left err ->
- assertFailure $ "frame 2 failed: " ++ show err
- Right (pt2, rest2, _) -> do
- pt2 @?= "second"
- rest2 @?= BS.empty
+ (i_s_sec, i_s_pub) <- expectJust "initiator keypair"
+ (BOLT8.keypair initiator_s_priv)
+ (r_s_sec, r_s_pub) <- expectJust "responder keypair"
+ (BOLT8.keypair responder_s_priv)
+ rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub)
+ (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs
+ initiator_e_priv)
+ (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv
+ msg1)
+ (msg3, i_result) <- expectRight "act3" (BOLT8.act3 i_hs msg2)
+ r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3)
+ let i_sess = BOLT8.session i_result
+ r_sess = BOLT8.session r_result
+ -- encrypt two messages
+ (ct1, i_sess') <- expectRight "encrypt 1" (BOLT8.encrypt i_sess "first")
+ (ct2, _) <- expectRight "encrypt 2" (BOLT8.encrypt i_sess' "second")
+ -- concatenate frames
+ let buffer = ct1 <> ct2
+ -- decrypt first frame
+ (pt1, rest, r_sess') <- expectRight "frame 1"
+ (BOLT8.decrypt_frame r_sess buffer)
+ pt1 @?= "first"
+ -- decrypt second frame from remainder
+ (pt2, rest2, _) <- expectRight "frame 2" (BOLT8.decrypt_frame r_sess' rest)
+ pt2 @?= "second"
+ rest2 @?= BS.empty
-- partial framing tests -----------------------------------------------------
@@ -401,97 +365,78 @@ partial_framing_tests = testGroup "Partial Framing" [
test_partial_short_buffer :: Assertion
test_partial_short_buffer = do
- let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
- Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
- Just rs = BOLT8.parse_pub responder_s_pub
- case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
- Left err -> assertFailure $ "act1 failed: " ++ show err
- Right (msg1, i_hs) ->
- case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
- Left err -> assertFailure $ "act2 failed: " ++ show err
- Right (msg2, r_hs) ->
- case BOLT8.act3 i_hs msg2 of
- Left err -> assertFailure $ "act3 failed: " ++ show err
- Right (msg3, _) ->
- case BOLT8.finalize r_hs msg3 of
- Left err -> assertFailure $ "finalize failed: " ++ show err
- Right r_result -> do
- let r_sess = BOLT8.session r_result
- short_buf = BS.replicate 10 0x00
- case BOLT8.decrypt_frame_partial r_sess short_buf of
- BOLT8.NeedMore n -> n @?= 8
- BOLT8.FrameOk {} ->
- assertFailure "expected NeedMore, got FrameOk"
- BOLT8.FrameError err ->
- assertFailure $ "expected NeedMore, got: " ++ show err
+ (i_s_sec, i_s_pub) <- expectJust "initiator keypair"
+ (BOLT8.keypair initiator_s_priv)
+ (r_s_sec, r_s_pub) <- expectJust "responder keypair"
+ (BOLT8.keypair responder_s_priv)
+ rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub)
+ (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs
+ initiator_e_priv)
+ (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv
+ msg1)
+ (msg3, _) <- expectRight "act3" (BOLT8.act3 i_hs msg2)
+ r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3)
+ let r_sess = BOLT8.session r_result
+ short_buf = BS.replicate 10 0x00
+ case BOLT8.decrypt_frame_partial r_sess short_buf of
+ BOLT8.NeedMore n -> n @?= 8
+ BOLT8.FrameOk {} -> assertFailure "expected NeedMore, got FrameOk"
+ BOLT8.FrameError err ->
+ assertFailure $ "expected NeedMore, got: " ++ show err
test_partial_body :: Assertion
test_partial_body = do
- let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
- Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
- Just rs = BOLT8.parse_pub responder_s_pub
- case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
- Left err -> assertFailure $ "act1 failed: " ++ show err
- Right (msg1, i_hs) ->
- case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
- Left err -> assertFailure $ "act2 failed: " ++ show err
- Right (msg2, r_hs) ->
- case BOLT8.act3 i_hs msg2 of
- Left err -> assertFailure $ "act3 failed: " ++ show err
- Right (msg3, i_result) ->
- case BOLT8.finalize r_hs msg3 of
- Left err -> assertFailure $ "finalize failed: " ++ show err
- Right r_result -> do
- let i_sess = BOLT8.session i_result
- r_sess = BOLT8.session r_result
- case BOLT8.encrypt i_sess hello of
- Left err -> assertFailure $ "encrypt failed: " ++ show err
- Right (ct, _) -> do
- -- take only length header (18 bytes) + 5 bytes of body
- let partial = BS.take 23 ct
- case BOLT8.decrypt_frame_partial r_sess partial of
- BOLT8.NeedMore n -> do
- -- "hello" = 5 bytes, so body = 5 + 16 = 21
- -- we have 5 bytes of body, need 16 more
- n @?= 16
- BOLT8.FrameOk {} ->
- assertFailure "expected NeedMore, got FrameOk"
- BOLT8.FrameError err ->
- assertFailure $ "expected NeedMore, got: " ++ show err
+ (i_s_sec, i_s_pub) <- expectJust "initiator keypair"
+ (BOLT8.keypair initiator_s_priv)
+ (r_s_sec, r_s_pub) <- expectJust "responder keypair"
+ (BOLT8.keypair responder_s_priv)
+ rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub)
+ (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs
+ initiator_e_priv)
+ (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv
+ msg1)
+ (msg3, i_result) <- expectRight "act3" (BOLT8.act3 i_hs msg2)
+ r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3)
+ let i_sess = BOLT8.session i_result
+ r_sess = BOLT8.session r_result
+ (ct, _) <- expectRight "encrypt" (BOLT8.encrypt i_sess hello)
+ -- take only length header (18 bytes) + 5 bytes of body
+ let partial = BS.take 23 ct
+ case BOLT8.decrypt_frame_partial r_sess partial of
+ BOLT8.NeedMore n -> do
+ -- "hello" = 5 bytes, so body = 5 + 16 = 21
+ -- we have 5 bytes of body, need 16 more
+ n @?= 16
+ BOLT8.FrameOk {} -> assertFailure "expected NeedMore, got FrameOk"
+ BOLT8.FrameError err ->
+ assertFailure $ "expected NeedMore, got: " ++ show err
test_partial_full_frame :: Assertion
test_partial_full_frame = do
- let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
- Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
- Just rs = BOLT8.parse_pub responder_s_pub
- case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
- Left err -> assertFailure $ "act1 failed: " ++ show err
- Right (msg1, i_hs) ->
- case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
- Left err -> assertFailure $ "act2 failed: " ++ show err
- Right (msg2, r_hs) ->
- case BOLT8.act3 i_hs msg2 of
- Left err -> assertFailure $ "act3 failed: " ++ show err
- Right (msg3, i_result) ->
- case BOLT8.finalize r_hs msg3 of
- Left err -> assertFailure $ "finalize failed: " ++ show err
- Right r_result -> do
- let i_sess = BOLT8.session i_result
- r_sess = BOLT8.session r_result
- case BOLT8.encrypt i_sess hello of
- Left err -> assertFailure $ "encrypt failed: " ++ show err
- Right (ct, _) -> do
- let trailing = "extra"
- buf = ct <> trailing
- case BOLT8.decrypt_frame_partial r_sess buf of
- BOLT8.FrameOk pt remainder _ -> do
- pt @?= hello
- remainder @?= trailing
- BOLT8.NeedMore n ->
- assertFailure $ "expected FrameOk, got NeedMore "
- ++ show n
- BOLT8.FrameError err ->
- assertFailure $ "expected FrameOk, got: " ++ show err
+ (i_s_sec, i_s_pub) <- expectJust "initiator keypair"
+ (BOLT8.keypair initiator_s_priv)
+ (r_s_sec, r_s_pub) <- expectJust "responder keypair"
+ (BOLT8.keypair responder_s_priv)
+ rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub)
+ (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs
+ initiator_e_priv)
+ (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv
+ msg1)
+ (msg3, i_result) <- expectRight "act3" (BOLT8.act3 i_hs msg2)
+ r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3)
+ let i_sess = BOLT8.session i_result
+ r_sess = BOLT8.session r_result
+ (ct, _) <- expectRight "encrypt" (BOLT8.encrypt i_sess hello)
+ let trailing = "extra"
+ buf = ct <> trailing
+ case BOLT8.decrypt_frame_partial r_sess buf of
+ BOLT8.FrameOk pt remainder _ -> do
+ pt @?= hello
+ remainder @?= trailing
+ BOLT8.NeedMore n ->
+ assertFailure $ "expected FrameOk, got NeedMore " ++ show n
+ BOLT8.FrameError err ->
+ assertFailure $ "expected FrameOk, got: " ++ show err
-- negative tests ------------------------------------------------------------
@@ -506,22 +451,23 @@ negative_tests = testGroup "Negative Tests" [
test_act2_wrong_version :: Assertion
test_act2_wrong_version = do
- let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
- Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
- Just rs = BOLT8.parse_pub responder_s_pub
- case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
- Left err -> assertFailure $ "act1 failed: " ++ show err
- Right (msg1, _) -> do
- let bad_msg1 = BS.cons 0x01 (BS.drop 1 msg1)
- case BOLT8.act2 r_s_sec r_s_pub responder_e_priv bad_msg1 of
- Left BOLT8.InvalidVersion -> pure ()
- Left err -> assertFailure $ "expected InvalidVersion, got: " ++ show err
- Right _ -> assertFailure "expected rejection, got success"
+ (i_s_sec, i_s_pub) <- expectJust "initiator keypair"
+ (BOLT8.keypair initiator_s_priv)
+ (r_s_sec, r_s_pub) <- expectJust "responder keypair"
+ (BOLT8.keypair responder_s_priv)
+ rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub)
+ (msg1, _) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv)
+ let bad_msg1 = BS.cons 0x01 (BS.drop 1 msg1)
+ case BOLT8.act2 r_s_sec r_s_pub responder_e_priv bad_msg1 of
+ Left BOLT8.InvalidVersion -> pure ()
+ Left err -> assertFailure $ "expected InvalidVersion, got: " ++ show err
+ Right _ -> assertFailure "expected rejection, got success"
test_act2_wrong_length :: Assertion
test_act2_wrong_length = do
- let Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
- short_msg = BS.replicate 49 0x00
+ (r_s_sec, r_s_pub) <- expectJust "responder keypair"
+ (BOLT8.keypair responder_s_priv)
+ let short_msg = BS.replicate 49 0x00
case BOLT8.act2 r_s_sec r_s_pub responder_e_priv short_msg of
Left BOLT8.InvalidLength -> pure ()
Left err -> assertFailure $ "expected InvalidLength, got: " ++ show err
@@ -529,80 +475,75 @@ test_act2_wrong_length = do
test_act3_invalid_mac :: Assertion
test_act3_invalid_mac = do
- let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
- Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
- Just rs = BOLT8.parse_pub responder_s_pub
- case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
- Left err -> assertFailure $ "act1 failed: " ++ show err
- Right (msg1, i_hs) ->
- case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
- Left err -> assertFailure $ "act2 failed: " ++ show err
- Right (msg2, _) -> do
- let bad_msg2 = flip_byte 40 msg2
- case BOLT8.act3 i_hs bad_msg2 of
- Left BOLT8.InvalidMAC -> pure ()
- Left err ->
- assertFailure $ "expected InvalidMAC, got: " ++ show err
- Right _ -> assertFailure "expected rejection, got success"
+ (i_s_sec, i_s_pub) <- expectJust "initiator keypair"
+ (BOLT8.keypair initiator_s_priv)
+ (r_s_sec, r_s_pub) <- expectJust "responder keypair"
+ (BOLT8.keypair responder_s_priv)
+ rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub)
+ (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs
+ initiator_e_priv)
+ (msg2, _) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv
+ msg1)
+ bad_msg2 <- flip_byte 40 msg2
+ case BOLT8.act3 i_hs bad_msg2 of
+ Left BOLT8.InvalidMAC -> pure ()
+ Left err -> assertFailure $ "expected InvalidMAC, got: " ++ show err
+ Right _ -> assertFailure "expected rejection, got success"
test_finalize_invalid_mac :: Assertion
test_finalize_invalid_mac = do
- let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
- Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
- Just rs = BOLT8.parse_pub responder_s_pub
- case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
- Left err -> assertFailure $ "act1 failed: " ++ show err
- Right (msg1, i_hs) ->
- case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
- Left err -> assertFailure $ "act2 failed: " ++ show err
- Right (msg2, r_hs) ->
- case BOLT8.act3 i_hs msg2 of
- Left err -> assertFailure $ "act3 failed: " ++ show err
- Right (msg3, _) -> do
- let bad_msg3 = flip_byte 20 msg3
- case BOLT8.finalize r_hs bad_msg3 of
- Left BOLT8.InvalidMAC -> pure ()
- Left err ->
- assertFailure $ "expected InvalidMAC, got: " ++ show err
- Right _ -> assertFailure "expected rejection, got success"
+ (i_s_sec, i_s_pub) <- expectJust "initiator keypair"
+ (BOLT8.keypair initiator_s_priv)
+ (r_s_sec, r_s_pub) <- expectJust "responder keypair"
+ (BOLT8.keypair responder_s_priv)
+ rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub)
+ (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs
+ initiator_e_priv)
+ (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv
+ msg1)
+ (msg3, _) <- expectRight "act3" (BOLT8.act3 i_hs msg2)
+ bad_msg3 <- flip_byte 20 msg3
+ case BOLT8.finalize r_hs bad_msg3 of
+ Left BOLT8.InvalidMAC -> pure ()
+ Left err -> assertFailure $ "expected InvalidMAC, got: " ++ show err
+ Right _ -> assertFailure "expected rejection, got success"
test_decrypt_short_packet :: Assertion
test_decrypt_short_packet = do
- let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
- Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
- Just rs = BOLT8.parse_pub responder_s_pub
- case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
- Left err -> assertFailure $ "act1 failed: " ++ show err
- Right (msg1, i_hs) ->
- case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
- Left err -> assertFailure $ "act2 failed: " ++ show err
- Right (msg2, r_hs) ->
- case BOLT8.act3 i_hs msg2 of
- Left err -> assertFailure $ "act3 failed: " ++ show err
- Right (msg3, _) ->
- case BOLT8.finalize r_hs msg3 of
- Left err -> assertFailure $ "finalize failed: " ++ show err
- Right r_result -> do
- let r_sess = BOLT8.session r_result
- short_packet = BS.replicate 17 0x00
- case BOLT8.decrypt r_sess short_packet of
- Left BOLT8.InvalidLength -> pure ()
- Left err ->
- assertFailure $ "expected InvalidLength, got: " ++ show err
- Right _ -> assertFailure "expected rejection, got success"
+ (i_s_sec, i_s_pub) <- expectJust "initiator keypair"
+ (BOLT8.keypair initiator_s_priv)
+ (r_s_sec, r_s_pub) <- expectJust "responder keypair"
+ (BOLT8.keypair responder_s_priv)
+ rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub)
+ (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs
+ initiator_e_priv)
+ (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv
+ msg1)
+ (msg3, _) <- expectRight "act3" (BOLT8.act3 i_hs msg2)
+ r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3)
+ let r_sess = BOLT8.session r_result
+ short_packet = BS.replicate 17 0x00
+ case BOLT8.decrypt r_sess short_packet of
+ Left BOLT8.InvalidLength -> pure ()
+ Left err -> assertFailure $ "expected InvalidLength, got: " ++ show err
+ Right _ -> assertFailure "expected rejection, got success"
-- flip one byte in a bytestring at given index
-flip_byte :: Int -> BS.ByteString -> BS.ByteString
+flip_byte :: Int -> BS.ByteString -> IO BS.ByteString
flip_byte i bs
- | i < 0 || i >= BS.length bs = error "flip_byte: index out of bounds"
+ | i < 0 || i >= BS.length bs =
+ assertFailure "flip_byte: index out of bounds" >> pure bs
| otherwise =
let (pre, post) = BS.splitAt i bs
b = BS.index post 0
- in pre <> BS.cons (b `xor` 0xff) (BS.drop 1 post)
+ in pure (pre <> BS.cons (b `xor` 0xff) (BS.drop 1 post))
-- utilities -----------------------------------------------------------------
+-- Safe hex decode for test vectors (only called at top level with known-good
+-- literals). This uses error since it's for compile-time constants, not runtime
+-- input; wrapping in IO would break the test vector declarations.
hex :: BS.ByteString -> BS.ByteString
hex bs = case B16.decode bs of
- Nothing -> error "invalid hex"
- Just r -> r
+ Nothing -> error "hex: invalid test vector literal"
+ Just r -> r