commit fc2c1fcf42bb4bc5e91213f2d035ae09e5386d25
parent 0514ab10d8df8dba93885db6a47b0d659e296238
Author: Jared Tobin <jared@jtobin.io>
Date: Sun, 24 Mar 2024 15:16:40 +0400
lib: monomorphise, use ghc-bignum utils
Monomorphising on 'Integer' and using e.g. integerMod and
integerRecipMod# from ghc-bignum's GHC.Num.Integer module yields
something like a 13x speedup across-the-board for every elliptic curve
addition algo.
Diffstat:
4 files changed, 262 insertions(+), 147 deletions(-)
diff --git a/bench/Main.hs b/bench/Main.hs
@@ -3,40 +3,46 @@
module Main where
+import qualified Data.ByteString as BS
import Control.DeepSeq
import Criterion.Main
import qualified Crypto.Secp256k1 as S
-instance NFData a => NFData (S.Projective a)
-instance NFData a => NFData (S.Affine a)
-instance NFData a => NFData (S.Triple a)
-instance NFData a => NFData (S.Curve a)
+instance NFData S.Projective
+instance NFData S.Affine
+instance NFData S.Curve
add :: Benchmark
add = bgroup "secp256k1" [
- bgroup "add" [
- bench "foo bar" $ nf (S.add foo) bar
- , bench "foo baz" $ nf (S.add foo) baz
- , bench "foo qux" $ nf (S.add foo) qux
- , bench "bar baz" $ nf (S.add bar) baz
- , bench "bar qux" $ nf (S.add bar) qux
- , bench "baz qux" $ nf (S.add baz) qux
+ bgroup "parse" [
+ bench "foo" $ nf bparse p
+ , bench "bar" $ nf bparse q
+ , bench "baz" $ nf bparse r
+ , bench "qux" $ nf bparse s
+ ]
+ , bgroup "add" [
+ bench "foo bar" $ nf (S.add foo) bar
+ , bench "foo baz" $ nf (S.add foo) baz
+ , bench "foo qux" $ nf (S.add foo) qux
+ , bench "bar baz" $ nf (S.add bar) baz
+ , bench "bar qux" $ nf (S.add bar) qux
+ , bench "baz qux" $ nf (S.add baz) qux
]
, bgroup "add'" [
- bench "foo bar" $ nf (S.add' foo) bar
- , bench "foo baz" $ nf (S.add' foo) baz
- , bench "foo qux" $ nf (S.add' foo) qux
- , bench "bar baz" $ nf (S.add' bar) baz
- , bench "bar qux" $ nf (S.add' bar) qux
- , bench "baz qux" $ nf (S.add' baz) qux
+ bench "foo bar" $ nf (S.add' foo) bar
+ , bench "foo baz" $ nf (S.add' foo) baz
+ , bench "foo qux" $ nf (S.add' foo) qux
+ , bench "bar baz" $ nf (S.add' bar) baz
+ , bench "bar qux" $ nf (S.add' bar) qux
+ , bench "baz qux" $ nf (S.add' baz) qux
]
, bgroup "add_pure" [
- bench "foo bar" $ nf (S.add_pure foo) bar
- , bench "foo baz" $ nf (S.add_pure foo) baz
- , bench "foo qux" $ nf (S.add_pure foo) qux
- , bench "bar baz" $ nf (S.add_pure bar) baz
- , bench "bar qux" $ nf (S.add_pure bar) qux
- , bench "baz qux" $ nf (S.add_pure baz) qux
+ bench "foo bar" $ nf (S.add_pure foo) bar
+ , bench "foo baz" $ nf (S.add_pure foo) baz
+ , bench "foo qux" $ nf (S.add_pure foo) qux
+ , bench "bar baz" $ nf (S.add_pure bar) baz
+ , bench "bar qux" $ nf (S.add_pure bar) qux
+ , bench "baz qux" $ nf (S.add_pure baz) qux
]
, bgroup "add_affine" [
bench "foo bar" $ nf (S.add_affine afoo) abar
@@ -53,23 +59,28 @@ add = bgroup "secp256k1" [
r = "03a2113cf152585d96791a42cdd78782757fbfb5c6b2c11b59857eb4f7fda0b0e8"
s = "0306413898a49c93cccf3db6e9078c1b6a8e62568e4a4770e0d7d96792d1c580ad"
- foo :: S.Projective Integer
- foo = case S.parse_point p of
+ bparse :: BS.ByteString -> S.Projective
+ bparse bs = case S.parse bs of
+ Nothing -> error "bang"
+ Just v -> v
+
+ foo :: S.Projective
+ foo = case S.parse p of
Nothing -> error "boom"
Just !pa -> pa
- bar :: S.Projective Integer
- bar = case S.parse_point q of
+ bar :: S.Projective
+ bar = case S.parse q of
Nothing -> error "bang"
Just !pa -> pa
- baz :: S.Projective Integer
- baz = case S.parse_point r of
+ baz :: S.Projective
+ baz = case S.parse r of
Nothing -> error "bang"
Just !pa -> pa
- qux :: S.Projective Integer
- qux = case S.parse_point s of
+ qux :: S.Projective
+ qux = case S.parse s of
Nothing -> error "bang"
Just !pa -> pa
diff --git a/lib/Crypto/Secp256k1.hs b/lib/Crypto/Secp256k1.hs
@@ -1,113 +1,64 @@
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE UnboxedSums #-}
{-# LANGUAGE ViewPatterns #-}
module Crypto.Secp256k1 where
import Control.Monad (when)
import Control.Monad.ST
-import Data.Bits (Bits, (.&.), (.|.))
-import qualified Data.Bits as B (shiftL, shiftR)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base16 as B16
import Data.STRef
import GHC.Generics
+import GHC.Natural
+import qualified GHC.Num.Integer as I
+
+data Curve = Curve {
+ curve_p :: Natural -- ^ field prime
+ , curve_n :: Natural -- ^ group order
+ , curve_a :: Integer -- ^ /a/ coefficient, weierstrass form
+ , curve_b :: Integer -- ^ /b/ coefficient, weierstrass form
+ , curve_gx :: Integer -- ^ base point x
+ , curve_gy :: Integer -- ^ base point y
+ }
+ deriving stock (Show, Generic)
--- XX could stand some reorganization; lots of stuff 'baked in', e.g. in mods
---
--- XX seems Point should have a reference to the curve it's on; probably a lazy
--- one. otherwise hard to implement Eq.
---
--- then again, we're exclusively concerned with a single curve here, so
--- who cares. bake everything in.
---
--- only counterargument is that we may want to reuse the same skeleton for
--- other libraries after the fact. so we have a single library implementing
--- modular arithmetic, curves, etc., and then implement other stuff on top
--- of that
---
--- i think i like the idea of abstracting quickly and baking types and utils
--- and such into an internal library, e.g. secp256k1-sys; later extract that
-
--- modular arithmetic utilities
-
--- XX be aware of non-constant-timeness in these; i should understand this
--- issue *precisely*
-
--- modular division
-moddiv :: Integral a => a -> a -> a
-moddiv a b
- | r < 0 = b + r
- | otherwise = r
- where
- r = a `mod` b
+secp256k1 :: Curve
+secp256k1 = Curve p n 0 7 gx gy where
+ b256 = 2 ^ (256 :: Integer)
+ p = b256 - 0x1000003d1
+ n = b256 - 0x14551231950b75fc4402da1732fc9bebf
+ gx = 0x79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798
+ gy = 0x483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8
-- modular division by secp256k1 group order
-mods :: Integral a => a -> a
-mods a = moddiv a (curve_p secp256k1)
+mods :: Integer -> Integer
+mods a = I.integerMod a (fromIntegral (curve_p secp256k1))
-- is field element (i.e., is invertible)
-fe :: Integral a => a -> Bool
-fe n = 0 < n && n < (curve_p secp256k1)
+fe :: Integer -> Bool
+fe n = 0 < n && n < fromIntegral (curve_p secp256k1)
-- is group element
-ge :: Integral a => a -> Bool
+ge :: Natural -> Bool
ge n = 0 < n && n < (curve_n secp256k1)
-data Triple a = Triple !a !a !a
- deriving stock Generic
-
--- for a, b, return x, y, g such that ax + by = g for g = gcd(a, b)
-egcd :: Integral a => a -> a -> Triple a
-egcd a 0 = Triple 1 0 a
-egcd a b =
- let (q, r) = a `quotRem` b
- Triple s t g = egcd b r
- in Triple t (s - q * t) g
-
-- for a, m return x such that ax = 1 mod m
-modinv :: Integral a => a -> a -> Maybe a
-modinv a m
- | g == 1 = Just (pos i)
- | otherwise = Nothing
- where
- Triple i _ g = egcd a m
- pos x
- | x < 0 = x + m
- | otherwise = x
-
--- partial modinv
-modinv' :: Integral a => a -> a -> a
-modinv' a m = case modinv a m of
- Just x -> x
- Nothing -> error "modinv': no modular inverse"
+modinv :: Integer -> Natural -> Maybe Natural
+modinv a m = case I.integerRecipMod# a m of
+ (# n | #) -> Just n
+ (# | _ #) -> Nothing
-- elliptic curve
-data Curve a = Curve {
- curve_p :: a -- ^ field prime
- , curve_n :: a -- ^ group order
- , curve_a :: a -- ^ /a/ coefficient, weierstrass form
- , curve_b :: a -- ^ /b/ coefficient, weierstrass form
- , curve_gx :: a -- ^ base point x
- , curve_gy :: a -- ^ base point y
- }
- deriving stock (Show, Generic)
-
-secp256k1 :: Integral a => Curve a
-secp256k1 = Curve p n 0 7 gx gy where
- b256 = 2 ^ (256 :: Integer)
- p = b256 - 0x1000003d1
- n = b256 - 0x14551231950b75fc4402da1732fc9bebf
- gx = 0x79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798
- gy = 0x483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8
-
-- XX not general weierstrass; only for j-invariant 0 (i.e. a == 0)
-weierstrass :: Integral a => a -> a
+weierstrass :: Integer -> Integer
weierstrass x = mods (mods (x * x) * x + curve_b secp256k1)
-- modular square root
-modsqrt :: (Integral a, Bits a) => a -> Maybe a
+modsqrt :: Integer -> Maybe Integer
modsqrt n = runST $ do
r <- newSTRef 1
num <- newSTRef n
@@ -119,38 +70,38 @@ modsqrt n = runST $ do
then Just rr
else Nothing
where
- p = curve_p secp256k1
+ p = fromIntegral (curve_p secp256k1)
loop sr snum se = do
e <- readSTRef se
when (e > 0) $ do
- when (e .&. 1 == 1) $ do
+ when (I.integerTestBit e 0) $ do
num <- readSTRef snum
modifySTRef' sr (\lr -> (lr * num) `rem` p)
modifySTRef' snum (\ln -> (ln * ln) `rem` p)
- modifySTRef' se (`B.shiftR` 1)
+ modifySTRef' se (`I.integerShiftR` 1)
loop sr snum se
-- group bytelength
-_GROUP_BYTELENGTH :: Integral a => a
+_GROUP_BYTELENGTH :: Int
_GROUP_BYTELENGTH = 32
-- curve point
-data Affine a = Affine !a !a
+data Affine = Affine Integer Integer
deriving stock (Show, Generic)
-instance Integral a => Eq (Affine a) where
+instance Eq Affine where
Affine x1 y1 == Affine x2 y2 =
mods x1 == mods x2 && mods y1 == mods y2
-data Projective a = Projective {
- px :: !a
- , py :: !a
- , pz :: !a
+data Projective = Projective {
+ px :: Integer
+ , py :: Integer
+ , pz :: Integer
}
deriving stock (Show, Generic)
-instance Integral a => Eq (Projective a) where
+instance Eq Projective where
Projective ax ay az == Projective bx by bz =
let x1z2 = mods (ax * bz)
x2z1 = mods (bx * az)
@@ -158,22 +109,22 @@ instance Integral a => Eq (Projective a) where
y2z1 = mods (by * az)
in x1z2 == x2z1 && y1z2 == y2z1
-_ZERO :: Integral a => Projective a
+_ZERO :: Projective
_ZERO = Projective 0 1 0
-_BASE :: Integral a => Projective a
+_BASE :: Projective
_BASE = Projective (curve_gx secp256k1) (curve_gy secp256k1) 1
-- negate point
-neg :: (Integral a, Num a) => Projective a -> Projective a
+neg :: Projective -> Projective
neg (Projective x y z) = Projective x (mods (negate y)) z
--- XX correct?
-add_affine :: (Integral a, Num a) => Affine a -> Affine a -> Maybe (Affine a)
+-- -- XX correct?
+add_affine :: Affine -> Affine -> Maybe Affine
add_affine p@(Affine x1 y1) q@(Affine x2 y2)
| p == q && (p == azero || q == azero) = pure azero
| p == q = do
- i <- modinv (mods (2 * y1)) (curve_p secp256k1)
+ i <- fmap fromIntegral (modinv (mods (2 * y1)) (curve_p secp256k1))
let s = mods (mods (3 * mods (x1 * x1)) * i)
x = mods (mods (s * s) - mods (2 * x1))
y = mods (mods (s * mods (x1 - x)) - y1)
@@ -182,7 +133,7 @@ add_affine p@(Affine x1 y1) q@(Affine x2 y2)
| x2 == 0 && y2 == 0 = pure p
| x1 == x2 = pure azero
| otherwise = do
- i <- modinv (mods (x1 - x2)) (curve_p secp256k1)
+ i <- fmap fromIntegral (modinv (mods (x1 - x2)) (curve_p secp256k1))
let s = mods (mods (y1 - y2) * i)
x3 = mods (mods (s * s) - x1 - x2)
y3 = mods (mods (s * mods (x1 - x3)) - y1)
@@ -190,9 +141,10 @@ add_affine p@(Affine x1 y1) q@(Affine x2 y2)
where
azero = Affine 0 0
+
-- algo 1, "complete addition formulas for prime order elliptic curves,"
-- renes et al, 2015
-add :: (Integral a, Num a) => Projective a -> Projective a -> Projective a
+add :: Projective -> Projective -> Projective
add (Projective x1 y1 z1) (Projective x2 y2 z2) = runST $ do
let a = curve_a secp256k1
b = curve_b secp256k1
@@ -278,7 +230,7 @@ add (Projective x1 y1 z1) (Projective x2 y2 z2) = runST $ do
modifySTRef' z3 (\rz3 -> mods (rz3 + r0))
Projective <$> readSTRef x3 <*> readSTRef y3 <*> readSTRef z3
-add_pure :: (Integral a, Num a) => Projective a -> Projective a -> Projective a
+add_pure :: Projective -> Projective -> Projective
add_pure (Projective x1 y1 z1) (Projective x2 y2 z2) =
let a = curve_a secp256k1
b = curve_b secp256k1
@@ -327,7 +279,7 @@ add_pure (Projective x1 y1 z1) (Projective x2 y2 z2) =
-- algo 7, "complete addition formulas for prime order elliptic curves,"
-- renes et al, 2015
-add' :: (Integral a, Num a) => Projective a -> Projective a -> Projective a
+add' :: Projective -> Projective -> Projective
add' (Projective x1 y1 z1) (Projective x2 y2 z2) = runST $ do
let b = curve_b secp256k1
x3 <- newSTRef 0
@@ -398,34 +350,34 @@ add' (Projective x1 y1 z1) (Projective x2 y2 z2) = runST $ do
Projective <$> readSTRef x3 <*> readSTRef y3 <*> readSTRef z3
-- double a point
-double :: (Integral a, Num a) => Projective a -> Projective a
+double :: Projective -> Projective
double p = add p p
-- to affine coordinates
-affine :: Integral a => Projective a -> Maybe (Affine a)
+affine :: Projective -> Maybe Affine
affine p@(Projective x y z)
| p == _ZERO = pure (Affine 0 0)
| z == 1 = pure (Affine x y)
| otherwise = do
- iz <- modinv z (curve_p secp256k1)
+ iz <- fmap fromIntegral (modinv z (curve_p secp256k1))
if mods (z * iz) /= 1
then Nothing
else pure (Affine (mods (x * iz)) (mods (y * iz)))
-- partial affine
-affine' :: Integral a => Projective a -> Affine a
+affine' :: Projective -> Affine
affine' p = case affine p of
Nothing -> error "bang"
Just x -> x
-- to projective coordinates
-projective :: Integral a => Affine a -> Projective a
+projective :: Affine -> Projective
projective (Affine x y)
| x == 0 && y == 0 = _ZERO
| otherwise = Projective x y 1
-- point is valid
-valid :: Integral a => Projective a -> Bool
+valid :: Projective -> Bool
valid p = case affine p of
Nothing -> False
Just (Affine x y)
@@ -434,8 +386,8 @@ valid p = case affine p of
| otherwise -> True
-- parse hex-encoded
-parse_point :: (Bits a, Integral a) => BS.ByteString -> Maybe (Projective a)
-parse_point (B16.decode -> ebs) = case ebs of
+parse :: BS.ByteString -> Maybe Projective
+parse (B16.decode -> ebs) = case ebs of
Left _ -> Nothing
Right bs -> case BS.uncons bs of
Nothing -> Nothing
@@ -447,8 +399,8 @@ parse_point (B16.decode -> ebs) = case ebs of
then Nothing
else do
y <- modsqrt (weierstrass x)
- let yodd = y .&. 1 == 1
- hodd = h .&. 1 == 1
+ let yodd = I.integerTestBit y 0
+ hodd = I.integerTestBit (fromIntegral h) 0
pure $
if hodd /= yodd
then Projective x (mods (negate y)) 1
@@ -462,7 +414,7 @@ parse_point (B16.decode -> ebs) = case ebs of
else Nothing
-- big-endian bytestring decoding
-roll :: (Bits a, Integral a) => BS.ByteString -> a
+roll :: BS.ByteString -> Integer
roll = BS.foldl' unstep 0 where
- unstep a b = a `B.shiftL` 8 .|. fromIntegral b
+ unstep a b = (a `I.integerShiftL` 8) `I.integerOr` fromIntegral b
diff --git a/ppad-secp256k1.cabal b/ppad-secp256k1.cabal
@@ -41,6 +41,7 @@ test-suite secp256k1-tests
build-depends:
base
+ , base16-bytestring
, bytestring
, ppad-secp256k1
, tasty
diff --git a/test/Main.hs b/test/Main.hs
@@ -1,4 +1,155 @@
+{-# LANGUAGE OverloadedStrings #-}
+
module Main where
-main :: IO
-main = pure ()
+import qualified Data.ByteString as BS
+import Crypto.Secp256k1
+import Test.Tasty
+import Test.Tasty.HUnit
+
+main :: IO ()
+main = defaultMain units
+
+units :: TestTree
+units = testGroup "unit tests" [
+ parse_tests
+ , add_tests
+ , add_pure_tests
+ , add'_tests
+ ]
+
+parse_tests :: TestTree
+parse_tests = testGroup "parse tests" [
+ parse_test_p
+ , parse_test_q
+ , parse_test_r
+ ]
+
+render :: Show a => a -> String
+render = filter (`notElem` ("\"" :: String)) . show
+
+-- XX replace these with something non-stupid
+parse_test_p :: TestTree
+parse_test_p = testCase (render p_hex) $ case parse p_hex of
+ Nothing -> assertFailure "bad parse"
+ Just p -> assertEqual mempty p_pro p
+
+parse_test_q :: TestTree
+parse_test_q = testCase (render q_hex) $ case parse q_hex of
+ Nothing -> assertFailure "bad parse"
+ Just q -> assertEqual mempty q_pro q
+
+parse_test_r :: TestTree
+parse_test_r = testCase (render r_hex) $ case parse r_hex of
+ Nothing -> assertFailure "bad parse"
+ Just r -> assertEqual mempty r_pro r
+
+-- XX also make less dumb
+add_tests :: TestTree
+add_tests = testGroup "ec addition, algo 1" [
+ add_test_pq
+ , add_test_pr
+ , add_test_qr
+ ]
+
+add'_tests :: TestTree
+add'_tests = testGroup "ec addition, algo 7" [
+ add'_test_pq
+ , add'_test_pr
+ , add'_test_qr
+ ]
+
+add_pure_tests :: TestTree
+add_pure_tests = testGroup "ec addition, algo 1, pure" [
+ add_pure_test_pq
+ , add_pure_test_pr
+ , add_pure_test_qr
+ ]
+
+add_test_pq :: TestTree
+add_test_pq = testCase "p + q" $
+ assertEqual mempty pq_pro (p_pro `add` q_pro)
+
+add_test_pr :: TestTree
+add_test_pr = testCase "p + r" $
+ assertEqual mempty pr_pro (p_pro `add` r_pro)
+
+add_test_qr :: TestTree
+add_test_qr = testCase "q + r" $
+ assertEqual mempty qr_pro (q_pro `add` r_pro)
+
+add'_test_pq :: TestTree
+add'_test_pq = testCase "p + q" $
+ assertEqual mempty pq_pro (p_pro `add'` q_pro)
+
+add'_test_pr :: TestTree
+add'_test_pr = testCase "p + r" $
+ assertEqual mempty pr_pro (p_pro `add'` r_pro)
+
+add'_test_qr :: TestTree
+add'_test_qr = testCase "q + r" $
+ assertEqual mempty qr_pro (q_pro `add'` r_pro)
+
+add_pure_test_pq :: TestTree
+add_pure_test_pq = testCase "p + q" $
+ assertEqual mempty pq_pro (p_pro `add_pure` q_pro)
+
+add_pure_test_pr :: TestTree
+add_pure_test_pr = testCase "p + r" $
+ assertEqual mempty pr_pro (p_pro `add_pure` r_pro)
+
+add_pure_test_qr :: TestTree
+add_pure_test_qr = testCase "q + r" $
+ assertEqual mempty qr_pro (q_pro `add_pure` r_pro)
+
+p_hex :: BS.ByteString
+p_hex = "0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798"
+
+p_pro :: Projective
+p_pro = Projective {
+ px = 55066263022277343669578718895168534326250603453777594175500187360389116729240
+ , py = 32670510020758816978083085130507043184471273380659243275938904335757337482424
+ , pz = 1
+ }
+
+q_hex :: BS.ByteString
+q_hex = "02f9308a019258c31049344f85f89d5229b531c845836f99b08601f113bce036f9"
+
+q_pro :: Projective
+q_pro = Projective {
+ px = 112711660439710606056748659173929673102114977341539408544630613555209775888121
+ , py = 25583027980570883691656905877401976406448868254816295069919888960541586679410
+ , pz = 1
+ }
+
+r_hex :: BS.ByteString
+r_hex = "03a2113cf152585d96791a42cdd78782757fbfb5c6b2c11b59857eb4f7fda0b0e8"
+
+r_pro :: Projective
+r_pro = Projective {
+ px = 73305138481390301074068425511419969342201196102229546346478796034582161436904
+ , py = 77311080844824646227678701997218206005272179480834599837053144390237051080427
+ , pz = 1
+ }
+
+pq_pro :: Projective
+pq_pro = Projective {
+ px = 52396973184413144605737087313078368553350360735730295164507742012595395307648
+ , py = 81222895265056120475581324527268307707868393868711445371362592923687074369515
+ , pz = 57410578768022213246260942140297839801661445014943088692963835122150180187279
+ }
+
+pr_pro :: Projective
+pr_pro = Projective {
+ px = 1348700846815225554023000535566992225745844759459188830982575724903956130228
+ , py = 36170035245379023681754688218456726199360176620640420471087552839246039945572
+ , pz = 92262311556350124501370727779827867637071338628440636251794554773617634796873
+ }
+
+qr_pro :: Projective
+qr_pro = Projective {
+ px = 98601662106226486891738184090788320295235665172235527697419658886981126285906
+ , py = 18578813777775793862159229516827464252856752093683109113431170463916250542461
+ , pz = 56555634785712334774735413904899958905472439323190450522613637299635410127585
+ }
+