Main.hs (27238B)
1 {-# LANGUAGE LambdaCase #-} 2 {-# LANGUAGE OverloadedStrings #-} 3 4 module Main where 5 6 import Data.Bits (xor) 7 import qualified Data.ByteString as BS 8 import qualified Data.ByteString.Base16 as B16 9 import qualified Lightning.Protocol.BOLT8 as BOLT8 10 import Test.Tasty 11 import Test.Tasty.HUnit 12 import Test.Tasty.QuickCheck (Gen, Property, choose, forAll, testProperty, 13 vectorOf) 14 15 -- test helpers ---------------------------------------------------------------- 16 17 -- | Extract a Just value or fail the test. 18 expectJust :: String -> Maybe a -> IO a 19 expectJust msg = \case 20 Nothing -> assertFailure msg >> error "unreachable" 21 Just a -> pure a 22 23 -- | Extract a Right value or fail the test. 24 expectRight :: Show e => String -> Either e a -> IO a 25 expectRight msg = \case 26 Left e -> assertFailure (msg ++ ": " ++ show e) >> error "unreachable" 27 Right a -> pure a 28 29 main :: IO () 30 main = defaultMain $ testGroup "ppad-bolt8" [ 31 handshake_tests 32 , message_tests 33 , framing_tests 34 , partial_framing_tests 35 , negative_tests 36 , property_tests 37 ] 38 39 -- test vectors from BOLT #8 specification ----------------------------------- 40 41 -- initiator static private key 42 initiator_s_priv :: BS.ByteString 43 initiator_s_priv = hex 44 "1111111111111111111111111111111111111111111111111111111111111111" 45 46 -- initiator ephemeral private key 47 initiator_e_priv :: BS.ByteString 48 initiator_e_priv = hex 49 "1212121212121212121212121212121212121212121212121212121212121212" 50 51 -- responder static private key 52 responder_s_priv :: BS.ByteString 53 responder_s_priv = hex 54 "2121212121212121212121212121212121212121212121212121212121212121" 55 56 -- responder static public key (known to initiator) 57 responder_s_pub :: BS.ByteString 58 responder_s_pub = hex 59 "028d7500dd4c12685d1f568b4c2b5048e8534b873319f3a8daa612b469132ec7f7" 60 61 -- responder ephemeral private key 62 responder_e_priv :: BS.ByteString 63 responder_e_priv = hex 64 "2222222222222222222222222222222222222222222222222222222222222222" 65 66 -- expected act 1 message 67 expected_act1 :: BS.ByteString 68 expected_act1 = hex 69 "00036360e856310ce5d294e8be33fc807077dc56ac80d95d9cd4ddbd21325eff73f7\ 70 \0df6086551151f58b8afe6c195782c6a" 71 72 -- expected act 2 message 73 expected_act2 :: BS.ByteString 74 expected_act2 = hex 75 "0002466d7fcae563e5cb09a0d1870bb580344804617879a14949cf22285f1bae3f27\ 76 \6e2470b93aac583c9ef6eafca3f730ae" 77 78 -- expected act 3 message 79 expected_act3 :: BS.ByteString 80 expected_act3 = hex 81 "00b9e3a702e93e3a9948c2ed6e5fd7590a6e1c3a0344cfc9d5b57357049aa22355\ 82 \361aa02e55a8fc28fef5bd6d71ad0c38228dc68b1c466263b47fdf31e560e139ba" 83 84 -- handshake tests ----------------------------------------------------------- 85 86 handshake_tests :: TestTree 87 handshake_tests = testGroup "Handshake" [ 88 testCase "act1 matches spec vector" test_act1 89 , testCase "act2 matches spec vector" test_act2 90 , testCase "act3 matches spec vector" test_act3 91 , testCase "full handshake round-trip" test_full_handshake 92 ] 93 94 test_act1 :: Assertion 95 test_act1 = do 96 (i_s_sec, i_s_pub) <- expectJust "initiator keypair" 97 (BOLT8.keypair initiator_s_priv) 98 rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) 99 (act1_msg, _) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs 100 initiator_e_priv) 101 act1_msg @?= expected_act1 102 103 test_act2 :: Assertion 104 test_act2 = do 105 (i_s_sec, i_s_pub) <- expectJust "initiator keypair" 106 (BOLT8.keypair initiator_s_priv) 107 (r_s_sec, r_s_pub) <- expectJust "responder keypair" 108 (BOLT8.keypair responder_s_priv) 109 rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) 110 (msg1, _) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs 111 initiator_e_priv) 112 (msg2, _) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv 113 msg1) 114 msg2 @?= expected_act2 115 116 test_act3 :: Assertion 117 test_act3 = do 118 (i_s_sec, i_s_pub) <- expectJust "initiator keypair" 119 (BOLT8.keypair initiator_s_priv) 120 (r_s_sec, r_s_pub) <- expectJust "responder keypair" 121 (BOLT8.keypair responder_s_priv) 122 rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) 123 (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs 124 initiator_e_priv) 125 (msg2, _) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv 126 msg1) 127 (msg3, _) <- expectRight "act3" (BOLT8.act3 i_hs msg2) 128 msg3 @?= expected_act3 129 130 test_full_handshake :: Assertion 131 test_full_handshake = do 132 (i_s_sec, i_s_pub) <- expectJust "initiator keypair" 133 (BOLT8.keypair initiator_s_priv) 134 (r_s_sec, r_s_pub) <- expectJust "responder keypair" 135 (BOLT8.keypair responder_s_priv) 136 rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) 137 (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs 138 initiator_e_priv) 139 (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv 140 msg1) 141 (msg3, i_result) <- expectRight "act3" (BOLT8.act3 i_hs msg2) 142 r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3) 143 BOLT8.remote_static i_result @?= r_s_pub 144 BOLT8.remote_static r_result @?= i_s_pub 145 146 -- message encryption tests -------------------------------------------------- 147 148 message_tests :: TestTree 149 message_tests = testGroup "Message Encryption" [ 150 testCase "message 0 matches spec" test_message_0 151 , testCase "message 1 matches spec" test_message_1 152 , testCase "message 500 matches spec" test_message_500 153 , testCase "message 501 matches spec" test_message_501 154 , testCase "message 1000 matches spec" test_message_1000 155 , testCase "message 1001 matches spec" test_message_1001 156 , testCase "decrypt round-trip" test_decrypt_roundtrip 157 ] 158 159 -- "hello" = 0x68656c6c6f 160 hello :: BS.ByteString 161 hello = "hello" 162 163 -- expected encrypted messages 164 expected_msg_0 :: BS.ByteString 165 expected_msg_0 = hex 166 "cf2b30ddf0cf3f80e7c35a6e6730b59fe802473180f396d88a8fb0db8cbcf25d\ 167 \2f214cf9ea1d95" 168 169 expected_msg_1 :: BS.ByteString 170 expected_msg_1 = hex 171 "72887022101f0b6753e0c7de21657d35a4cb2a1f5cde2650528bbc8f837d0f0d\ 172 \7ad833b1a256a1" 173 174 expected_msg_500 :: BS.ByteString 175 expected_msg_500 = hex 176 "178cb9d7387190fa34db9c2d50027d21793c9bc2d40b1e14dcf30ebeeeb220f4\ 177 \8364f7a4c68bf8" 178 179 expected_msg_501 :: BS.ByteString 180 expected_msg_501 = hex 181 "1b186c57d44eb6de4c057c49940d79bb838a145cb528d6e8fd26dbe50a60ca2c\ 182 \104b56b60e45bd" 183 184 expected_msg_1000 :: BS.ByteString 185 expected_msg_1000 = hex 186 "4a2f3cc3b5e78ddb83dcb426d9863d9d9a723b0337c89dd0b005d89f8d3c05c5\ 187 \2b76b29b740f09" 188 189 expected_msg_1001 :: BS.ByteString 190 expected_msg_1001 = hex 191 "2ecd8c8a5629d0d02ab457a0fdd0f7b90a192cd46be5ecb6ca570bfc5e268338\ 192 \b1a16cf4ef2d36" 193 194 -- helper to get initiator session after handshake 195 get_initiator_session :: IO BOLT8.Session 196 get_initiator_session = do 197 (i_s_sec, i_s_pub) <- expectJust "initiator keypair" 198 (BOLT8.keypair initiator_s_priv) 199 (r_s_sec, r_s_pub) <- expectJust "responder keypair" 200 (BOLT8.keypair responder_s_priv) 201 rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) 202 (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs 203 initiator_e_priv) 204 (msg2, _) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv 205 msg1) 206 (_, result) <- expectRight "act3" (BOLT8.act3 i_hs msg2) 207 pure (BOLT8.session result) 208 209 -- encrypt N messages, return Nth ciphertext 210 encrypt_n :: Int -> BOLT8.Session -> IO BS.ByteString 211 encrypt_n n sess0 = go 0 sess0 212 where 213 go i sess 214 | i == n = case BOLT8.encrypt sess hello of 215 Left err -> fail $ "encrypt failed at " ++ show i ++ ": " ++ show err 216 Right (ct, _) -> pure ct 217 | otherwise = case BOLT8.encrypt sess hello of 218 Left err -> fail $ "encrypt failed at " ++ show i ++ ": " ++ show err 219 Right (_, sess') -> go (i + 1) sess' 220 221 test_message_0 :: Assertion 222 test_message_0 = do 223 sess <- get_initiator_session 224 ct <- encrypt_n 0 sess 225 ct @?= expected_msg_0 226 227 test_message_1 :: Assertion 228 test_message_1 = do 229 sess <- get_initiator_session 230 ct <- encrypt_n 1 sess 231 ct @?= expected_msg_1 232 233 test_message_500 :: Assertion 234 test_message_500 = do 235 sess <- get_initiator_session 236 ct <- encrypt_n 500 sess 237 ct @?= expected_msg_500 238 239 test_message_501 :: Assertion 240 test_message_501 = do 241 sess <- get_initiator_session 242 ct <- encrypt_n 501 sess 243 ct @?= expected_msg_501 244 245 test_message_1000 :: Assertion 246 test_message_1000 = do 247 sess <- get_initiator_session 248 ct <- encrypt_n 1000 sess 249 ct @?= expected_msg_1000 250 251 test_message_1001 :: Assertion 252 test_message_1001 = do 253 sess <- get_initiator_session 254 ct <- encrypt_n 1001 sess 255 ct @?= expected_msg_1001 256 257 test_decrypt_roundtrip :: Assertion 258 test_decrypt_roundtrip = do 259 (i_s_sec, i_s_pub) <- expectJust "initiator keypair" 260 (BOLT8.keypair initiator_s_priv) 261 (r_s_sec, r_s_pub) <- expectJust "responder keypair" 262 (BOLT8.keypair responder_s_priv) 263 rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) 264 (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs 265 initiator_e_priv) 266 (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv 267 msg1) 268 (msg3, i_result) <- expectRight "act3" (BOLT8.act3 i_hs msg2) 269 r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3) 270 let i_sess = BOLT8.session i_result 271 r_sess = BOLT8.session r_result 272 (ct, _) <- expectRight "encrypt" (BOLT8.encrypt i_sess hello) 273 (pt, _) <- expectRight "decrypt" (BOLT8.decrypt r_sess ct) 274 pt @?= hello 275 276 -- framing tests ------------------------------------------------------------- 277 278 framing_tests :: TestTree 279 framing_tests = testGroup "Packet Framing" [ 280 testCase "decrypt rejects trailing bytes" test_decrypt_trailing 281 , testCase "decrypt_frame returns remainder" test_decrypt_frame_remainder 282 , testCase "decrypt_frame handles multiple frames" test_decrypt_frame_multi 283 ] 284 285 test_decrypt_trailing :: Assertion 286 test_decrypt_trailing = do 287 (i_s_sec, i_s_pub) <- expectJust "initiator keypair" 288 (BOLT8.keypair initiator_s_priv) 289 (r_s_sec, r_s_pub) <- expectJust "responder keypair" 290 (BOLT8.keypair responder_s_priv) 291 rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) 292 (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs 293 initiator_e_priv) 294 (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv 295 msg1) 296 (msg3, i_result) <- expectRight "act3" (BOLT8.act3 i_hs msg2) 297 r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3) 298 let i_sess = BOLT8.session i_result 299 r_sess = BOLT8.session r_result 300 (ct, _) <- expectRight "encrypt" (BOLT8.encrypt i_sess hello) 301 -- append trailing bytes 302 let ct_with_trailing = ct <> "extra" 303 case BOLT8.decrypt r_sess ct_with_trailing of 304 Left BOLT8.InvalidLength -> pure () 305 Left err -> assertFailure $ "expected InvalidLength, got: " ++ show err 306 Right _ -> assertFailure "decrypt should reject trailing bytes" 307 308 test_decrypt_frame_remainder :: Assertion 309 test_decrypt_frame_remainder = do 310 (i_s_sec, i_s_pub) <- expectJust "initiator keypair" 311 (BOLT8.keypair initiator_s_priv) 312 (r_s_sec, r_s_pub) <- expectJust "responder keypair" 313 (BOLT8.keypair responder_s_priv) 314 rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) 315 (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs 316 initiator_e_priv) 317 (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv 318 msg1) 319 (msg3, i_result) <- expectRight "act3" (BOLT8.act3 i_hs msg2) 320 r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3) 321 let i_sess = BOLT8.session i_result 322 r_sess = BOLT8.session r_result 323 (ct, _) <- expectRight "encrypt" (BOLT8.encrypt i_sess hello) 324 let trailing = "remainder" 325 ct_with_trailing = ct <> trailing 326 (pt, remainder, _) <- expectRight "decrypt_frame" 327 (BOLT8.decrypt_frame r_sess ct_with_trailing) 328 pt @?= hello 329 remainder @?= trailing 330 331 test_decrypt_frame_multi :: Assertion 332 test_decrypt_frame_multi = do 333 (i_s_sec, i_s_pub) <- expectJust "initiator keypair" 334 (BOLT8.keypair initiator_s_priv) 335 (r_s_sec, r_s_pub) <- expectJust "responder keypair" 336 (BOLT8.keypair responder_s_priv) 337 rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) 338 (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs 339 initiator_e_priv) 340 (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv 341 msg1) 342 (msg3, i_result) <- expectRight "act3" (BOLT8.act3 i_hs msg2) 343 r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3) 344 let i_sess = BOLT8.session i_result 345 r_sess = BOLT8.session r_result 346 -- encrypt two messages 347 (ct1, i_sess') <- expectRight "encrypt 1" (BOLT8.encrypt i_sess "first") 348 (ct2, _) <- expectRight "encrypt 2" (BOLT8.encrypt i_sess' "second") 349 -- concatenate frames 350 let buffer = ct1 <> ct2 351 -- decrypt first frame 352 (pt1, rest, r_sess') <- expectRight "frame 1" 353 (BOLT8.decrypt_frame r_sess buffer) 354 pt1 @?= "first" 355 -- decrypt second frame from remainder 356 (pt2, rest2, _) <- expectRight "frame 2" (BOLT8.decrypt_frame r_sess' rest) 357 pt2 @?= "second" 358 rest2 @?= BS.empty 359 360 -- partial framing tests ----------------------------------------------------- 361 362 partial_framing_tests :: TestTree 363 partial_framing_tests = testGroup "Partial Framing" [ 364 testCase "short buffer returns NeedMore" test_partial_short_buffer 365 , testCase "partial body returns NeedMore" test_partial_body 366 , testCase "full frame returns FrameOk" test_partial_full_frame 367 ] 368 369 test_partial_short_buffer :: Assertion 370 test_partial_short_buffer = do 371 (i_s_sec, i_s_pub) <- expectJust "initiator keypair" 372 (BOLT8.keypair initiator_s_priv) 373 (r_s_sec, r_s_pub) <- expectJust "responder keypair" 374 (BOLT8.keypair responder_s_priv) 375 rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) 376 (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs 377 initiator_e_priv) 378 (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv 379 msg1) 380 (msg3, _) <- expectRight "act3" (BOLT8.act3 i_hs msg2) 381 r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3) 382 let r_sess = BOLT8.session r_result 383 short_buf = BS.replicate 10 0x00 384 case BOLT8.decrypt_frame_partial r_sess short_buf of 385 BOLT8.NeedMore n -> n @?= 8 386 BOLT8.FrameOk {} -> assertFailure "expected NeedMore, got FrameOk" 387 BOLT8.FrameError err -> 388 assertFailure $ "expected NeedMore, got: " ++ show err 389 390 test_partial_body :: Assertion 391 test_partial_body = do 392 (i_s_sec, i_s_pub) <- expectJust "initiator keypair" 393 (BOLT8.keypair initiator_s_priv) 394 (r_s_sec, r_s_pub) <- expectJust "responder keypair" 395 (BOLT8.keypair responder_s_priv) 396 rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) 397 (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs 398 initiator_e_priv) 399 (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv 400 msg1) 401 (msg3, i_result) <- expectRight "act3" (BOLT8.act3 i_hs msg2) 402 r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3) 403 let i_sess = BOLT8.session i_result 404 r_sess = BOLT8.session r_result 405 (ct, _) <- expectRight "encrypt" (BOLT8.encrypt i_sess hello) 406 -- take only length header (18 bytes) + 5 bytes of body 407 let partial = BS.take 23 ct 408 case BOLT8.decrypt_frame_partial r_sess partial of 409 BOLT8.NeedMore n -> do 410 -- "hello" = 5 bytes, so body = 5 + 16 = 21 411 -- we have 5 bytes of body, need 16 more 412 n @?= 16 413 BOLT8.FrameOk {} -> assertFailure "expected NeedMore, got FrameOk" 414 BOLT8.FrameError err -> 415 assertFailure $ "expected NeedMore, got: " ++ show err 416 417 test_partial_full_frame :: Assertion 418 test_partial_full_frame = do 419 (i_s_sec, i_s_pub) <- expectJust "initiator keypair" 420 (BOLT8.keypair initiator_s_priv) 421 (r_s_sec, r_s_pub) <- expectJust "responder keypair" 422 (BOLT8.keypair responder_s_priv) 423 rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) 424 (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs 425 initiator_e_priv) 426 (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv 427 msg1) 428 (msg3, i_result) <- expectRight "act3" (BOLT8.act3 i_hs msg2) 429 r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3) 430 let i_sess = BOLT8.session i_result 431 r_sess = BOLT8.session r_result 432 (ct, _) <- expectRight "encrypt" (BOLT8.encrypt i_sess hello) 433 let trailing = "extra" 434 buf = ct <> trailing 435 case BOLT8.decrypt_frame_partial r_sess buf of 436 BOLT8.FrameOk pt remainder _ -> do 437 pt @?= hello 438 remainder @?= trailing 439 BOLT8.NeedMore n -> 440 assertFailure $ "expected FrameOk, got NeedMore " ++ show n 441 BOLT8.FrameError err -> 442 assertFailure $ "expected FrameOk, got: " ++ show err 443 444 -- negative tests ------------------------------------------------------------ 445 446 negative_tests :: TestTree 447 negative_tests = testGroup "Negative Tests" [ 448 testCase "act2 rejects wrong version" test_act2_wrong_version 449 , testCase "act2 rejects wrong length" test_act2_wrong_length 450 , testCase "act3 rejects invalid MAC" test_act3_invalid_mac 451 , testCase "finalize rejects invalid MAC" test_finalize_invalid_mac 452 , testCase "decrypt rejects short packet" test_decrypt_short_packet 453 ] 454 455 test_act2_wrong_version :: Assertion 456 test_act2_wrong_version = do 457 (i_s_sec, i_s_pub) <- expectJust "initiator keypair" 458 (BOLT8.keypair initiator_s_priv) 459 (r_s_sec, r_s_pub) <- expectJust "responder keypair" 460 (BOLT8.keypair responder_s_priv) 461 rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) 462 (msg1, _) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv) 463 let bad_msg1 = BS.cons 0x01 (BS.drop 1 msg1) 464 case BOLT8.act2 r_s_sec r_s_pub responder_e_priv bad_msg1 of 465 Left BOLT8.InvalidVersion -> pure () 466 Left err -> assertFailure $ "expected InvalidVersion, got: " ++ show err 467 Right _ -> assertFailure "expected rejection, got success" 468 469 test_act2_wrong_length :: Assertion 470 test_act2_wrong_length = do 471 (r_s_sec, r_s_pub) <- expectJust "responder keypair" 472 (BOLT8.keypair responder_s_priv) 473 let short_msg = BS.replicate 49 0x00 474 case BOLT8.act2 r_s_sec r_s_pub responder_e_priv short_msg of 475 Left BOLT8.InvalidLength -> pure () 476 Left err -> assertFailure $ "expected InvalidLength, got: " ++ show err 477 Right _ -> assertFailure "expected rejection, got success" 478 479 test_act3_invalid_mac :: Assertion 480 test_act3_invalid_mac = do 481 (i_s_sec, i_s_pub) <- expectJust "initiator keypair" 482 (BOLT8.keypair initiator_s_priv) 483 (r_s_sec, r_s_pub) <- expectJust "responder keypair" 484 (BOLT8.keypair responder_s_priv) 485 rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) 486 (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs 487 initiator_e_priv) 488 (msg2, _) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv 489 msg1) 490 bad_msg2 <- flip_byte 40 msg2 491 case BOLT8.act3 i_hs bad_msg2 of 492 Left BOLT8.InvalidMAC -> pure () 493 Left err -> assertFailure $ "expected InvalidMAC, got: " ++ show err 494 Right _ -> assertFailure "expected rejection, got success" 495 496 test_finalize_invalid_mac :: Assertion 497 test_finalize_invalid_mac = do 498 (i_s_sec, i_s_pub) <- expectJust "initiator keypair" 499 (BOLT8.keypair initiator_s_priv) 500 (r_s_sec, r_s_pub) <- expectJust "responder keypair" 501 (BOLT8.keypair responder_s_priv) 502 rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) 503 (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs 504 initiator_e_priv) 505 (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv 506 msg1) 507 (msg3, _) <- expectRight "act3" (BOLT8.act3 i_hs msg2) 508 bad_msg3 <- flip_byte 20 msg3 509 case BOLT8.finalize r_hs bad_msg3 of 510 Left BOLT8.InvalidMAC -> pure () 511 Left err -> assertFailure $ "expected InvalidMAC, got: " ++ show err 512 Right _ -> assertFailure "expected rejection, got success" 513 514 test_decrypt_short_packet :: Assertion 515 test_decrypt_short_packet = do 516 (i_s_sec, i_s_pub) <- expectJust "initiator keypair" 517 (BOLT8.keypair initiator_s_priv) 518 (r_s_sec, r_s_pub) <- expectJust "responder keypair" 519 (BOLT8.keypair responder_s_priv) 520 rs <- expectJust "responder pub" (BOLT8.parse_pub responder_s_pub) 521 (msg1, i_hs) <- expectRight "act1" (BOLT8.act1 i_s_sec i_s_pub rs 522 initiator_e_priv) 523 (msg2, r_hs) <- expectRight "act2" (BOLT8.act2 r_s_sec r_s_pub responder_e_priv 524 msg1) 525 (msg3, _) <- expectRight "act3" (BOLT8.act3 i_hs msg2) 526 r_result <- expectRight "finalize" (BOLT8.finalize r_hs msg3) 527 let r_sess = BOLT8.session r_result 528 short_packet = BS.replicate 17 0x00 529 case BOLT8.decrypt r_sess short_packet of 530 Left BOLT8.InvalidLength -> pure () 531 Left err -> assertFailure $ "expected InvalidLength, got: " ++ show err 532 Right _ -> assertFailure "expected rejection, got success" 533 534 -- flip one byte in a bytestring at given index 535 flip_byte :: Int -> BS.ByteString -> IO BS.ByteString 536 flip_byte i bs 537 | i < 0 || i >= BS.length bs = 538 assertFailure "flip_byte: index out of bounds" >> pure bs 539 | otherwise = 540 let (pre, post) = BS.splitAt i bs 541 b = BS.index post 0 542 in pure (pre <> BS.cons (b `xor` 0xff) (BS.drop 1 post)) 543 544 -- utilities ----------------------------------------------------------------- 545 546 -- Safe hex decode for test vectors (only called at top level with known-good 547 -- literals). This uses error since it's for compile-time constants, not runtime 548 -- input; wrapping in IO would break the test vector declarations. 549 hex :: BS.ByteString -> BS.ByteString 550 hex bs = case B16.decode bs of 551 Nothing -> error "hex: invalid test vector literal" 552 Just r -> r 553 554 -- property tests -------------------------------------------------------------- 555 556 property_tests :: TestTree 557 property_tests = testGroup "Properties" [ 558 testProperty "handshake round-trip" prop_handshake_roundtrip 559 , testProperty "encrypt/decrypt round-trip" prop_encrypt_decrypt_roundtrip 560 , testProperty "decrypt_frame consumes one frame" prop_frame_consumes_one 561 , testProperty "decrypt_frame_partial NeedMore on short" 562 prop_partial_needmore_short 563 ] 564 565 -- generators ------------------------------------------------------------------ 566 567 -- | Generate 32 bytes of entropy that yields a valid keypair. 568 genValidEntropy :: Gen BS.ByteString 569 genValidEntropy = do 570 bytes <- BS.pack <$> vectorOf 32 (choose (0, 255)) 571 case BOLT8.keypair bytes of 572 Just _ -> pure bytes 573 Nothing -> genValidEntropy 574 575 -- | Generate a payload of 0..256 bytes. 576 genPayload :: Gen BS.ByteString 577 genPayload = do 578 len <- choose (0, 256) 579 BS.pack <$> vectorOf len (choose (0, 255)) 580 581 -- | Perform a full handshake with given static key entropy. 582 -- Uses fixed ephemeral keys for determinism. 583 doHandshake 584 :: BS.ByteString 585 -> BS.ByteString 586 -> Maybe (BOLT8.Session, BOLT8.Session) 587 doHandshake i_entropy r_entropy = do 588 (i_s_sec, i_s_pub) <- BOLT8.keypair i_entropy 589 (r_s_sec, r_s_pub) <- BOLT8.keypair r_entropy 590 let i_e = BS.replicate 32 0x12 591 r_e = BS.replicate 32 0x22 592 (msg1, i_hs) <- either (const Nothing) Just $ 593 BOLT8.act1 i_s_sec i_s_pub r_s_pub i_e 594 (msg2, r_hs) <- either (const Nothing) Just $ 595 BOLT8.act2 r_s_sec r_s_pub r_e msg1 596 (msg3, i_res) <- either (const Nothing) Just $ 597 BOLT8.act3 i_hs msg2 598 r_res <- either (const Nothing) Just $ 599 BOLT8.finalize r_hs msg3 600 pure (BOLT8.session i_res, BOLT8.session r_res) 601 602 -- properties ------------------------------------------------------------------ 603 604 -- | Handshake succeeds for valid keys and sessions are consistent. 605 prop_handshake_roundtrip :: Property 606 prop_handshake_roundtrip = forAll genValidEntropy $ \i_ent -> 607 forAll genValidEntropy $ \r_ent -> 608 case doHandshake i_ent r_ent of 609 Nothing -> False 610 Just _ -> True 611 612 -- | Encrypt then decrypt yields original payload. 613 prop_encrypt_decrypt_roundtrip :: Property 614 prop_encrypt_decrypt_roundtrip = forAll genPayload $ \payload -> 615 case doHandshake initiator_s_priv responder_s_priv of 616 Nothing -> False 617 Just (i_sess, r_sess) -> 618 case BOLT8.encrypt i_sess payload of 619 Left _ -> False 620 Right (ct, _) -> 621 case BOLT8.decrypt r_sess ct of 622 Left _ -> False 623 Right (pt, _) -> pt == payload 624 625 -- | decrypt_frame consumes exactly one frame and returns remainder. 626 prop_frame_consumes_one :: Property 627 prop_frame_consumes_one = forAll genPayload $ \p1 -> 628 forAll genPayload $ \p2 -> 629 case doHandshake initiator_s_priv responder_s_priv of 630 Nothing -> False 631 Just (i_sess, r_sess) -> 632 case BOLT8.encrypt i_sess p1 of 633 Left _ -> False 634 Right (ct1, i_sess') -> 635 case BOLT8.encrypt i_sess' p2 of 636 Left _ -> False 637 Right (ct2, _) -> 638 let buf = ct1 <> ct2 639 in case BOLT8.decrypt_frame r_sess buf of 640 Left _ -> False 641 Right (pt1, rest, r_sess') -> 642 pt1 == p1 && 643 case BOLT8.decrypt_frame r_sess' rest of 644 Left _ -> False 645 Right (pt2, rest2, _) -> 646 pt2 == p2 && BS.null rest2 647 648 -- | decrypt_frame_partial returns NeedMore when buffer < 18 bytes. 649 prop_partial_needmore_short :: Property 650 prop_partial_needmore_short = forAll (choose (0, 17)) $ \len -> 651 case doHandshake initiator_s_priv responder_s_priv of 652 Nothing -> False 653 Just (_, r_sess) -> 654 let buf = BS.replicate len 0x00 655 in case BOLT8.decrypt_frame_partial r_sess buf of 656 BOLT8.NeedMore n -> n == 18 - len 657 _ -> False