commit 6a8fe194e7148d95adcc5c803c7bc52f5a0b60af
parent a2df1e41a98a39eea2adbd5253932ed187769416
Author: Jared Tobin <jared@jtobin.io>
Date: Mon, 12 Jan 2026 13:54:20 +0400
lib: general refactor
Diffstat:
| M | bench/Main.hs | | | 58 | +++++++++++++++++++++++----------------------------------- |
| M | bench/Weight.hs | | | 46 | ++++++++++++++++++++-------------------------- |
| M | lib/Lightning/Protocol/BOLT8.hs | | | 134 | ++++++++++++++++++++++++++++++++++++++++---------------------------------------- |
| M | test/Main.hs | | | 90 | ++++++++++++++++++++++++++++++++++--------------------------------------------- |
4 files changed, 149 insertions(+), 179 deletions(-)
diff --git a/bench/Main.hs b/bench/Main.hs
@@ -16,7 +16,7 @@ instance NFData BOLT8.Sec
instance NFData BOLT8.Error
instance NFData BOLT8.Session
instance NFData BOLT8.HandshakeState
-instance NFData BOLT8.HandshakeResult
+instance NFData BOLT8.Handshake
main :: IO ()
main = defaultMain [
@@ -43,40 +43,30 @@ keys = bgroup "keys" [
r_s_pub_bs = BOLT8.serialize_pub r_s_pub
handshake :: Benchmark
-handshake = env setup $ \ ~(i_s_sec, i_s_pub, r_s_sec, r_s_pub, act1, i_hs,
- act2, r_hs, act3) ->
+handshake = env setup $ \ ~(i_s_sec, i_s_pub, r_s_sec, r_s_pub, msg1, i_hs,
+ msg2, r_hs, msg3) ->
bgroup "handshake" [
- bench "initiator_act1" $
- nf (BOLT8.initiator_act1 i_s_sec i_s_pub r_s_pub) i_e_ent
- , bench "responder_act2" $
- nf (BOLT8.responder_act2 r_s_sec r_s_pub r_e_ent) act1
- , bench "initiator_act3" $
- nf (BOLT8.initiator_act3 i_hs) act2
- , bench "responder_finalize" $
- nf (BOLT8.responder_finalize r_hs) act3
+ bench "act1" $ nf (BOLT8.act1 i_s_sec i_s_pub r_s_pub) i_e_ent
+ , bench "act2" $ nf (BOLT8.act2 r_s_sec r_s_pub r_e_ent) msg1
+ , bench "act3" $ nf (BOLT8.act3 i_hs) msg2
+ , bench "finalize" $ nf (BOLT8.finalize r_hs) msg3
]
where
setup = do
let Just (!i_s_sec, !i_s_pub) = BOLT8.keypair i_s_ent
Just (!r_s_sec, !r_s_pub) = BOLT8.keypair r_s_ent
- Right (!act1, !i_hs) =
- BOLT8.initiator_act1 i_s_sec i_s_pub r_s_pub i_e_ent
- Right (!act2, !r_hs) =
- BOLT8.responder_act2 r_s_sec r_s_pub r_e_ent act1
- Right (!act3, _) = BOLT8.initiator_act3 i_hs act2
- pure (i_s_sec, i_s_pub, r_s_sec, r_s_pub, act1, i_hs, act2, r_hs, act3)
+ Right (!msg1, !i_hs) = BOLT8.act1 i_s_sec i_s_pub r_s_pub i_e_ent
+ Right (!msg2, !r_hs) = BOLT8.act2 r_s_sec r_s_pub r_e_ent msg1
+ Right (!msg3, _) = BOLT8.act3 i_hs msg2
+ pure (i_s_sec, i_s_pub, r_s_sec, r_s_pub, msg1, i_hs, msg2, r_hs, msg3)
messages :: Benchmark
messages = env setup $ \ ~(i_sess, r_sess, ct_small, ct_large) ->
bgroup "messages" [
- bench "encrypt (32B)" $
- nf (BOLT8.encrypt_message i_sess) small_msg
- , bench "encrypt (1KB)" $
- nf (BOLT8.encrypt_message i_sess) large_msg
- , bench "decrypt (32B)" $
- nf (BOLT8.decrypt_message r_sess) ct_small
- , bench "decrypt (1KB)" $
- nf (BOLT8.decrypt_message r_sess) ct_large
+ bench "encrypt (32B)" $ nf (BOLT8.encrypt i_sess) small_msg
+ , bench "encrypt (1KB)" $ nf (BOLT8.encrypt i_sess) large_msg
+ , bench "decrypt (32B)" $ nf (BOLT8.decrypt r_sess) ct_small
+ , bench "decrypt (1KB)" $ nf (BOLT8.decrypt r_sess) ct_large
]
where
small_msg = BS.replicate 32 0x00
@@ -84,14 +74,12 @@ messages = env setup $ \ ~(i_sess, r_sess, ct_small, ct_large) ->
setup = do
let Just (!i_s_sec, !i_s_pub) = BOLT8.keypair i_s_ent
Just (!r_s_sec, !r_s_pub) = BOLT8.keypair r_s_ent
- Right (act1, i_hs) =
- BOLT8.initiator_act1 i_s_sec i_s_pub r_s_pub i_e_ent
- Right (act2, r_hs) =
- BOLT8.responder_act2 r_s_sec r_s_pub r_e_ent act1
- Right (act3, i_result) = BOLT8.initiator_act3 i_hs act2
- Right r_result = BOLT8.responder_finalize r_hs act3
- !i_sess = BOLT8.hr_session i_result
- !r_sess = BOLT8.hr_session r_result
- Right (!ct_small, _) = BOLT8.encrypt_message i_sess small_msg
- Right (!ct_large, _) = BOLT8.encrypt_message i_sess large_msg
+ Right (msg1, i_hs) = BOLT8.act1 i_s_sec i_s_pub r_s_pub i_e_ent
+ Right (msg2, r_hs) = BOLT8.act2 r_s_sec r_s_pub r_e_ent msg1
+ Right (msg3, i_result) = BOLT8.act3 i_hs msg2
+ Right r_result = BOLT8.finalize r_hs msg3
+ !i_sess = BOLT8.session i_result
+ !r_sess = BOLT8.session r_result
+ Right (!ct_small, _) = BOLT8.encrypt i_sess small_msg
+ Right (!ct_large, _) = BOLT8.encrypt i_sess large_msg
pure (i_sess, r_sess, ct_small, ct_large)
diff --git a/bench/Weight.hs b/bench/Weight.hs
@@ -16,7 +16,7 @@ instance NFData BOLT8.Sec
instance NFData BOLT8.Error
instance NFData BOLT8.Session
instance NFData BOLT8.HandshakeState
-instance NFData BOLT8.HandshakeResult
+instance NFData BOLT8.Handshake
-- note that 'weigh' doesn't work properly in a repl
main :: IO ()
@@ -45,37 +45,31 @@ handshake :: Weigh ()
handshake =
let Just (!i_s_sec, !i_s_pub) = BOLT8.keypair i_s_ent
Just (!r_s_sec, !r_s_pub) = BOLT8.keypair r_s_ent
- Right (!act1, !i_hs) =
- BOLT8.initiator_act1 i_s_sec i_s_pub r_s_pub i_e_ent
- Right (!act2, !r_hs) =
- BOLT8.responder_act2 r_s_sec r_s_pub r_e_ent act1
- Right (!act3, _) = BOLT8.initiator_act3 i_hs act2
+ Right (!msg1, !i_hs) = BOLT8.act1 i_s_sec i_s_pub r_s_pub i_e_ent
+ Right (!msg2, !r_hs) = BOLT8.act2 r_s_sec r_s_pub r_e_ent msg1
+ Right (!msg3, _) = BOLT8.act3 i_hs msg2
in wgroup "handshake" $ do
- func "initiator_act1"
- (BOLT8.initiator_act1 i_s_sec i_s_pub r_s_pub) i_e_ent
- func "responder_act2"
- (BOLT8.responder_act2 r_s_sec r_s_pub r_e_ent) act1
- func "initiator_act3" (BOLT8.initiator_act3 i_hs) act2
- func "responder_finalize" (BOLT8.responder_finalize r_hs) act3
+ func "act1" (BOLT8.act1 i_s_sec i_s_pub r_s_pub) i_e_ent
+ func "act2" (BOLT8.act2 r_s_sec r_s_pub r_e_ent) msg1
+ func "act3" (BOLT8.act3 i_hs) msg2
+ func "finalize" (BOLT8.finalize r_hs) msg3
messages :: Weigh ()
messages =
let Just (!i_s_sec, !i_s_pub) = BOLT8.keypair i_s_ent
Just (!r_s_sec, !r_s_pub) = BOLT8.keypair r_s_ent
- Right (act1, i_hs) =
- BOLT8.initiator_act1 i_s_sec i_s_pub r_s_pub i_e_ent
- Right (act2, r_hs) =
- BOLT8.responder_act2 r_s_sec r_s_pub r_e_ent act1
- Right (act3, i_result) = BOLT8.initiator_act3 i_hs act2
- Right r_result = BOLT8.responder_finalize r_hs act3
- !i_sess = BOLT8.hr_session i_result
- !r_sess = BOLT8.hr_session r_result
+ Right (msg1, i_hs) = BOLT8.act1 i_s_sec i_s_pub r_s_pub i_e_ent
+ Right (msg2, r_hs) = BOLT8.act2 r_s_sec r_s_pub r_e_ent msg1
+ Right (msg3, i_result) = BOLT8.act3 i_hs msg2
+ Right r_result = BOLT8.finalize r_hs msg3
+ !i_sess = BOLT8.session i_result
+ !r_sess = BOLT8.session r_result
!small_msg = BS.replicate 32 0x00
!large_msg = BS.replicate 1024 0x00
- Right (!ct_small, _) = BOLT8.encrypt_message i_sess small_msg
- Right (!ct_large, _) = BOLT8.encrypt_message i_sess large_msg
+ Right (!ct_small, _) = BOLT8.encrypt i_sess small_msg
+ Right (!ct_large, _) = BOLT8.encrypt i_sess large_msg
in wgroup "messages" $ do
- func "encrypt (32B)" (BOLT8.encrypt_message i_sess) small_msg
- func "encrypt (1KB)" (BOLT8.encrypt_message i_sess) large_msg
- func "decrypt (32B)" (BOLT8.decrypt_message r_sess) ct_small
- func "decrypt (1KB)" (BOLT8.decrypt_message r_sess) ct_large
+ func "encrypt (32B)" (BOLT8.encrypt i_sess) small_msg
+ func "encrypt (1KB)" (BOLT8.encrypt i_sess) large_msg
+ func "decrypt (32B)" (BOLT8.decrypt r_sess) ct_small
+ func "decrypt (1KB)" (BOLT8.decrypt r_sess) ct_large
diff --git a/lib/Lightning/Protocol/BOLT8.hs b/lib/Lightning/Protocol/BOLT8.hs
@@ -27,19 +27,19 @@ module Lightning.Protocol.BOLT8 (
, serialize_pub
-- * Handshake (initiator)
- , initiator_act1
- , initiator_act3
+ , act1
+ , act3
-- * Handshake (responder)
- , responder_act2
- , responder_finalize
+ , act2
+ , finalize
-- * Session
, Session
, HandshakeState
- , HandshakeResult(..)
- , encrypt_message
- , decrypt_message
+ , Handshake(..)
+ , encrypt
+ , decrypt
-- * Errors
, Error(..)
@@ -93,9 +93,9 @@ data Session = Session {
deriving Generic
-- | Result of a successful handshake.
-data HandshakeResult = HandshakeResult {
- hr_session :: !Session -- ^ session state
- , hr_remote_pk :: !Pub -- ^ authenticated remote static pubkey
+data Handshake = Handshake {
+ session :: !Session -- ^ session state
+ , remote_static :: !Pub -- ^ authenticated remote static pubkey
}
deriving Generic
@@ -288,15 +288,15 @@ init_handshake s_sec s_pub e_sec e_pub m_rs is_initiator =
-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
-- >>> let eph_ent = BS.replicate 32 0x12
--- >>> case initiator_act1 i_sec i_pub r_pub eph_ent of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
+-- >>> case act1 i_sec i_pub r_pub eph_ent of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
-- 50
-initiator_act1
+act1
:: Sec -- ^ local static secret
-> Pub -- ^ local static public
-> Pub -- ^ remote static public (responder's)
-> BS.ByteString -- ^ 32 bytes entropy for ephemeral
-> Either Error (BS.ByteString, HandshakeState)
-initiator_act1 s_sec s_pub rs ent = do
+act1 s_sec s_pub rs ent = do
(e_sec, e_pub) <- note InvalidKey (keypair ent)
let !hs0 = init_handshake s_sec s_pub e_sec e_pub (Just rs) True
!e_pub_bytes = serialize_pub e_pub
@@ -322,20 +322,20 @@ initiator_act1 s_sec s_pub rs ent = do
--
-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
--- >>> let Right (act1, _) = initiator_act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
--- >>> case responder_act2 r_sec r_pub (BS.replicate 32 0x22) act1 of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
+-- >>> let Right (msg1, _) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
+-- >>> case act2 r_sec r_pub (BS.replicate 32 0x22) msg1 of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
-- 50
-responder_act2
+act2
:: Sec -- ^ local static secret
-> Pub -- ^ local static public
-> BS.ByteString -- ^ 32 bytes entropy for ephemeral
-> BS.ByteString -- ^ Act 1 message (50 bytes)
-> Either Error (BS.ByteString, HandshakeState)
-responder_act2 s_sec s_pub ent act1 = do
- require (BS.length act1 == 50) InvalidLength
- let !version = BS.index act1 0
- !re_bytes = BS.take 33 (BS.drop 1 act1)
- !c = BS.drop 34 act1
+act2 s_sec s_pub ent msg1 = do
+ require (BS.length msg1 == 50) InvalidLength
+ let !version = BS.index msg1 0
+ !re_bytes = BS.take 33 (BS.drop 1 msg1)
+ !c = BS.drop 34 msg1
require (version == 0x00) InvalidVersion
re <- note InvalidPub (parse_pub re_bytes)
(e_sec, e_pub) <- note InvalidKey (keypair ent)
@@ -363,23 +363,23 @@ responder_act2 s_sec s_pub ent act1 = do
-- | Initiator: process Act 2 and generate Act 3 (66 bytes), completing
-- the handshake.
--
--- Returns the 66-byte Act 3 message and the session result.
+-- Returns the 66-byte Act 3 message and the handshake result.
--
-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
--- >>> let Right (act1, i_hs) = initiator_act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
--- >>> let Right (act2, _) = responder_act2 r_sec r_pub (BS.replicate 32 0x22) act1
--- >>> case initiator_act3 i_hs act2 of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
+-- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
+-- >>> let Right (msg2, _) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
+-- >>> case act3 i_hs msg2 of { Right (msg, _) -> BS.length msg; Left _ -> 0 }
-- 66
-initiator_act3
+act3
:: HandshakeState -- ^ state after Act 1
-> BS.ByteString -- ^ Act 2 message (50 bytes)
- -> Either Error (BS.ByteString, HandshakeResult)
-initiator_act3 hs act2 = do
- require (BS.length act2 == 50) InvalidLength
- let !version = BS.index act2 0
- !re_bytes = BS.take 33 (BS.drop 1 act2)
- !c = BS.drop 34 act2
+ -> Either Error (BS.ByteString, Handshake)
+act3 hs msg2 = do
+ require (BS.length msg2 == 50) InvalidLength
+ let !version = BS.index msg2 0
+ !re_bytes = BS.take 33 (BS.drop 1 msg2)
+ !c = BS.drop 34 msg2
require (version == 0x00) InvalidVersion
re <- note InvalidPub (parse_pub re_bytes)
let !h1 = mix_hash (hs_h hs) re_bytes
@@ -395,7 +395,7 @@ initiator_act3 hs act2 = do
t <- note InvalidMAC (encrypt_with_ad temp_k3 0 h3 BS.empty)
let !(sk, rk) = mix_key ck2 BS.empty
!msg = BS.singleton 0x00 <> c3 <> t
- !session = Session {
+ !sess = Session {
sess_sk = sk
, sess_sn = 0
, sess_sck = ck2
@@ -404,32 +404,32 @@ initiator_act3 hs act2 = do
, sess_rck = ck2
}
rs <- note InvalidPub (hs_rs hs)
- let !result = HandshakeResult {
- hr_session = session
- , hr_remote_pk = rs
+ let !result = Handshake {
+ session = sess
+ , remote_static = rs
}
pure (msg, result)
-- | Responder: process Act 3 (66 bytes) and complete the handshake.
--
--- Returns the session result with authenticated remote static pubkey.
+-- Returns the handshake result with authenticated remote static pubkey.
--
-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
--- >>> let Right (act1, i_hs) = initiator_act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
--- >>> let Right (act2, r_hs) = responder_act2 r_sec r_pub (BS.replicate 32 0x22) act1
--- >>> let Right (act3, _) = initiator_act3 i_hs act2
--- >>> case responder_finalize r_hs act3 of { Right _ -> "ok"; Left e -> show e }
+-- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
+-- >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
+-- >>> let Right (msg3, _) = act3 i_hs msg2
+-- >>> case finalize r_hs msg3 of { Right _ -> "ok"; Left e -> show e }
-- "ok"
-responder_finalize
+finalize
:: HandshakeState -- ^ state after Act 2
-> BS.ByteString -- ^ Act 3 message (66 bytes)
- -> Either Error HandshakeResult
-responder_finalize hs act3 = do
- require (BS.length act3 == 66) InvalidLength
- let !version = BS.index act3 0
- !c = BS.take 49 (BS.drop 1 act3)
- !t = BS.drop 50 act3
+ -> Either Error Handshake
+finalize hs msg3 = do
+ require (BS.length msg3 == 66) InvalidLength
+ let !version = BS.index msg3 0
+ !c = BS.take 49 (BS.drop 1 msg3)
+ !t = BS.drop 50 msg3
require (version == 0x00) InvalidVersion
rs_bytes <- note InvalidMAC (decrypt_with_ad (hs_temp_k hs) 1 (hs_h hs) c)
rs <- note InvalidPub (parse_pub rs_bytes)
@@ -439,7 +439,7 @@ responder_finalize hs act3 = do
_ <- note InvalidMAC (decrypt_with_ad temp_k3 0 h1 t)
-- responder swaps order (receives what initiator sends)
let !(rk, sk) = mix_key ck1 BS.empty
- !session = Session {
+ !sess = Session {
sess_sk = sk
, sess_sn = 0
, sess_sck = ck1
@@ -447,9 +447,9 @@ responder_finalize hs act3 = do
, sess_rn = 0
, sess_rck = ck1
}
- !result = HandshakeResult {
- hr_session = session
- , hr_remote_pk = rs
+ !result = Handshake {
+ session = sess
+ , remote_static = rs
}
pure result
@@ -464,17 +464,17 @@ responder_finalize hs act3 = do
--
-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
--- >>> let Right (act1, i_hs) = initiator_act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
--- >>> let Right (act2, r_hs) = responder_act2 r_sec r_pub (BS.replicate 32 0x22) act1
--- >>> let Right (_, i_result) = initiator_act3 i_hs act2
--- >>> let sess = hr_session i_result
--- >>> case encrypt_message sess "hello" of { Right (ct, _) -> BS.length ct; Left _ -> 0 }
+-- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
+-- >>> let Right (msg2, _) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
+-- >>> let Right (_, i_result) = act3 i_hs msg2
+-- >>> let sess = session i_result
+-- >>> case encrypt sess "hello" of { Right (ct, _) -> BS.length ct; Left _ -> 0 }
-- 39
-encrypt_message
+encrypt
:: Session
-> BS.ByteString -- ^ plaintext (max 65535 bytes)
-> Either Error (BS.ByteString, Session)
-encrypt_message sess pt = do
+encrypt sess pt = do
let !len = BS.length pt
require (len <= 65535) InvalidLength
let !len_bytes = encode_be16 (fi len)
@@ -498,18 +498,18 @@ encrypt_message sess pt = do
--
-- >>> let Just (i_sec, i_pub) = keypair (BS.replicate 32 0x11)
-- >>> let Just (r_sec, r_pub) = keypair (BS.replicate 32 0x21)
--- >>> let Right (act1, i_hs) = initiator_act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
--- >>> let Right (act2, r_hs) = responder_act2 r_sec r_pub (BS.replicate 32 0x22) act1
--- >>> let Right (act3, i_result) = initiator_act3 i_hs act2
--- >>> let Right r_result = responder_finalize r_hs act3
--- >>> let Right (ct, _) = encrypt_message (hr_session i_result) "hello"
--- >>> case decrypt_message (hr_session r_result) ct of { Right (pt, _) -> pt; Left _ -> "fail" }
+-- >>> let Right (msg1, i_hs) = act1 i_sec i_pub r_pub (BS.replicate 32 0x12)
+-- >>> let Right (msg2, r_hs) = act2 r_sec r_pub (BS.replicate 32 0x22) msg1
+-- >>> let Right (msg3, i_result) = act3 i_hs msg2
+-- >>> let Right r_result = finalize r_hs msg3
+-- >>> let Right (ct, _) = encrypt (session i_result) "hello"
+-- >>> case decrypt (session r_result) ct of { Right (pt, _) -> pt; Left _ -> "fail" }
-- "hello"
-decrypt_message
+decrypt
:: Session
-> BS.ByteString -- ^ encrypted packet
-> Either Error (BS.ByteString, Session)
-decrypt_message sess packet = do
+decrypt sess packet = do
require (BS.length packet >= 34) InvalidLength
let !lc = BS.take 18 packet
!rest = BS.drop 18 packet
diff --git a/test/Main.hs b/test/Main.hs
@@ -73,7 +73,7 @@ test_act1 :: Assertion
test_act1 = do
let Just (i_s_sec, i_s_pub) = BOLT8.keypair initiator_s_priv
Just rs = BOLT8.parse_pub responder_s_pub
- case BOLT8.initiator_act1 i_s_sec i_s_pub rs initiator_e_priv of
+ case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
Left err -> assertFailure $ "act1 failed: " ++ show err
Right (act1_msg, _hs) -> act1_msg @?= expected_act1
@@ -83,14 +83,12 @@ test_act2 = do
Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
Just rs = BOLT8.parse_pub responder_s_pub
- -- initiator generates act1
- case BOLT8.initiator_act1 i_s_sec i_s_pub rs initiator_e_priv of
+ case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
Left err -> assertFailure $ "act1 failed: " ++ show err
- Right (act1_msg, _) -> do
- -- responder processes act1 and generates act2
- case BOLT8.responder_act2 r_s_sec r_s_pub responder_e_priv act1_msg of
+ Right (msg1, _) -> do
+ case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
Left err -> assertFailure $ "act2 failed: " ++ show err
- Right (act2_msg, _) -> act2_msg @?= expected_act2
+ Right (msg2, _) -> msg2 @?= expected_act2
test_act3 :: Assertion
test_act3 = do
@@ -98,18 +96,15 @@ test_act3 = do
Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
Just rs = BOLT8.parse_pub responder_s_pub
- -- initiator generates act1
- case BOLT8.initiator_act1 i_s_sec i_s_pub rs initiator_e_priv of
+ case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
Left err -> assertFailure $ "act1 failed: " ++ show err
- Right (act1_msg, i_hs) -> do
- -- responder processes act1 and generates act2
- case BOLT8.responder_act2 r_s_sec r_s_pub responder_e_priv act1_msg of
+ Right (msg1, i_hs) -> do
+ case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
Left err -> assertFailure $ "act2 failed: " ++ show err
- Right (act2_msg, _) -> do
- -- initiator processes act2 and generates act3
- case BOLT8.initiator_act3 i_hs act2_msg of
+ Right (msg2, _) -> do
+ case BOLT8.act3 i_hs msg2 of
Left err -> assertFailure $ "act3 failed: " ++ show err
- Right (act3_msg, _) -> act3_msg @?= expected_act3
+ Right (msg3, _) -> msg3 @?= expected_act3
test_full_handshake :: Assertion
test_full_handshake = do
@@ -117,25 +112,20 @@ test_full_handshake = do
Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
Just rs = BOLT8.parse_pub responder_s_pub
- -- Act 1: initiator generates
- case BOLT8.initiator_act1 i_s_sec i_s_pub rs initiator_e_priv of
+ case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
Left err -> assertFailure $ "act1 failed: " ++ show err
- Right (act1_msg, i_hs) -> do
- -- Act 2: responder processes act1, generates act2
- case BOLT8.responder_act2 r_s_sec r_s_pub responder_e_priv act1_msg of
+ Right (msg1, i_hs) -> do
+ case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
Left err -> assertFailure $ "act2 failed: " ++ show err
- Right (act2_msg, r_hs) -> do
- -- Act 3: initiator processes act2, generates act3
- case BOLT8.initiator_act3 i_hs act2_msg of
+ Right (msg2, r_hs) -> do
+ case BOLT8.act3 i_hs msg2 of
Left err -> assertFailure $ "act3 failed: " ++ show err
- Right (act3_msg, i_result) -> do
- -- Responder finalizes
- case BOLT8.responder_finalize r_hs act3_msg of
+ Right (msg3, i_result) -> do
+ case BOLT8.finalize r_hs msg3 of
Left err -> assertFailure $ "finalize failed: " ++ show err
Right r_result -> do
- -- Verify remote pubkeys match
- BOLT8.hr_remote_pk i_result @?= r_s_pub
- BOLT8.hr_remote_pk r_result @?= i_s_pub
+ BOLT8.remote_static i_result @?= r_s_pub
+ BOLT8.remote_static r_result @?= i_s_pub
-- message encryption tests --------------------------------------------------
@@ -192,25 +182,25 @@ get_initiator_session = do
Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
Just rs = BOLT8.parse_pub responder_s_pub
- case BOLT8.initiator_act1 i_s_sec i_s_pub rs initiator_e_priv of
+ case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
Left err -> fail $ "act1 failed: " ++ show err
- Right (act1_msg, i_hs) ->
- case BOLT8.responder_act2 r_s_sec r_s_pub responder_e_priv act1_msg of
+ Right (msg1, i_hs) ->
+ case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
Left err -> fail $ "act2 failed: " ++ show err
- Right (act2_msg, _) ->
- case BOLT8.initiator_act3 i_hs act2_msg of
+ Right (msg2, _) ->
+ case BOLT8.act3 i_hs msg2 of
Left err -> fail $ "act3 failed: " ++ show err
- Right (_, result) -> pure (BOLT8.hr_session result)
+ Right (_, result) -> pure (BOLT8.session result)
-- encrypt N messages, return Nth ciphertext
encrypt_n :: Int -> BOLT8.Session -> IO BS.ByteString
encrypt_n n sess0 = go 0 sess0
where
go i sess
- | i == n = case BOLT8.encrypt_message sess hello of
+ | i == n = case BOLT8.encrypt sess hello of
Left err -> fail $ "encrypt failed at " ++ show i ++ ": " ++ show err
Right (ct, _) -> pure ct
- | otherwise = case BOLT8.encrypt_message sess hello of
+ | otherwise = case BOLT8.encrypt sess hello of
Left err -> fail $ "encrypt failed at " ++ show i ++ ": " ++ show err
Right (_, sess') -> go (i + 1) sess'
@@ -256,26 +246,24 @@ test_decrypt_roundtrip = do
Just (r_s_sec, r_s_pub) = BOLT8.keypair responder_s_priv
Just rs = BOLT8.parse_pub responder_s_pub
- -- Complete handshake
- case BOLT8.initiator_act1 i_s_sec i_s_pub rs initiator_e_priv of
+ case BOLT8.act1 i_s_sec i_s_pub rs initiator_e_priv of
Left err -> assertFailure $ "act1 failed: " ++ show err
- Right (act1_msg, i_hs) ->
- case BOLT8.responder_act2 r_s_sec r_s_pub responder_e_priv act1_msg of
+ Right (msg1, i_hs) ->
+ case BOLT8.act2 r_s_sec r_s_pub responder_e_priv msg1 of
Left err -> assertFailure $ "act2 failed: " ++ show err
- Right (act2_msg, r_hs) ->
- case BOLT8.initiator_act3 i_hs act2_msg of
+ Right (msg2, r_hs) ->
+ case BOLT8.act3 i_hs msg2 of
Left err -> assertFailure $ "act3 failed: " ++ show err
- Right (act3_msg, i_result) ->
- case BOLT8.responder_finalize r_hs act3_msg of
+ Right (msg3, i_result) ->
+ case BOLT8.finalize r_hs msg3 of
Left err -> assertFailure $ "finalize failed: " ++ show err
Right r_result -> do
- let i_sess = BOLT8.hr_session i_result
- r_sess = BOLT8.hr_session r_result
- -- Initiator sends to responder
- case BOLT8.encrypt_message i_sess hello of
+ let i_sess = BOLT8.session i_result
+ r_sess = BOLT8.session r_result
+ case BOLT8.encrypt i_sess hello of
Left err -> assertFailure $ "encrypt failed: " ++ show err
Right (ct, _) ->
- case BOLT8.decrypt_message r_sess ct of
+ case BOLT8.decrypt r_sess ct of
Left err ->
assertFailure $ "decrypt failed: " ++ show err
Right (pt, _) -> pt @?= hello