bolt8

Encrypted and authenticated transport, per BOLT #8.
git clone git://git.ppad.tech/bolt8.git
Log | Files | Refs | README | LICENSE

commit 5e83d290b828178f360e53a1349f46bae09db625
parent 991f84308a8597540ddbceb608b189ddfab42353
Author: Jared Tobin <jared@jtobin.io>
Date:   Sun, 25 Jan 2026 10:22:55 +0400

test: eliminate non-exhaustive pattern matches

Add expectJust/expectRight helpers that fail tests with clear messages
instead of using partial 'let Just' patterns. Make flip_byte total by
returning IO. All 22 tests pass with no pattern match warnings.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Diffstat:
Aplans/IMPL5.md | 22++++++++++++++++++++++
Mtest/Main.hs | 625++++++++++++++++++++++++++++++++++++-------------------------------------------
2 files changed, 305 insertions(+), 342 deletions(-)

diff --git a/plans/IMPL5.md b/plans/IMPL5.md @@ -0,0 +1,22 @@ +# IMPL5: Eliminate non-exhaustive test pattern matches + +## Goal +Remove non-exhaustive pattern match warnings when running tests. + +## Steps +1) Identify the warning locations by grepping for partial pattern + matches in tests (e.g., `let Just`, `Right`, or direct field access). +2) Replace brittle matches with total helpers: + - Use small helper functions like `expectRight`/`expectJust` + that fail the test with an assertion message. + - Avoid `error` and unchecked indexing. +3) For repeated handshake setup, add a helper returning either + `Assertion` failure or the needed values, keeping each test + total and readable. +4) Ensure all tests handle the full set of constructors explicitly + (`Left`/`Right`, `Nothing`/`Just`, `FrameResult` variants). +5) Re-run `cabal test` to confirm warnings are gone. + +## Notes +- Keep helpers local to `test/Main.hs`. +- Preserve existing test coverage and intent. diff --git a/test/Main.hs b/test/Main.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} module Main where @@ -9,6 +10,20 @@ import qualified Lightning.Protocol.BOLT8 as BOLT8 import Test.Tasty import Test.Tasty.HUnit +-- test helpers ---------------------------------------------------------------- + +-- | Extract a Just value or fail the test. +expectJust :: String -> Maybe a -> IO a +expectJust msg = \case + Nothing -> assertFailure msg >> error "unreachable" + Just a -> pure a + +-- | Extract a Right value or fail the test. +expectRight :: Show e => String -> Either e a -> IO a +expectRight msg = \case + Left e -> assertFailure (msg ++ ": " ++ show e) >> error "unreachable" + Right a -> pure a + main :: IO () main = defaultMain $ testGroup "ppad-bolt8" [ handshake_tests @@ -75,61 +90,55 @@ handshake_tests = testGroup "Handshake" [ test_act1 :: Assertion test_act1 = do - let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv - Just rs = BOLT8.parse_pub responder_s_pub - case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of - Left err -> assertFailure $ "act1 failed: " ++ show err - Right (act1_msg, _hs) -> act1_msg @?= expected_act1 + (i_s_sec, i_s_pub) <- expectJust "initiator keypair" + (BOLT8.keypair initiator_s_priv) + rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) + (act1_msg, _) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs + initiator_e_priv) + act1_msg @?= expected_act1 test_act2 :: Assertion test_act2 = do - let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv - Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv - Just rs = BOLT8.parse_pub responder_s_pub - - case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of - Left err -> assertFailure $ "act1 failed: " ++ show err - Right (msg1, _) -> do - case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of - Left err -> assertFailure $ "act2 failed: " ++ show err - Right (msg2, _) -> msg2 @?= expected_act2 + (i_s_sec, i_s_pub) <- expectJust "initiator keypair" + (BOLT8.keypair initiator_s_priv) + (r_s_sec, r_s_pub) <- expectJust "responder keypair" + (BOLT8.keypair responder_s_priv) + rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) + (msg1, _) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs + initiator_e_priv) + (msg2, _) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv + msg1) + msg2 @?= expected_act2 test_act3 :: Assertion test_act3 = do - let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv - Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv - Just rs = BOLT8.parse_pub responder_s_pub - - case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of - Left err -> assertFailure $ "act1 failed: " ++ show err - Right (msg1, i_hs) -> do - case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of - Left err -> assertFailure $ "act2 failed: " ++ show err - Right (msg2, _) -> do - case BOLT8.act3 i_hs msg2 of - Left err -> assertFailure $ "act3 failed: " ++ show err - Right (msg3, _) -> msg3 @?= expected_act3 + (i_s_sec, i_s_pub) <- expectJust "initiator keypair" + (BOLT8.keypair initiator_s_priv) + (r_s_sec, r_s_pub) <- expectJust "responder keypair" + (BOLT8.keypair responder_s_priv) + rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) + (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs + initiator_e_priv) + (msg2, _) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv + msg1) + (msg3, _) <- expectRight "act3" (BOLT8.act3 i_hs msg2) + msg3 @?= expected_act3 test_full_handshake :: Assertion test_full_handshake = do - let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv - Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv - Just rs = BOLT8.parse_pub responder_s_pub - - case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of - Left err -> assertFailure $ "act1 failed: " ++ show err - Right (msg1, i_hs) -> do - case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of - Left err -> assertFailure $ "act2 failed: " ++ show err - Right (msg2, r_hs) -> do - case BOLT8.act3 i_hs msg2 of - Left err -> assertFailure $ "act3 failed: " ++ show err - Right (msg3, i_result) -> do - case BOLT8.finalize r_hs msg3 of - Left err -> assertFailure $ "finalize failed: " ++ show err - Right r_result -> do - BOLT8.remote_static i_result @?= r_s_pub - BOLT8.remote_static r_result @?= i_s_pub + (i_s_sec, i_s_pub) <- expectJust "initiator keypair" + (BOLT8.keypair initiator_s_priv) + (r_s_sec, r_s_pub) <- expectJust "responder keypair" + (BOLT8.keypair responder_s_priv) + rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) + (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs + initiator_e_priv) + (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv + msg1) + (msg3, i_result) <- expectRight "act3" (BOLT8.act3 i_hs msg2) + r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3) + BOLT8.remote_static i_result @?= r_s_pub + BOLT8.remote_static r_result @?= i_s_pub -- message encryption tests -------------------------------------------------- @@ -182,19 +191,17 @@ expected_msg_1001 = hex -- helper to get initiator session after handshake get_initiator_session :: IO BOLT8.Session get_initiator_session = do - let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv - Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv - Just rs = BOLT8.parse_pub responder_s_pub - - case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of - Left err -> fail $ "act1 failed: " ++ show err - Right (msg1, i_hs) -> - case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of - Left err -> fail $ "act2 failed: " ++ show err - Right (msg2, _) -> - case BOLT8.act3 i_hs msg2 of - Left err -> fail $ "act3 failed: " ++ show err - Right (_, result) -> pure (BOLT8.session result) + (i_s_sec, i_s_pub) <- expectJust "initiator keypair" + (BOLT8.keypair initiator_s_priv) + (r_s_sec, r_s_pub) <- expectJust "responder keypair" + (BOLT8.keypair responder_s_priv) + rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) + (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs + initiator_e_priv) + (msg2, _) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv + msg1) + (_, result) <- expectRight "act3" (BOLT8.act3 i_hs msg2) + pure (BOLT8.session result) -- encrypt N messages, return Nth ciphertext encrypt_n :: Int -> BOLT8.Session -> IO BS.ByteString @@ -246,31 +253,22 @@ test_message_1001 = do test_decrypt_roundtrip :: Assertion test_decrypt_roundtrip = do - let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv - Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv - Just rs = BOLT8.parse_pub responder_s_pub - - case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of - Left err -> assertFailure $ "act1 failed: " ++ show err - Right (msg1, i_hs) -> - case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of - Left err -> assertFailure $ "act2 failed: " ++ show err - Right (msg2, r_hs) -> - case BOLT8.act3 i_hs msg2 of - Left err -> assertFailure $ "act3 failed: " ++ show err - Right (msg3, i_result) -> - case BOLT8.finalize r_hs msg3 of - Left err -> assertFailure $ "finalize failed: " ++ show err - Right r_result -> do - let i_sess = BOLT8.session i_result - r_sess = BOLT8.session r_result - case BOLT8.encrypt i_sess hello of - Left err -> assertFailure $ "encrypt failed: " ++ show err - Right (ct, _) -> - case BOLT8.decrypt r_sess ct of - Left err -> - assertFailure $ "decrypt failed: " ++ show err - Right (pt, _) -> pt @?= hello + (i_s_sec, i_s_pub) <- expectJust "initiator keypair" + (BOLT8.keypair initiator_s_priv) + (r_s_sec, r_s_pub) <- expectJust "responder keypair" + (BOLT8.keypair responder_s_priv) + rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) + (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs + initiator_e_priv) + (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv + msg1) + (msg3, i_result) <- expectRight "act3" (BOLT8.act3 i_hs msg2) + r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3) + let i_sess = BOLT8.session i_result + r_sess = BOLT8.session r_result + (ct, _) <- expectRight "encrypt" (BOLT8.encrypt i_sess hello) + (pt, _) <- expectRight "decrypt" (BOLT8.decrypt r_sess ct) + pt @?= hello -- framing tests ------------------------------------------------------------- @@ -283,112 +281,78 @@ framing_tests = testGroup "Packet Framing" [ test_decrypt_trailing :: Assertion test_decrypt_trailing = do - let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv - Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv - Just rs = BOLT8.parse_pub responder_s_pub - - case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of - Left err -> assertFailure $ "act1 failed: " ++ show err - Right (msg1, i_hs) -> - case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of - Left err -> assertFailure $ "act2 failed: " ++ show err - Right (msg2, r_hs) -> - case BOLT8.act3 i_hs msg2 of - Left err -> assertFailure $ "act3 failed: " ++ show err - Right (msg3, i_result) -> - case BOLT8.finalize r_hs msg3 of - Left err -> assertFailure $ "finalize failed: " ++ show err - Right r_result -> do - let i_sess = BOLT8.session i_result - r_sess = BOLT8.session r_result - case BOLT8.encrypt i_sess hello of - Left err -> assertFailure $ "encrypt failed: " ++ show err - Right (ct, _) -> do - -- append trailing bytes - let ct_with_trailing = ct <> "extra" - case BOLT8.decrypt r_sess ct_with_trailing of - Left BOLT8.InvalidLength -> pure () - Left err -> - assertFailure $ "expected InvalidLength, got: " - ++ show err - Right _ -> - assertFailure "decrypt should reject trailing bytes" + (i_s_sec, i_s_pub) <- expectJust "initiator keypair" + (BOLT8.keypair initiator_s_priv) + (r_s_sec, r_s_pub) <- expectJust "responder keypair" + (BOLT8.keypair responder_s_priv) + rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) + (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs + initiator_e_priv) + (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv + msg1) + (msg3, i_result) <- expectRight "act3" (BOLT8.act3 i_hs msg2) + r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3) + let i_sess = BOLT8.session i_result + r_sess = BOLT8.session r_result + (ct, _) <- expectRight "encrypt" (BOLT8.encrypt i_sess hello) + -- append trailing bytes + let ct_with_trailing = ct <> "extra" + case BOLT8.decrypt r_sess ct_with_trailing of + Left BOLT8.InvalidLength -> pure () + Left err -> assertFailure $ "expected InvalidLength, got: " ++ show err + Right _ -> assertFailure "decrypt should reject trailing bytes" test_decrypt_frame_remainder :: Assertion test_decrypt_frame_remainder = do - let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv - Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv - Just rs = BOLT8.parse_pub responder_s_pub - - case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of - Left err -> assertFailure $ "act1 failed: " ++ show err - Right (msg1, i_hs) -> - case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of - Left err -> assertFailure $ "act2 failed: " ++ show err - Right (msg2, r_hs) -> - case BOLT8.act3 i_hs msg2 of - Left err -> assertFailure $ "act3 failed: " ++ show err - Right (msg3, i_result) -> - case BOLT8.finalize r_hs msg3 of - Left err -> assertFailure $ "finalize failed: " ++ show err - Right r_result -> do - let i_sess = BOLT8.session i_result - r_sess = BOLT8.session r_result - case BOLT8.encrypt i_sess hello of - Left err -> assertFailure $ "encrypt failed: " ++ show err - Right (ct, _) -> do - let trailing = "remainder" - ct_with_trailing = ct <> trailing - case BOLT8.decrypt_frame r_sess ct_with_trailing of - Left err -> - assertFailure $ "decrypt_frame failed: " ++ show err - Right (pt, remainder, _) -> do - pt @?= hello - remainder @?= trailing + (i_s_sec, i_s_pub) <- expectJust "initiator keypair" + (BOLT8.keypair initiator_s_priv) + (r_s_sec, r_s_pub) <- expectJust "responder keypair" + (BOLT8.keypair responder_s_priv) + rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) + (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs + initiator_e_priv) + (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv + msg1) + (msg3, i_result) <- expectRight "act3" (BOLT8.act3 i_hs msg2) + r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3) + let i_sess = BOLT8.session i_result + r_sess = BOLT8.session r_result + (ct, _) <- expectRight "encrypt" (BOLT8.encrypt i_sess hello) + let trailing = "remainder" + ct_with_trailing = ct <> trailing + (pt, remainder, _) <- expectRight "decrypt_frame" + (BOLT8.decrypt_frame r_sess ct_with_trailing) + pt @?= hello + remainder @?= trailing test_decrypt_frame_multi :: Assertion test_decrypt_frame_multi = do - let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv - Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv - Just rs = BOLT8.parse_pub responder_s_pub - - case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of - Left err -> assertFailure $ "act1 failed: " ++ show err - Right (msg1, i_hs) -> - case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of - Left err -> assertFailure $ "act2 failed: " ++ show err - Right (msg2, r_hs) -> - case BOLT8.act3 i_hs msg2 of - Left err -> assertFailure $ "act3 failed: " ++ show err - Right (msg3, i_result) -> - case BOLT8.finalize r_hs msg3 of - Left err -> assertFailure $ "finalize failed: " ++ show err - Right r_result -> do - let i_sess = BOLT8.session i_result - r_sess = BOLT8.session r_result - -- encrypt two messages - case BOLT8.encrypt i_sess "first" of - Left err -> assertFailure $ "encrypt 1 failed: " ++ show err - Right (ct1, i_sess') -> - case BOLT8.encrypt i_sess' "second" of - Left err -> - assertFailure $ "encrypt 2 failed: " ++ show err - Right (ct2, _) -> do - -- concatenate frames - let buffer = ct1 <> ct2 - -- decrypt first frame - case BOLT8.decrypt_frame r_sess buffer of - Left err -> - assertFailure $ "frame 1 failed: " ++ show err - Right (pt1, rest, r_sess') -> do - pt1 @?= "first" - -- decrypt second frame from remainder - case BOLT8.decrypt_frame r_sess' rest of - Left err -> - assertFailure $ "frame 2 failed: " ++ show err - Right (pt2, rest2, _) -> do - pt2 @?= "second" - rest2 @?= BS.empty + (i_s_sec, i_s_pub) <- expectJust "initiator keypair" + (BOLT8.keypair initiator_s_priv) + (r_s_sec, r_s_pub) <- expectJust "responder keypair" + (BOLT8.keypair responder_s_priv) + rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) + (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs + initiator_e_priv) + (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv + msg1) + (msg3, i_result) <- expectRight "act3" (BOLT8.act3 i_hs msg2) + r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3) + let i_sess = BOLT8.session i_result + r_sess = BOLT8.session r_result + -- encrypt two messages + (ct1, i_sess') <- expectRight "encrypt 1" (BOLT8.encrypt i_sess "first") + (ct2, _) <- expectRight "encrypt 2" (BOLT8.encrypt i_sess' "second") + -- concatenate frames + let buffer = ct1 <> ct2 + -- decrypt first frame + (pt1, rest, r_sess') <- expectRight "frame 1" + (BOLT8.decrypt_frame r_sess buffer) + pt1 @?= "first" + -- decrypt second frame from remainder + (pt2, rest2, _) <- expectRight "frame 2" (BOLT8.decrypt_frame r_sess' rest) + pt2 @?= "second" + rest2 @?= BS.empty -- partial framing tests ----------------------------------------------------- @@ -401,97 +365,78 @@ partial_framing_tests = testGroup "Partial Framing" [ test_partial_short_buffer :: Assertion test_partial_short_buffer = do - let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv - Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv - Just rs = BOLT8.parse_pub responder_s_pub - case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of - Left err -> assertFailure $ "act1 failed: " ++ show err - Right (msg1, i_hs) -> - case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of - Left err -> assertFailure $ "act2 failed: " ++ show err - Right (msg2, r_hs) -> - case BOLT8.act3 i_hs msg2 of - Left err -> assertFailure $ "act3 failed: " ++ show err - Right (msg3, _) -> - case BOLT8.finalize r_hs msg3 of - Left err -> assertFailure $ "finalize failed: " ++ show err - Right r_result -> do - let r_sess = BOLT8.session r_result - short_buf = BS.replicate 10 0x00 - case BOLT8.decrypt_frame_partial r_sess short_buf of - BOLT8.NeedMore n -> n @?= 8 - BOLT8.FrameOk {} -> - assertFailure "expected NeedMore, got FrameOk" - BOLT8.FrameError err -> - assertFailure $ "expected NeedMore, got: " ++ show err + (i_s_sec, i_s_pub) <- expectJust "initiator keypair" + (BOLT8.keypair initiator_s_priv) + (r_s_sec, r_s_pub) <- expectJust "responder keypair" + (BOLT8.keypair responder_s_priv) + rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) + (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs + initiator_e_priv) + (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv + msg1) + (msg3, _) <- expectRight "act3" (BOLT8.act3 i_hs msg2) + r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3) + let r_sess = BOLT8.session r_result + short_buf = BS.replicate 10 0x00 + case BOLT8.decrypt_frame_partial r_sess short_buf of + BOLT8.NeedMore n -> n @?= 8 + BOLT8.FrameOk {} -> assertFailure "expected NeedMore, got FrameOk" + BOLT8.FrameError err -> + assertFailure $ "expected NeedMore, got: " ++ show err test_partial_body :: Assertion test_partial_body = do - let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv - Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv - Just rs = BOLT8.parse_pub responder_s_pub - case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of - Left err -> assertFailure $ "act1 failed: " ++ show err - Right (msg1, i_hs) -> - case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of - Left err -> assertFailure $ "act2 failed: " ++ show err - Right (msg2, r_hs) -> - case BOLT8.act3 i_hs msg2 of - Left err -> assertFailure $ "act3 failed: " ++ show err - Right (msg3, i_result) -> - case BOLT8.finalize r_hs msg3 of - Left err -> assertFailure $ "finalize failed: " ++ show err - Right r_result -> do - let i_sess = BOLT8.session i_result - r_sess = BOLT8.session r_result - case BOLT8.encrypt i_sess hello of - Left err -> assertFailure $ "encrypt failed: " ++ show err - Right (ct, _) -> do - -- take only length header (18 bytes) + 5 bytes of body - let partial = BS.take 23 ct - case BOLT8.decrypt_frame_partial r_sess partial of - BOLT8.NeedMore n -> do - -- "hello" = 5 bytes, so body = 5 + 16 = 21 - -- we have 5 bytes of body, need 16 more - n @?= 16 - BOLT8.FrameOk {} -> - assertFailure "expected NeedMore, got FrameOk" - BOLT8.FrameError err -> - assertFailure $ "expected NeedMore, got: " ++ show err + (i_s_sec, i_s_pub) <- expectJust "initiator keypair" + (BOLT8.keypair initiator_s_priv) + (r_s_sec, r_s_pub) <- expectJust "responder keypair" + (BOLT8.keypair responder_s_priv) + rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) + (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs + initiator_e_priv) + (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv + msg1) + (msg3, i_result) <- expectRight "act3" (BOLT8.act3 i_hs msg2) + r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3) + let i_sess = BOLT8.session i_result + r_sess = BOLT8.session r_result + (ct, _) <- expectRight "encrypt" (BOLT8.encrypt i_sess hello) + -- take only length header (18 bytes) + 5 bytes of body + let partial = BS.take 23 ct + case BOLT8.decrypt_frame_partial r_sess partial of + BOLT8.NeedMore n -> do + -- "hello" = 5 bytes, so body = 5 + 16 = 21 + -- we have 5 bytes of body, need 16 more + n @?= 16 + BOLT8.FrameOk {} -> assertFailure "expected NeedMore, got FrameOk" + BOLT8.FrameError err -> + assertFailure $ "expected NeedMore, got: " ++ show err test_partial_full_frame :: Assertion test_partial_full_frame = do - let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv - Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv - Just rs = BOLT8.parse_pub responder_s_pub - case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of - Left err -> assertFailure $ "act1 failed: " ++ show err - Right (msg1, i_hs) -> - case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of - Left err -> assertFailure $ "act2 failed: " ++ show err - Right (msg2, r_hs) -> - case BOLT8.act3 i_hs msg2 of - Left err -> assertFailure $ "act3 failed: " ++ show err - Right (msg3, i_result) -> - case BOLT8.finalize r_hs msg3 of - Left err -> assertFailure $ "finalize failed: " ++ show err - Right r_result -> do - let i_sess = BOLT8.session i_result - r_sess = BOLT8.session r_result - case BOLT8.encrypt i_sess hello of - Left err -> assertFailure $ "encrypt failed: " ++ show err - Right (ct, _) -> do - let trailing = "extra" - buf = ct <> trailing - case BOLT8.decrypt_frame_partial r_sess buf of - BOLT8.FrameOk pt remainder _ -> do - pt @?= hello - remainder @?= trailing - BOLT8.NeedMore n -> - assertFailure $ "expected FrameOk, got NeedMore " - ++ show n - BOLT8.FrameError err -> - assertFailure $ "expected FrameOk, got: " ++ show err + (i_s_sec, i_s_pub) <- expectJust "initiator keypair" + (BOLT8.keypair initiator_s_priv) + (r_s_sec, r_s_pub) <- expectJust "responder keypair" + (BOLT8.keypair responder_s_priv) + rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) + (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs + initiator_e_priv) + (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv + msg1) + (msg3, i_result) <- expectRight "act3" (BOLT8.act3 i_hs msg2) + r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3) + let i_sess = BOLT8.session i_result + r_sess = BOLT8.session r_result + (ct, _) <- expectRight "encrypt" (BOLT8.encrypt i_sess hello) + let trailing = "extra" + buf = ct <> trailing + case BOLT8.decrypt_frame_partial r_sess buf of + BOLT8.FrameOk pt remainder _ -> do + pt @?= hello + remainder @?= trailing + BOLT8.NeedMore n -> + assertFailure $ "expected FrameOk, got NeedMore " ++ show n + BOLT8.FrameError err -> + assertFailure $ "expected FrameOk, got: " ++ show err -- negative tests ------------------------------------------------------------ @@ -506,22 +451,23 @@ negative_tests = testGroup "Negative Tests" [ test_act2_wrong_version :: Assertion test_act2_wrong_version = do - let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv - Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv - Just rs = BOLT8.parse_pub responder_s_pub - case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of - Left err -> assertFailure $ "act1 failed: " ++ show err - Right (msg1, _) -> do - let bad_msg1 = BS.cons 0x01 (BS.drop 1 msg1) - case BOLT8.act2 r_s_sec r_s_pub responder_e_priv bad_msg1 of - Left BOLT8.InvalidVersion -> pure () - Left err -> assertFailure $ "expected InvalidVersion, got: " ++ show err - Right _ -> assertFailure "expected rejection, got success" + (i_s_sec, i_s_pub) <- expectJust "initiator keypair" + (BOLT8.keypair initiator_s_priv) + (r_s_sec, r_s_pub) <- expectJust "responder keypair" + (BOLT8.keypair responder_s_priv) + rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) + (msg1, _) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv) + let bad_msg1 = BS.cons 0x01 (BS.drop 1 msg1) + case BOLT8.act2 r_s_sec r_s_pub responder_e_priv bad_msg1 of + Left BOLT8.InvalidVersion -> pure () + Left err -> assertFailure $ "expected InvalidVersion, got: " ++ show err + Right _ -> assertFailure "expected rejection, got success" test_act2_wrong_length :: Assertion test_act2_wrong_length = do - let Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv - short_msg = BS.replicate 49 0x00 + (r_s_sec, r_s_pub) <- expectJust "responder keypair" + (BOLT8.keypair responder_s_priv) + let short_msg = BS.replicate 49 0x00 case BOLT8.act2 r_s_sec r_s_pub responder_e_priv short_msg of Left BOLT8.InvalidLength -> pure () Left err -> assertFailure $ "expected InvalidLength, got: " ++ show err @@ -529,80 +475,75 @@ test_act2_wrong_length = do test_act3_invalid_mac :: Assertion test_act3_invalid_mac = do - let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv - Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv - Just rs = BOLT8.parse_pub responder_s_pub - case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of - Left err -> assertFailure $ "act1 failed: " ++ show err - Right (msg1, i_hs) -> - case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of - Left err -> assertFailure $ "act2 failed: " ++ show err - Right (msg2, _) -> do - let bad_msg2 = flip_byte 40 msg2 - case BOLT8.act3 i_hs bad_msg2 of - Left BOLT8.InvalidMAC -> pure () - Left err -> - assertFailure $ "expected InvalidMAC, got: " ++ show err - Right _ -> assertFailure "expected rejection, got success" + (i_s_sec, i_s_pub) <- expectJust "initiator keypair" + (BOLT8.keypair initiator_s_priv) + (r_s_sec, r_s_pub) <- expectJust "responder keypair" + (BOLT8.keypair responder_s_priv) + rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) + (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs + initiator_e_priv) + (msg2, _) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv + msg1) + bad_msg2 <- flip_byte 40 msg2 + case BOLT8.act3 i_hs bad_msg2 of + Left BOLT8.InvalidMAC -> pure () + Left err -> assertFailure $ "expected InvalidMAC, got: " ++ show err + Right _ -> assertFailure "expected rejection, got success" test_finalize_invalid_mac :: Assertion test_finalize_invalid_mac = do - let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv - Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv - Just rs = BOLT8.parse_pub responder_s_pub - case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of - Left err -> assertFailure $ "act1 failed: " ++ show err - Right (msg1, i_hs) -> - case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of - Left err -> assertFailure $ "act2 failed: " ++ show err - Right (msg2, r_hs) -> - case BOLT8.act3 i_hs msg2 of - Left err -> assertFailure $ "act3 failed: " ++ show err - Right (msg3, _) -> do - let bad_msg3 = flip_byte 20 msg3 - case BOLT8.finalize r_hs bad_msg3 of - Left BOLT8.InvalidMAC -> pure () - Left err -> - assertFailure $ "expected InvalidMAC, got: " ++ show err - Right _ -> assertFailure "expected rejection, got success" + (i_s_sec, i_s_pub) <- expectJust "initiator keypair" + (BOLT8.keypair initiator_s_priv) + (r_s_sec, r_s_pub) <- expectJust "responder keypair" + (BOLT8.keypair responder_s_priv) + rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) + (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs + initiator_e_priv) + (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv + msg1) + (msg3, _) <- expectRight "act3" (BOLT8.act3 i_hs msg2) + bad_msg3 <- flip_byte 20 msg3 + case BOLT8.finalize r_hs bad_msg3 of + Left BOLT8.InvalidMAC -> pure () + Left err -> assertFailure $ "expected InvalidMAC, got: " ++ show err + Right _ -> assertFailure "expected rejection, got success" test_decrypt_short_packet :: Assertion test_decrypt_short_packet = do - let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv - Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv - Just rs = BOLT8.parse_pub responder_s_pub - case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of - Left err -> assertFailure $ "act1 failed: " ++ show err - Right (msg1, i_hs) -> - case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of - Left err -> assertFailure $ "act2 failed: " ++ show err - Right (msg2, r_hs) -> - case BOLT8.act3 i_hs msg2 of - Left err -> assertFailure $ "act3 failed: " ++ show err - Right (msg3, _) -> - case BOLT8.finalize r_hs msg3 of - Left err -> assertFailure $ "finalize failed: " ++ show err - Right r_result -> do - let r_sess = BOLT8.session r_result - short_packet = BS.replicate 17 0x00 - case BOLT8.decrypt r_sess short_packet of - Left BOLT8.InvalidLength -> pure () - Left err -> - assertFailure $ "expected InvalidLength, got: " ++ show err - Right _ -> assertFailure "expected rejection, got success" + (i_s_sec, i_s_pub) <- expectJust "initiator keypair" + (BOLT8.keypair initiator_s_priv) + (r_s_sec, r_s_pub) <- expectJust "responder keypair" + (BOLT8.keypair responder_s_priv) + rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) + (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs + initiator_e_priv) + (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv + msg1) + (msg3, _) <- expectRight "act3" (BOLT8.act3 i_hs msg2) + r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3) + let r_sess = BOLT8.session r_result + short_packet = BS.replicate 17 0x00 + case BOLT8.decrypt r_sess short_packet of + Left BOLT8.InvalidLength -> pure () + Left err -> assertFailure $ "expected InvalidLength, got: " ++ show err + Right _ -> assertFailure "expected rejection, got success" -- flip one byte in a bytestring at given index -flip_byte :: Int -> BS.ByteString -> BS.ByteString +flip_byte :: Int -> BS.ByteString -> IO BS.ByteString flip_byte i bs - | i < 0 || i >= BS.length bs = error "flip_byte: index out of bounds" + | i < 0 || i >= BS.length bs = + assertFailure "flip_byte: index out of bounds" >> pure bs | otherwise = let (pre, post) = BS.splitAt i bs b = BS.index post 0 - in pre <> BS.cons (b `xor` 0xff) (BS.drop 1 post) + in pure (pre <> BS.cons (b `xor` 0xff) (BS.drop 1 post)) -- utilities ----------------------------------------------------------------- +-- Safe hex decode for test vectors (only called at top level with known-good +-- literals). This uses error since it's for compile-time constants, not runtime +-- input; wrapping in IO would break the test vector declarations. hex :: BS.ByteString -> BS.ByteString hex bs = case B16.decode bs of - Nothing -> error "invalid hex" - Just r -> r + Nothing -> error "hex: invalid test vector literal" + Just r -> r