commit a162174c0e9bcf74acda20130e549f19e123b13e
parent 20ce103a3b8d41c1049732b4d16ca799def35f03
Author: Jared Tobin <jared@jtobin.io>
Date: Mon, 14 Oct 2024 19:35:09 +0400
test: basic schnorr stuff
Diffstat:
2 files changed, 101 insertions(+), 56 deletions(-)
diff --git a/lib/Crypto/Curve/Secp256k1.hs b/lib/Crypto/Curve/Secp256k1.hs
@@ -431,34 +431,38 @@ mul_safe p n
-- parsing --------------------------------------------------------------------
--- | Parse hex-encoded compressed or uncompressed point.
+-- | Parse hex-encoded compressed or uncompressed point, or BIP0340
+-- public key.
parse_point :: BS.ByteString -> Maybe Projective
parse_point (B16.decode -> ebs) = case ebs of
Left _ -> Nothing
- Right bs -> case BS.uncons bs of
- Nothing -> Nothing
- Just (fi -> h, t) ->
- let (roll -> x, etc) = BS.splitAt (fi _CURVE_Q_BYTES) t
- len = BS.length bs
- in if len == 33 && (h == 0x02 || h == 0x03) -- compressed
- then if not (fe x)
- then Nothing
- else do
- y <- modsqrt (weierstrass x)
- let yodd = I.integerTestBit y 0
- hodd = I.integerTestBit h 0
- pure $
- if hodd /= yodd
- then Projective x (modP (negate y)) 1
- else Projective x y 1
- else
- if len == 65 && h == 0x04 -- uncompressed
- then let (roll -> y, _) = BS.splitAt (fi _CURVE_Q_BYTES) etc
- p = Projective x y 1
- in if valid p
- then Just p
- else Nothing
- else Nothing
+ Right bs
+ | BS.length bs == 32 -> -- bip0340 public key
+ fmap projective (lift (roll bs))
+ | otherwise -> case BS.uncons bs of
+ Nothing -> Nothing
+ Just (fi -> h, t) ->
+ let (roll -> x, etc) = BS.splitAt (fi _CURVE_Q_BYTES) t
+ len = BS.length bs
+ in if len == 33 && (h == 0x02 || h == 0x03) -- compressed
+ then if not (fe x)
+ then Nothing
+ else do
+ y <- modsqrt (weierstrass x)
+ let yodd = I.integerTestBit y 0
+ hodd = I.integerTestBit h 0
+ pure $
+ if hodd /= yodd
+ then Projective x (modP (negate y)) 1
+ else Projective x y 1
+ else
+ if len == 65 && h == 0x04 -- uncompressed
+ then let (roll -> y, _) = BS.splitAt (fi _CURVE_Q_BYTES) etc
+ p = Projective x y 1
+ in if valid p
+ then Just p
+ else Nothing
+ else Nothing
-- big-endian bytestring decoding
roll :: BS.ByteString -> Integer
@@ -682,8 +686,8 @@ sign_schnorr d' m a
k' = modQ (roll rand)
- in if k' == 0
- then error "ppad-secp256k1 (sign_schnorr): invalid k" -- negligible
+ in if k' == 0 -- negligible probability
+ then error "ppad-secp256k1 (sign_schnorr): invalid k"
else
let Affine x_r y_r = affine (mul _CURVE_G k')
k | y_r `rem` 2 == 0 = k'
@@ -711,9 +715,9 @@ modexp b e m
let t = if B.testBit e 0 then b `mod` m else 1
in t * modexp ((b * b) `mod` m) (B.shiftR e 1) m `mod` m
-lift :: Integer -> Affine
+lift :: Integer -> Maybe Affine
lift x
- | not (fe x) = error "ppad-secp256k1 (lift): not field element"
+ | not (fe x) = Nothing
| otherwise =
let c = modP (modexp x 3 _CURVE_P + 7)
y = modexp c ((_CURVE_P + 1) `div` 4) _CURVE_P
@@ -722,24 +726,23 @@ lift x
| otherwise = _CURVE_P - y
in if c /= modexp y 2 _CURVE_P
- then error "ppad-secp256k1 (lift): modular square predicate failed"
- else Affine x y_p
+ then Nothing
+ else Just $! (Affine x y_p)
verify_schnorr
:: BS.ByteString -- ^ message
-> Affine -- ^ public key
-> BS.ByteString -- ^ 64-byte schnorr signature
-> Bool
-verify_schnorr m (Affine x_p _) sig =
- let capP@(Affine x_P _) = lift x_p
- (roll -> r, roll -> s) = BS.splitAt 32 sig
- in if r >= _CURVE_P
- then False
- else if s >= _CURVE_Q
- then False
- else let e = modQ . roll $ hash_tagged "BIP0340/challenge"
- (unroll r <> unroll x_P <> m)
- Affine x_R y_R = affine $
- add (mul _CURVE_G s) (neg (mul (projective capP) e))
- in not (y_R `rem` 2 /= 0 || x_R /= r)
+verify_schnorr m (Affine x_p _) sig = case lift x_p of
+ Nothing -> False
+ Just capP@(Affine x_P _) ->
+ let (roll -> r, roll -> s) = BS.splitAt 32 sig
+ in if r >= _CURVE_P || s >= _CURVE_Q
+ then False
+ else let e = modQ . roll $ hash_tagged "BIP0340/challenge"
+ (unroll r <> unroll x_P <> m)
+ Affine x_R y_R = affine $
+ add (mul _CURVE_G s) (neg (mul (projective capP) e))
+ in not (y_R `rem` 2 /= 0 || x_R /= r)
diff --git a/test/BIP340.hs b/test/BIP340.hs
@@ -1,11 +1,30 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
+{-# LANGUAGE ViewPatterns #-}
-module BIP340 where
+module BIP340 (
+ cases
+ , execute
+ ) where
import Control.Applicative
-import qualified Data.Attoparsec.ByteString.Char as AT
+import Crypto.Curve.Secp256k1
+import qualified Data.Attoparsec.ByteString.Char8 as AT
import qualified Data.ByteString as BS
+import qualified Data.ByteString.Base16 as B16
+import qualified GHC.Num.Integer as I
+import Test.Tasty
+import Test.Tasty.HUnit
+
+-- XX make a test prelude instead of copying/pasting these things everywhere
+
+fi :: (Integral a, Num b) => a -> b
+fi = fromIntegral
+{-# INLINE fi #-}
+
+roll :: BS.ByteString -> Integer
+roll = BS.foldl' unstep 0 where
+ unstep a (fi -> b) = (a `I.integerShiftL` 8) `I.integerOr` b
data Case = Case {
c_index :: !Int
@@ -18,6 +37,28 @@ data Case = Case {
, c_comment :: !BS.ByteString
} deriving Show
+execute :: Case -> TestTree
+execute Case {..} = testCase ("bip0340 " <> show c_index) $
+ case parse_point c_pk of
+ Nothing -> assertFailure "no parse"
+ Just p -> do
+ let pk = affine p
+ if c_sk == mempty
+ then do -- no signature; test verification
+ let ver = verify_schnorr c_msg pk c_sig
+ if c_res
+ then assertBool mempty ver
+ else assertBool mempty (not ver)
+ -- XX test pubkey derivation from sk
+ else do -- signature present; test sig too
+ let sk = roll c_sk
+ sig = sign_schnorr sk c_msg c_aux
+ ver = verify_schnorr c_msg pk sig
+ assertEqual mempty c_sig sig
+ if c_res
+ then assertBool mempty ver
+ else assertBool mempty (not ver)
+
header :: AT.Parser ()
header = do
_ <- AT.string "index,secret key,public key,aux_rand,message,signature,verification result,comment"
@@ -25,21 +66,22 @@ header = do
test_case :: AT.Parser Case
test_case = do
- c_index <- AT.decimal
+ c_index <- AT.decimal AT.<?> "index"
+ _ <- AT.char ','
+ c_sk <- fmap B16.decodeLenient (AT.takeWhile1 (/= ',') AT.<?> "sk")
+ _ <- AT.char ','
+ c_pk <- AT.takeWhile1 (/= ',') AT.<?> "pk"
+ _ <- AT.char ','
+ c_aux <- fmap B16.decodeLenient (AT.takeWhile1 (/= ',') AT.<?> "aux")
+ _ <- AT.char ','
+ c_msg <- fmap B16.decodeLenient (AT.takeWhile1 (/= ',') AT.<?> "msg")
_ <- AT.char ','
- c_sk <- AT.takeWhile1 (/= ',')
+ c_sig <- fmap B16.decodeLenient (AT.takeWhile1 (/= ',') AT.<?> "sig")
_ <- AT.char ','
- c_pk <- AT.takeWhile1 (/= ',')
- _ <- char ','
- c_aux <- AT.takeWhile1 (/= ',')
- _ <- char ','
- c_msg <- AT.takeWhile1 (/= ',')
- _ <- char ','
- c_sig <- AT.takeWhile1 (/= ',')
- _ <- char ','
c_res <- (AT.string "TRUE" *> pure True) <|> (AT.string "FALSE" *> pure False)
+ AT.<?> "res"
_ <- AT.char ','
- c_comment <- AT.takeWhile1 (/= '\n')
+ c_comment <- AT.takeWhile (/= '\n') AT.<?> "comment"
AT.endOfLine
pure Case {..}