secp256k1

Pure Haskell cryptographic primitives on the secp256k1 elliptic curve.
git clone git://git.ppad.tech/secp256k1.git
Log | Files | Refs | LICENSE

commit a162174c0e9bcf74acda20130e549f19e123b13e
parent 20ce103a3b8d41c1049732b4d16ca799def35f03
Author: Jared Tobin <jared@jtobin.io>
Date:   Mon, 14 Oct 2024 19:35:09 +0400

test: basic schnorr stuff

Diffstat:
Mlib/Crypto/Curve/Secp256k1.hs | 89+++++++++++++++++++++++++++++++++++++++++--------------------------------------
Mtest/BIP340.hs | 68+++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------
2 files changed, 101 insertions(+), 56 deletions(-)

diff --git a/lib/Crypto/Curve/Secp256k1.hs b/lib/Crypto/Curve/Secp256k1.hs @@ -431,34 +431,38 @@ mul_safe p n -- parsing -------------------------------------------------------------------- --- | Parse hex-encoded compressed or uncompressed point. +-- | Parse hex-encoded compressed or uncompressed point, or BIP0340 +-- public key. parse_point :: BS.ByteString -> Maybe Projective parse_point (B16.decode -> ebs) = case ebs of Left _ -> Nothing - Right bs -> case BS.uncons bs of - Nothing -> Nothing - Just (fi -> h, t) -> - let (roll -> x, etc) = BS.splitAt (fi _CURVE_Q_BYTES) t - len = BS.length bs - in if len == 33 && (h == 0x02 || h == 0x03) -- compressed - then if not (fe x) - then Nothing - else do - y <- modsqrt (weierstrass x) - let yodd = I.integerTestBit y 0 - hodd = I.integerTestBit h 0 - pure $ - if hodd /= yodd - then Projective x (modP (negate y)) 1 - else Projective x y 1 - else - if len == 65 && h == 0x04 -- uncompressed - then let (roll -> y, _) = BS.splitAt (fi _CURVE_Q_BYTES) etc - p = Projective x y 1 - in if valid p - then Just p - else Nothing - else Nothing + Right bs + | BS.length bs == 32 -> -- bip0340 public key + fmap projective (lift (roll bs)) + | otherwise -> case BS.uncons bs of + Nothing -> Nothing + Just (fi -> h, t) -> + let (roll -> x, etc) = BS.splitAt (fi _CURVE_Q_BYTES) t + len = BS.length bs + in if len == 33 && (h == 0x02 || h == 0x03) -- compressed + then if not (fe x) + then Nothing + else do + y <- modsqrt (weierstrass x) + let yodd = I.integerTestBit y 0 + hodd = I.integerTestBit h 0 + pure $ + if hodd /= yodd + then Projective x (modP (negate y)) 1 + else Projective x y 1 + else + if len == 65 && h == 0x04 -- uncompressed + then let (roll -> y, _) = BS.splitAt (fi _CURVE_Q_BYTES) etc + p = Projective x y 1 + in if valid p + then Just p + else Nothing + else Nothing -- big-endian bytestring decoding roll :: BS.ByteString -> Integer @@ -682,8 +686,8 @@ sign_schnorr d' m a k' = modQ (roll rand) - in if k' == 0 - then error "ppad-secp256k1 (sign_schnorr): invalid k" -- negligible + in if k' == 0 -- negligible probability + then error "ppad-secp256k1 (sign_schnorr): invalid k" else let Affine x_r y_r = affine (mul _CURVE_G k') k | y_r `rem` 2 == 0 = k' @@ -711,9 +715,9 @@ modexp b e m let t = if B.testBit e 0 then b `mod` m else 1 in t * modexp ((b * b) `mod` m) (B.shiftR e 1) m `mod` m -lift :: Integer -> Affine +lift :: Integer -> Maybe Affine lift x - | not (fe x) = error "ppad-secp256k1 (lift): not field element" + | not (fe x) = Nothing | otherwise = let c = modP (modexp x 3 _CURVE_P + 7) y = modexp c ((_CURVE_P + 1) `div` 4) _CURVE_P @@ -722,24 +726,23 @@ lift x | otherwise = _CURVE_P - y in if c /= modexp y 2 _CURVE_P - then error "ppad-secp256k1 (lift): modular square predicate failed" - else Affine x y_p + then Nothing + else Just $! (Affine x y_p) verify_schnorr :: BS.ByteString -- ^ message -> Affine -- ^ public key -> BS.ByteString -- ^ 64-byte schnorr signature -> Bool -verify_schnorr m (Affine x_p _) sig = - let capP@(Affine x_P _) = lift x_p - (roll -> r, roll -> s) = BS.splitAt 32 sig - in if r >= _CURVE_P - then False - else if s >= _CURVE_Q - then False - else let e = modQ . roll $ hash_tagged "BIP0340/challenge" - (unroll r <> unroll x_P <> m) - Affine x_R y_R = affine $ - add (mul _CURVE_G s) (neg (mul (projective capP) e)) - in not (y_R `rem` 2 /= 0 || x_R /= r) +verify_schnorr m (Affine x_p _) sig = case lift x_p of + Nothing -> False + Just capP@(Affine x_P _) -> + let (roll -> r, roll -> s) = BS.splitAt 32 sig + in if r >= _CURVE_P || s >= _CURVE_Q + then False + else let e = modQ . roll $ hash_tagged "BIP0340/challenge" + (unroll r <> unroll x_P <> m) + Affine x_R y_R = affine $ + add (mul _CURVE_G s) (neg (mul (projective capP) e)) + in not (y_R `rem` 2 /= 0 || x_R /= r) diff --git a/test/BIP340.hs b/test/BIP340.hs @@ -1,11 +1,30 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ViewPatterns #-} -module BIP340 where +module BIP340 ( + cases + , execute + ) where import Control.Applicative -import qualified Data.Attoparsec.ByteString.Char as AT +import Crypto.Curve.Secp256k1 +import qualified Data.Attoparsec.ByteString.Char8 as AT import qualified Data.ByteString as BS +import qualified Data.ByteString.Base16 as B16 +import qualified GHC.Num.Integer as I +import Test.Tasty +import Test.Tasty.HUnit + +-- XX make a test prelude instead of copying/pasting these things everywhere + +fi :: (Integral a, Num b) => a -> b +fi = fromIntegral +{-# INLINE fi #-} + +roll :: BS.ByteString -> Integer +roll = BS.foldl' unstep 0 where + unstep a (fi -> b) = (a `I.integerShiftL` 8) `I.integerOr` b data Case = Case { c_index :: !Int @@ -18,6 +37,28 @@ data Case = Case { , c_comment :: !BS.ByteString } deriving Show +execute :: Case -> TestTree +execute Case {..} = testCase ("bip0340 " <> show c_index) $ + case parse_point c_pk of + Nothing -> assertFailure "no parse" + Just p -> do + let pk = affine p + if c_sk == mempty + then do -- no signature; test verification + let ver = verify_schnorr c_msg pk c_sig + if c_res + then assertBool mempty ver + else assertBool mempty (not ver) + -- XX test pubkey derivation from sk + else do -- signature present; test sig too + let sk = roll c_sk + sig = sign_schnorr sk c_msg c_aux + ver = verify_schnorr c_msg pk sig + assertEqual mempty c_sig sig + if c_res + then assertBool mempty ver + else assertBool mempty (not ver) + header :: AT.Parser () header = do _ <- AT.string "index,secret key,public key,aux_rand,message,signature,verification result,comment" @@ -25,21 +66,22 @@ header = do test_case :: AT.Parser Case test_case = do - c_index <- AT.decimal + c_index <- AT.decimal AT.<?> "index" + _ <- AT.char ',' + c_sk <- fmap B16.decodeLenient (AT.takeWhile1 (/= ',') AT.<?> "sk") + _ <- AT.char ',' + c_pk <- AT.takeWhile1 (/= ',') AT.<?> "pk" + _ <- AT.char ',' + c_aux <- fmap B16.decodeLenient (AT.takeWhile1 (/= ',') AT.<?> "aux") + _ <- AT.char ',' + c_msg <- fmap B16.decodeLenient (AT.takeWhile1 (/= ',') AT.<?> "msg") _ <- AT.char ',' - c_sk <- AT.takeWhile1 (/= ',') + c_sig <- fmap B16.decodeLenient (AT.takeWhile1 (/= ',') AT.<?> "sig") _ <- AT.char ',' - c_pk <- AT.takeWhile1 (/= ',') - _ <- char ',' - c_aux <- AT.takeWhile1 (/= ',') - _ <- char ',' - c_msg <- AT.takeWhile1 (/= ',') - _ <- char ',' - c_sig <- AT.takeWhile1 (/= ',') - _ <- char ',' c_res <- (AT.string "TRUE" *> pure True) <|> (AT.string "FALSE" *> pure False) + AT.<?> "res" _ <- AT.char ',' - c_comment <- AT.takeWhile1 (/= '\n') + c_comment <- AT.takeWhile (/= '\n') AT.<?> "comment" AT.endOfLine pure Case {..}