secp256k1

Pure Haskell cryptographic primitives on the secp256k1 elliptic curve.
git clone git://git.ppad.tech/secp256k1.git
Log | Files | Refs | LICENSE

commit af6f564ef97aa00021355212aad67f96bb91f0ec
parent 287a4e9e649a2fb5215ab10575239c563afc08b5
Author: Jared Tobin <jared@jtobin.io>
Date:   Sun, 24 Mar 2024 18:40:39 +0400

lib: addition improvements

Implements algorithms 8 and 9 in Renes et al, 2015, for mixed-point
addition and doubling respectively. The 'add' function now makes use
of one of these (or the existing projective points addition algorithm)
depending on the arguments passed to it.

Diffstat:
Mbench/Main.hs | 7++++++-
Mlib/Crypto/Secp256k1.hs | 270++++++++++++++++++++++++++++++++++++++++++++++++++++++++-----------------------
Mtest/Main.hs | 22+++++++++++++++++++++-
3 files changed, 220 insertions(+), 79 deletions(-)

diff --git a/bench/Main.hs b/bench/Main.hs @@ -10,7 +10,6 @@ import qualified Crypto.Secp256k1 as S instance NFData S.Projective instance NFData S.Affine -instance NFData S.Curve main :: IO () main = defaultMain [ @@ -33,6 +32,12 @@ secp256k1 = bgroup "secp256k1" [ , bench "bar qux" $ nf (S.add bar) qux , bench "baz qux" $ nf (S.add baz) qux ] + , bgroup "double" [ + bench "foo" $ nf S.double foo + , bench "bar" $ nf S.double bar + , bench "baz" $ nf S.double baz + , bench "qux" $ nf S.double qux + ] ] where p = "0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798" diff --git a/lib/Crypto/Secp256k1.hs b/lib/Crypto/Secp256k1.hs @@ -15,54 +15,57 @@ import GHC.Generics import GHC.Natural import qualified GHC.Num.Integer as I -data Curve = Curve { - curve_p :: Natural -- ^ field prime - , curve_n :: Natural -- ^ group order - , curve_a :: Integer -- ^ /a/ coefficient, weierstrass form - , curve_b :: Integer -- ^ /b/ coefficient, weierstrass form - , curve_gx :: Integer -- ^ base point x - , curve_gy :: Integer -- ^ base point y - } - deriving stock (Show, Generic) +_B256 :: Integer +_B256 = 2 ^ (256 :: Integer) + +-- secp256k1 field prime +_CURVE_P :: Integer +_CURVE_P = _B256 - 0x1000003d1 + +-- secp256k1 group order +_CURVE_N :: Integer +_CURVE_N = _B256 - 0x14551231950b75fc4402da1732fc9bebf + +-- secp256k1 short weierstrass form, /a/ coefficient +_CURVE_A :: Integer +_CURVE_A = 0 -secp256k1 :: Curve -secp256k1 = Curve p n 0 7 gx gy where - b256 = 2 ^ (256 :: Integer) - p = b256 - 0x1000003d1 - n = b256 - 0x14551231950b75fc4402da1732fc9bebf - gx = 0x79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798 - gy = 0x483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8 +-- secp256k1 weierstrass form, /b/ coefficient +_CURVE_B :: Integer +_CURVE_B = 7 + +-- secp256k1 base point, x coordinate +_CURVE_GX :: Integer +_CURVE_GX = 0x79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798 + +-- secp256k1 base point, y coordinate +_CURVE_GY :: Integer +_CURVE_GY = 0x483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8 -- modular division by secp256k1 group order mods :: Integer -> Integer -mods a = I.integerMod a (fromIntegral (curve_p secp256k1)) +mods a = I.integerMod a _CURVE_P -- is field element (i.e., is invertible) fe :: Integer -> Bool -fe n = 0 < n && n < fromIntegral (curve_p secp256k1) +fe n = 0 < n && n < _CURVE_P -- is group element -ge :: Natural -> Bool -ge n = 0 < n && n < (curve_n secp256k1) +ge :: Integer -> Bool +ge n = 0 < n && n < _CURVE_N -- for a, m return x such that ax = 1 mod m -modinv :: Integer -> Natural -> Maybe Natural +modinv :: Integer -> Natural -> Maybe Integer modinv a m = case I.integerRecipMod# a m of - (# n | #) -> Just n + (# n | #) -> Just (fromIntegral n) (# | _ #) -> Nothing --- elliptic curve - --- XX not general weierstrass; only for j-invariant 0 (i.e. a == 0) -weierstrass :: Integer -> Integer -weierstrass x = mods (mods (x * x) * x + curve_b secp256k1) - --- modular square root +-- modular square root (shanks-tonelli) modsqrt :: Integer -> Maybe Integer modsqrt n = runST $ do r <- newSTRef 1 num <- newSTRef n - e <- newSTRef ((p + 1) `div` 4) + e <- newSTRef ((_CURVE_P + 1) `div` 4) loop r num e rr <- readSTRef r pure $ @@ -70,22 +73,19 @@ modsqrt n = runST $ do then Just rr else Nothing where - p = fromIntegral (curve_p secp256k1) loop sr snum se = do e <- readSTRef se when (e > 0) $ do when (I.integerTestBit e 0) $ do num <- readSTRef snum - modifySTRef' sr (\lr -> (lr * num) `rem` p) - modifySTRef' snum (\ln -> (ln * ln) `rem` p) + modifySTRef' sr (\lr -> (lr * num) `rem` _CURVE_P) + modifySTRef' snum (\ln -> (ln * ln) `rem` _CURVE_P) modifySTRef' se (`I.integerShiftR` 1) loop sr snum se --- group bytelength -_GROUP_BYTELENGTH :: Int -_GROUP_BYTELENGTH = 32 - --- curve point +-- prime order j-invariant 0 (i.e. a == 0) +weierstrass :: Integer -> Integer +weierstrass x = mods (mods (x * x) * x + _CURVE_B) data Affine = Affine Integer Integer deriving stock (Show, Generic) @@ -113,21 +113,27 @@ _ZERO :: Projective _ZERO = Projective 0 1 0 _BASE :: Projective -_BASE = Projective (curve_gx secp256k1) (curve_gy secp256k1) 1 +_BASE = Projective _CURVE_GX _CURVE_GY 1 -- negate point neg :: Projective -> Projective neg (Projective x y z) = Projective x (mods (negate y)) z +-- general ec addition +add :: Projective -> Projective -> Projective +add p q@(Projective _ _ z) + | p == q = double p -- algo 9 + | z == 1 = add_mixed p q -- algo 8 + | otherwise = add_proj p q -- algo 7 + -- algo 7, "complete addition formulas for prime order elliptic curves," -- renes et al, 2015 -add :: Projective -> Projective -> Projective -add (Projective x1 y1 z1) (Projective x2 y2 z2) = runST $ do - let b = curve_b secp256k1 +add_proj :: Projective -> Projective -> Projective +add_proj (Projective x1 y1 z1) (Projective x2 y2 z2) = runST $ do x3 <- newSTRef 0 y3 <- newSTRef 0 z3 <- newSTRef 0 - let b3 = mods (b * 3) + let b3 = mods (_CURVE_B * 3) t0 <- newSTRef (mods (x1 * x2)) -- 1 t1 <- newSTRef (mods (y1 * y2)) t2 <- newSTRef (mods (z1 * z2)) @@ -191,9 +197,122 @@ add (Projective x1 y1 z1) (Projective x2 y2 z2) = runST $ do modifySTRef' z3 (\rz3 -> mods (rz3 + r0)) Projective <$> readSTRef x3 <*> readSTRef y3 <*> readSTRef z3 --- double a point +-- algo 8, renes et al, 2015 +add_mixed :: Projective -> Projective -> Projective +add_mixed (Projective x1 y1 z1) (Projective x2 y2 z2) + | z2 /= 1 = error "secp256k1: internal error" + | otherwise = runST $ do + x3 <- newSTRef 0 + y3 <- newSTRef 0 + z3 <- newSTRef 0 + let b3 = mods (_CURVE_B * 3) + t0 <- newSTRef (mods (x1 * x2)) -- 1 + t1 <- newSTRef (mods (y1 * y2)) + t3 <- newSTRef (mods (x2 + y2)) + t4 <- newSTRef (mods (x1 + y1)) -- 4 + readSTRef t4 >>= \r4 -> + modifySTRef' t3 (\r3 -> mods (r3 * r4)) + readSTRef t0 >>= \r0 -> + readSTRef t1 >>= \r1 -> + writeSTRef t4 (mods (r0 + r1)) + readSTRef t4 >>= \r4 -> + modifySTRef' t3 (\r3 -> mods (r3 - r4)) -- 7 + writeSTRef t4 (mods (y2 * z1)) + modifySTRef' t4 (\r4 -> mods (r4 + y1)) + writeSTRef y3 (mods (x2 * z1)) -- 10 + modifySTRef' y3 (\ry3 -> mods (ry3 + x1)) + readSTRef t0 >>= \r0 -> + writeSTRef x3 (mods (r0 + r0)) + readSTRef x3 >>= \rx3 -> + modifySTRef' t0 (\r0 -> mods (rx3 + r0)) -- 13 + t2 <- newSTRef (mods (b3 * z1)) + readSTRef t1 >>= \r1 -> + readSTRef t2 >>= \r2 -> + writeSTRef z3 (mods (r1 + r2)) + readSTRef t2 >>= \r2 -> + modifySTRef' t1 (\r1 -> mods (r1 - r2)) -- 16 + modifySTRef' y3 (\ry3 -> mods (b3 * ry3)) + readSTRef t4 >>= \r4 -> + readSTRef y3 >>= \ry3 -> + writeSTRef x3 (mods (r4 * ry3)) + readSTRef t3 >>= \r3 -> + readSTRef t1 >>= \r1 -> + writeSTRef t2 (mods (r3 * r1)) -- 19 + readSTRef t2 >>= \r2 -> + modifySTRef' x3 (\rx3 -> mods (r2 - rx3)) + readSTRef t0 >>= \r0 -> + modifySTRef' y3 (\ry3 -> mods (ry3 * r0)) + readSTRef z3 >>= \rz3 -> + modifySTRef' t1 (\r1 -> mods (r1 * rz3)) -- 22 + readSTRef t1 >>= \r1 -> + modifySTRef' y3 (\ry3 -> mods (r1 + ry3)) + readSTRef t3 >>= \r3 -> + modifySTRef' t0 (\r0 -> mods (r0 * r3)) + readSTRef t4 >>= \r4 -> + modifySTRef' z3 (\rz3 -> mods (rz3 * r4)) -- 25 + readSTRef t0 >>= \r0 -> + modifySTRef' z3 (\rz3 -> mods (rz3 + r0)) + Projective <$> readSTRef x3 <*> readSTRef y3 <*> readSTRef z3 + +-- algo 9, renes et al, 2015 double :: Projective -> Projective -double p = add p p +double (Projective x y z) = runST $ do + x3 <- newSTRef 0 + y3 <- newSTRef 0 + z3 <- newSTRef 0 + let b3 = mods (_CURVE_B * 3) + t0 <- newSTRef (mods (y * y)) -- 1 + readSTRef t0 >>= \r0 -> + writeSTRef z3 (mods (r0 + r0)) + modifySTRef' z3 (\rz3 -> mods (rz3 + rz3)) + modifySTRef' z3 (\rz3 -> mods (rz3 + rz3)) -- 4 + t1 <- newSTRef (mods (y * z)) + t2 <- newSTRef (mods (z * z)) + modifySTRef t2 (\r2 -> mods (b3 * r2)) -- 7 + readSTRef z3 >>= \rz3 -> + readSTRef t2 >>= \r2 -> + writeSTRef x3 (mods (r2 * rz3)) + readSTRef t0 >>= \r0 -> + readSTRef t2 >>= \r2 -> + writeSTRef y3 (mods (r0 + r2)) + readSTRef t1 >>= \r1 -> + modifySTRef' z3 (\rz3 -> mods (r1 * rz3)) -- 10 + readSTRef t2 >>= \r2 -> + writeSTRef t1 (mods (r2 + r2)) + readSTRef t1 >>= \r1 -> + modifySTRef' t2 (\r2 -> mods (r1 + r2)) + readSTRef t2 >>= \r2 -> + modifySTRef' t0 (\r0 -> mods (r0 - r2)) -- 13 + readSTRef t0 >>= \r0 -> + modifySTRef' y3 (\ry3 -> mods (r0 * ry3)) + readSTRef x3 >>= \rx3 -> + modifySTRef' y3 (\ry3 -> mods (rx3 + ry3)) + writeSTRef t1 (mods (x * y)) -- 16 + readSTRef t0 >>= \r0 -> + readSTRef t1 >>= \r1 -> + writeSTRef x3 (mods (r0 * r1)) + modifySTRef' x3 (\rx3 -> mods (rx3 + rx3)) + Projective <$> readSTRef x3 <*> readSTRef y3 <*> readSTRef z3 + + + +-- mul(n, safe = true) { +-- if (!safe && n === 0n) +-- return I; // in unsafe mode, allow zero +-- if (!ge(n)) +-- err('invalid scalar'); // must be 0 < n < CURVE.n +-- if (this.equals(G)) +-- return wNAF(n).p; // use precomputes for base point +-- let p = I, f = G; // init result point & fake point +-- for (let d = this; n > 0n; d = d.double(), n >>= 1n) { // double-and-add ladder +-- if (n & 1n) +-- p = p.add(d); // if bit is present, add to point +-- else if (safe) +-- f = f.add(d); // if not, add to fake for timing safety +-- } +-- return p; +-- } + -- to affine coordinates affine :: Projective -> Maybe Affine @@ -201,17 +320,11 @@ affine p@(Projective x y z) | p == _ZERO = pure (Affine 0 0) | z == 1 = pure (Affine x y) | otherwise = do - iz <- fmap fromIntegral (modinv z (curve_p secp256k1)) + iz <- modinv z (fromIntegral _CURVE_P) if mods (z * iz) /= 1 then Nothing else pure (Affine (mods (x * iz)) (mods (y * iz))) --- partial affine -affine' :: Projective -> Affine -affine' p = case affine p of - Nothing -> error "bang" - Just x -> x - -- to projective coordinates projective :: Affine -> Projective projective (Affine x y) @@ -227,33 +340,36 @@ valid p = case affine p of | mods (y * y) /= weierstrass x -> False | otherwise -> True --- parse hex-encoded +-- parse hex-encoded point parse :: BS.ByteString -> Maybe Projective parse (B16.decode -> ebs) = case ebs of - Left _ -> Nothing - Right bs -> case BS.uncons bs of - Nothing -> Nothing - Just (h, t) -> - let (roll -> x, etc) = BS.splitAt _GROUP_BYTELENGTH t - len = BS.length bs - in if len == 33 && (h == 0x02 || h == 0x03) -- compressed - then if not (fe x) - then Nothing - else do - y <- modsqrt (weierstrass x) - let yodd = I.integerTestBit y 0 - hodd = I.integerTestBit (fromIntegral h) 0 - pure $ - if hodd /= yodd - then Projective x (mods (negate y)) 1 - else Projective x y 1 - else if len == 65 && h == 0x04 -- uncompressed - then let (roll -> y, _) = BS.splitAt _GROUP_BYTELENGTH etc - p = Projective x y 1 - in if valid p - then Just p - else Nothing - else Nothing + Left _ -> Nothing + Right bs -> case BS.uncons bs of + Nothing -> Nothing + Just (h, t) -> + let (roll -> x, etc) = BS.splitAt _GROUP_BYTELENGTH t + len = BS.length bs + in if len == 33 && (h == 0x02 || h == 0x03) -- compressed + then if not (fe x) + then Nothing + else do + y <- modsqrt (weierstrass x) + let yodd = I.integerTestBit y 0 + hodd = I.integerTestBit (fromIntegral h) 0 + pure $ + if hodd /= yodd + then Projective x (mods (negate y)) 1 + else Projective x y 1 + else if len == 65 && h == 0x04 -- uncompressed + then let (roll -> y, _) = BS.splitAt _GROUP_BYTELENGTH etc + p = Projective x y 1 + in if valid p + then Just p + else Nothing + else Nothing + where + _GROUP_BYTELENGTH :: Int + _GROUP_BYTELENGTH = 32 -- big-endian bytestring decoding roll :: BS.ByteString -> Integer diff --git a/test/Main.hs b/test/Main.hs @@ -14,6 +14,7 @@ units :: TestTree units = testGroup "unit tests" [ parse_tests , add_tests + , dub_tests ] parse_tests :: TestTree @@ -44,7 +45,7 @@ parse_test_r = testCase (render r_hex) $ case parse r_hex of -- XX also make less dumb add_tests :: TestTree -add_tests = testGroup "ec addition, algo 1" [ +add_tests = testGroup "ec addition" [ add_test_pq , add_test_pr , add_test_qr @@ -62,6 +63,25 @@ add_test_qr :: TestTree add_test_qr = testCase "q + r" $ assertEqual mempty qr_pro (q_pro `add` r_pro) +dub_tests :: TestTree +dub_tests = testGroup "ec doubling" [ + dub_test_p + , dub_test_q + , dub_test_r + ] + +dub_test_p :: TestTree +dub_test_p = testCase "2p" $ + assertEqual mempty (p_pro `add` p_pro) (double p_pro) + +dub_test_q :: TestTree +dub_test_q = testCase "2q" $ + assertEqual mempty (q_pro `add` q_pro) (double q_pro) + +dub_test_r :: TestTree +dub_test_r = testCase "2r" $ + assertEqual mempty (r_pro `add` r_pro) (double r_pro) + p_hex :: BS.ByteString p_hex = "0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798"