secp256k1

Pure Haskell cryptographic primitives on the secp256k1 elliptic curve.
git clone git://git.ppad.tech/secp256k1.git
Log | Files | Refs | LICENSE

commit ed2c4feab308b96659fefd982c68e95c121af91b
parent af6f564ef97aa00021355212aad67f96bb91f0ec
Author: Jared Tobin <jared@jtobin.io>
Date:   Mon, 25 Mar 2024 10:33:41 +0400

lib: add scalar multiplication, s/mods/mod

Diffstat:
Mlib/Crypto/Secp256k1.hs | 254+++++++++++++++++++++++++++++++++++++++++--------------------------------------
1 file changed, 133 insertions(+), 121 deletions(-)

diff --git a/lib/Crypto/Secp256k1.hs b/lib/Crypto/Secp256k1.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE MagicHash #-} @@ -14,6 +15,7 @@ import Data.STRef import GHC.Generics import GHC.Natural import qualified GHC.Num.Integer as I +import Prelude hiding (mod) _B256 :: Integer _B256 = 2 ^ (256 :: Integer) @@ -43,8 +45,8 @@ _CURVE_GY :: Integer _CURVE_GY = 0x483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8 -- modular division by secp256k1 group order -mods :: Integer -> Integer -mods a = I.integerMod a _CURVE_P +mod :: Integer -> Integer +mod a = I.integerMod a _CURVE_P -- is field element (i.e., is invertible) fe :: Integer -> Bool @@ -57,7 +59,7 @@ ge n = 0 < n && n < _CURVE_N -- for a, m return x such that ax = 1 mod m modinv :: Integer -> Natural -> Maybe Integer modinv a m = case I.integerRecipMod# a m of - (# n | #) -> Just (fromIntegral n) + (# fromIntegral -> n | #) -> Just n (# | _ #) -> Nothing -- modular square root (shanks-tonelli) @@ -69,7 +71,7 @@ modsqrt n = runST $ do loop r num e rr <- readSTRef r pure $ - if mods (rr * rr) == n + if mod (rr * rr) == n then Just rr else Nothing where @@ -85,28 +87,28 @@ modsqrt n = runST $ do -- prime order j-invariant 0 (i.e. a == 0) weierstrass :: Integer -> Integer -weierstrass x = mods (mods (x * x) * x + _CURVE_B) +weierstrass x = mod (mod (x * x) * x + _CURVE_B) data Affine = Affine Integer Integer deriving stock (Show, Generic) instance Eq Affine where Affine x1 y1 == Affine x2 y2 = - mods x1 == mods x2 && mods y1 == mods y2 + mod x1 == mod x2 && mod y1 == mod y2 data Projective = Projective { - px :: Integer - , py :: Integer - , pz :: Integer + px :: !Integer + , py :: !Integer + , pz :: !Integer } deriving stock (Show, Generic) instance Eq Projective where Projective ax ay az == Projective bx by bz = - let x1z2 = mods (ax * bz) - x2z1 = mods (bx * az) - y1z2 = mods (ay * bz) - y2z1 = mods (by * az) + let x1z2 = mod (ax * bz) + x2z1 = mod (bx * az) + y1z2 = mod (ay * bz) + y2z1 = mod (by * az) in x1z2 == x2z1 && y1z2 == y2z1 _ZERO :: Projective @@ -117,7 +119,7 @@ _BASE = Projective _CURVE_GX _CURVE_GY 1 -- negate point neg :: Projective -> Projective -neg (Projective x y z) = Projective x (mods (negate y)) z +neg (Projective x y z) = Projective x (mod (negate y)) z -- general ec addition add :: Projective -> Projective -> Projective @@ -133,125 +135,125 @@ add_proj (Projective x1 y1 z1) (Projective x2 y2 z2) = runST $ do x3 <- newSTRef 0 y3 <- newSTRef 0 z3 <- newSTRef 0 - let b3 = mods (_CURVE_B * 3) - t0 <- newSTRef (mods (x1 * x2)) -- 1 - t1 <- newSTRef (mods (y1 * y2)) - t2 <- newSTRef (mods (z1 * z2)) - t3 <- newSTRef (mods (x1 + y1)) -- 4 - t4 <- newSTRef (mods (x2 + y2)) + let b3 = mod (_CURVE_B * 3) + t0 <- newSTRef (mod (x1 * x2)) -- 1 + t1 <- newSTRef (mod (y1 * y2)) + t2 <- newSTRef (mod (z1 * z2)) + t3 <- newSTRef (mod (x1 + y1)) -- 4 + t4 <- newSTRef (mod (x2 + y2)) readSTRef t4 >>= \r4 -> - modifySTRef' t3 (\r3 -> mods (r3 * r4)) + modifySTRef' t3 (\r3 -> mod (r3 * r4)) readSTRef t0 >>= \r0 -> readSTRef t1 >>= \r1 -> - writeSTRef t4 (mods (r0 + r1)) + writeSTRef t4 (mod (r0 + r1)) readSTRef t4 >>= \r4 -> - modifySTRef' t3 (\r3 -> mods (r3 - r4)) -- 8 - writeSTRef t4 (mods (y1 + z1)) - writeSTRef x3 (mods (y2 + z2)) + modifySTRef' t3 (\r3 -> mod (r3 - r4)) -- 8 + writeSTRef t4 (mod (y1 + z1)) + writeSTRef x3 (mod (y2 + z2)) readSTRef x3 >>= \rx3 -> - modifySTRef' t4 (\r4 -> mods (r4 * rx3)) + modifySTRef' t4 (\r4 -> mod (r4 * rx3)) readSTRef t1 >>= \r1 -> readSTRef t2 >>= \r2 -> - writeSTRef x3 (mods (r1 + r2)) -- 12 + writeSTRef x3 (mod (r1 + r2)) -- 12 readSTRef x3 >>= \rx3 -> - modifySTRef' t4 (\r4 -> mods (r4 - rx3)) - writeSTRef x3 (mods (x1 + z1)) - writeSTRef y3 (mods (x2 + z2)) + modifySTRef' t4 (\r4 -> mod (r4 - rx3)) + writeSTRef x3 (mod (x1 + z1)) + writeSTRef y3 (mod (x2 + z2)) readSTRef y3 >>= \ry3 -> - modifySTRef' x3 (\rx3 -> mods (rx3 * ry3)) -- 16 + modifySTRef' x3 (\rx3 -> mod (rx3 * ry3)) -- 16 readSTRef t0 >>= \r0 -> readSTRef t2 >>= \r2 -> - writeSTRef y3 (mods (r0 + r2)) + writeSTRef y3 (mod (r0 + r2)) readSTRef x3 >>= \rx3 -> - modifySTRef' y3 (\ry3 -> mods (rx3 - ry3)) + modifySTRef' y3 (\ry3 -> mod (rx3 - ry3)) readSTRef t0 >>= \r0 -> - writeSTRef x3 (mods (r0 + r0)) + writeSTRef x3 (mod (r0 + r0)) readSTRef x3 >>= \rx3 -> - modifySTRef t0 (\r0 -> mods (rx3 + r0)) -- 20 - modifySTRef' t2 (\r2 -> mods (b3 * r2)) + modifySTRef t0 (\r0 -> mod (rx3 + r0)) -- 20 + modifySTRef' t2 (\r2 -> mod (b3 * r2)) readSTRef t1 >>= \r1 -> readSTRef t2 >>= \r2 -> - writeSTRef z3 (mods (r1 + r2)) + writeSTRef z3 (mod (r1 + r2)) readSTRef t2 >>= \r2 -> - modifySTRef' t1 (\r1 -> mods (r1 - r2)) - modifySTRef' y3 (\ry3 -> mods (b3 * ry3)) -- 24 + modifySTRef' t1 (\r1 -> mod (r1 - r2)) + modifySTRef' y3 (\ry3 -> mod (b3 * ry3)) -- 24 readSTRef t4 >>= \r4 -> readSTRef y3 >>= \ry3 -> - writeSTRef x3 (mods (r4 * ry3)) + writeSTRef x3 (mod (r4 * ry3)) readSTRef t3 >>= \r3 -> readSTRef t1 >>= \r1 -> - writeSTRef t2 (mods (r3 * r1)) + writeSTRef t2 (mod (r3 * r1)) readSTRef t2 >>= \r2 -> - modifySTRef' x3 (\rx3 -> mods (r2 - rx3)) + modifySTRef' x3 (\rx3 -> mod (r2 - rx3)) readSTRef t0 >>= \r0 -> - modifySTRef' y3 (\ry3 -> mods (ry3 * r0)) -- 28 + modifySTRef' y3 (\ry3 -> mod (ry3 * r0)) -- 28 readSTRef z3 >>= \rz3 -> - modifySTRef' t1 (\r1 -> mods (r1 * rz3)) + modifySTRef' t1 (\r1 -> mod (r1 * rz3)) readSTRef t1 >>= \r1 -> - modifySTRef' y3 (\ry3 -> mods (r1 + ry3)) + modifySTRef' y3 (\ry3 -> mod (r1 + ry3)) readSTRef t3 >>= \r3 -> - modifySTRef' t0 (\r0 -> mods (r0 * r3)) + modifySTRef' t0 (\r0 -> mod (r0 * r3)) readSTRef t4 >>= \r4 -> - modifySTRef' z3 (\rz3 -> mods (rz3 * r4)) -- 32 + modifySTRef' z3 (\rz3 -> mod (rz3 * r4)) -- 32 readSTRef t0 >>= \r0 -> - modifySTRef' z3 (\rz3 -> mods (rz3 + r0)) + modifySTRef' z3 (\rz3 -> mod (rz3 + r0)) Projective <$> readSTRef x3 <*> readSTRef y3 <*> readSTRef z3 -- algo 8, renes et al, 2015 add_mixed :: Projective -> Projective -> Projective add_mixed (Projective x1 y1 z1) (Projective x2 y2 z2) - | z2 /= 1 = error "secp256k1: internal error" + | z2 /= 1 = error "ppad-secp256k1: internal error" | otherwise = runST $ do x3 <- newSTRef 0 y3 <- newSTRef 0 z3 <- newSTRef 0 - let b3 = mods (_CURVE_B * 3) - t0 <- newSTRef (mods (x1 * x2)) -- 1 - t1 <- newSTRef (mods (y1 * y2)) - t3 <- newSTRef (mods (x2 + y2)) - t4 <- newSTRef (mods (x1 + y1)) -- 4 + let b3 = mod (_CURVE_B * 3) + t0 <- newSTRef (mod (x1 * x2)) -- 1 + t1 <- newSTRef (mod (y1 * y2)) + t3 <- newSTRef (mod (x2 + y2)) + t4 <- newSTRef (mod (x1 + y1)) -- 4 readSTRef t4 >>= \r4 -> - modifySTRef' t3 (\r3 -> mods (r3 * r4)) + modifySTRef' t3 (\r3 -> mod (r3 * r4)) readSTRef t0 >>= \r0 -> readSTRef t1 >>= \r1 -> - writeSTRef t4 (mods (r0 + r1)) + writeSTRef t4 (mod (r0 + r1)) readSTRef t4 >>= \r4 -> - modifySTRef' t3 (\r3 -> mods (r3 - r4)) -- 7 - writeSTRef t4 (mods (y2 * z1)) - modifySTRef' t4 (\r4 -> mods (r4 + y1)) - writeSTRef y3 (mods (x2 * z1)) -- 10 - modifySTRef' y3 (\ry3 -> mods (ry3 + x1)) + modifySTRef' t3 (\r3 -> mod (r3 - r4)) -- 7 + writeSTRef t4 (mod (y2 * z1)) + modifySTRef' t4 (\r4 -> mod (r4 + y1)) + writeSTRef y3 (mod (x2 * z1)) -- 10 + modifySTRef' y3 (\ry3 -> mod (ry3 + x1)) readSTRef t0 >>= \r0 -> - writeSTRef x3 (mods (r0 + r0)) + writeSTRef x3 (mod (r0 + r0)) readSTRef x3 >>= \rx3 -> - modifySTRef' t0 (\r0 -> mods (rx3 + r0)) -- 13 - t2 <- newSTRef (mods (b3 * z1)) + modifySTRef' t0 (\r0 -> mod (rx3 + r0)) -- 13 + t2 <- newSTRef (mod (b3 * z1)) readSTRef t1 >>= \r1 -> readSTRef t2 >>= \r2 -> - writeSTRef z3 (mods (r1 + r2)) + writeSTRef z3 (mod (r1 + r2)) readSTRef t2 >>= \r2 -> - modifySTRef' t1 (\r1 -> mods (r1 - r2)) -- 16 - modifySTRef' y3 (\ry3 -> mods (b3 * ry3)) + modifySTRef' t1 (\r1 -> mod (r1 - r2)) -- 16 + modifySTRef' y3 (\ry3 -> mod (b3 * ry3)) readSTRef t4 >>= \r4 -> readSTRef y3 >>= \ry3 -> - writeSTRef x3 (mods (r4 * ry3)) + writeSTRef x3 (mod (r4 * ry3)) readSTRef t3 >>= \r3 -> readSTRef t1 >>= \r1 -> - writeSTRef t2 (mods (r3 * r1)) -- 19 + writeSTRef t2 (mod (r3 * r1)) -- 19 readSTRef t2 >>= \r2 -> - modifySTRef' x3 (\rx3 -> mods (r2 - rx3)) + modifySTRef' x3 (\rx3 -> mod (r2 - rx3)) readSTRef t0 >>= \r0 -> - modifySTRef' y3 (\ry3 -> mods (ry3 * r0)) + modifySTRef' y3 (\ry3 -> mod (ry3 * r0)) readSTRef z3 >>= \rz3 -> - modifySTRef' t1 (\r1 -> mods (r1 * rz3)) -- 22 + modifySTRef' t1 (\r1 -> mod (r1 * rz3)) -- 22 readSTRef t1 >>= \r1 -> - modifySTRef' y3 (\ry3 -> mods (r1 + ry3)) + modifySTRef' y3 (\ry3 -> mod (r1 + ry3)) readSTRef t3 >>= \r3 -> - modifySTRef' t0 (\r0 -> mods (r0 * r3)) + modifySTRef' t0 (\r0 -> mod (r0 * r3)) readSTRef t4 >>= \r4 -> - modifySTRef' z3 (\rz3 -> mods (rz3 * r4)) -- 25 + modifySTRef' z3 (\rz3 -> mod (rz3 * r4)) -- 25 readSTRef t0 >>= \r0 -> - modifySTRef' z3 (\rz3 -> mods (rz3 + r0)) + modifySTRef' z3 (\rz3 -> mod (rz3 + r0)) Projective <$> readSTRef x3 <*> readSTRef y3 <*> readSTRef z3 -- algo 9, renes et al, 2015 @@ -260,59 +262,69 @@ double (Projective x y z) = runST $ do x3 <- newSTRef 0 y3 <- newSTRef 0 z3 <- newSTRef 0 - let b3 = mods (_CURVE_B * 3) - t0 <- newSTRef (mods (y * y)) -- 1 + let b3 = mod (_CURVE_B * 3) + t0 <- newSTRef (mod (y * y)) -- 1 readSTRef t0 >>= \r0 -> - writeSTRef z3 (mods (r0 + r0)) - modifySTRef' z3 (\rz3 -> mods (rz3 + rz3)) - modifySTRef' z3 (\rz3 -> mods (rz3 + rz3)) -- 4 - t1 <- newSTRef (mods (y * z)) - t2 <- newSTRef (mods (z * z)) - modifySTRef t2 (\r2 -> mods (b3 * r2)) -- 7 + writeSTRef z3 (mod (r0 + r0)) + modifySTRef' z3 (\rz3 -> mod (rz3 + rz3)) + modifySTRef' z3 (\rz3 -> mod (rz3 + rz3)) -- 4 + t1 <- newSTRef (mod (y * z)) + t2 <- newSTRef (mod (z * z)) + modifySTRef t2 (\r2 -> mod (b3 * r2)) -- 7 readSTRef z3 >>= \rz3 -> readSTRef t2 >>= \r2 -> - writeSTRef x3 (mods (r2 * rz3)) + writeSTRef x3 (mod (r2 * rz3)) readSTRef t0 >>= \r0 -> readSTRef t2 >>= \r2 -> - writeSTRef y3 (mods (r0 + r2)) + writeSTRef y3 (mod (r0 + r2)) readSTRef t1 >>= \r1 -> - modifySTRef' z3 (\rz3 -> mods (r1 * rz3)) -- 10 + modifySTRef' z3 (\rz3 -> mod (r1 * rz3)) -- 10 readSTRef t2 >>= \r2 -> - writeSTRef t1 (mods (r2 + r2)) + writeSTRef t1 (mod (r2 + r2)) readSTRef t1 >>= \r1 -> - modifySTRef' t2 (\r2 -> mods (r1 + r2)) + modifySTRef' t2 (\r2 -> mod (r1 + r2)) readSTRef t2 >>= \r2 -> - modifySTRef' t0 (\r0 -> mods (r0 - r2)) -- 13 + modifySTRef' t0 (\r0 -> mod (r0 - r2)) -- 13 readSTRef t0 >>= \r0 -> - modifySTRef' y3 (\ry3 -> mods (r0 * ry3)) + modifySTRef' y3 (\ry3 -> mod (r0 * ry3)) readSTRef x3 >>= \rx3 -> - modifySTRef' y3 (\ry3 -> mods (rx3 + ry3)) - writeSTRef t1 (mods (x * y)) -- 16 + modifySTRef' y3 (\ry3 -> mod (rx3 + ry3)) + writeSTRef t1 (mod (x * y)) -- 16 readSTRef t0 >>= \r0 -> readSTRef t1 >>= \r1 -> - writeSTRef x3 (mods (r0 * r1)) - modifySTRef' x3 (\rx3 -> mods (rx3 + rx3)) + writeSTRef x3 (mod (r0 * r1)) + modifySTRef' x3 (\rx3 -> mod (rx3 + rx3)) Projective <$> readSTRef x3 <*> readSTRef y3 <*> readSTRef z3 - - --- mul(n, safe = true) { --- if (!safe && n === 0n) --- return I; // in unsafe mode, allow zero --- if (!ge(n)) --- err('invalid scalar'); // must be 0 < n < CURVE.n --- if (this.equals(G)) --- return wNAF(n).p; // use precomputes for base point --- let p = I, f = G; // init result point & fake point --- for (let d = this; n > 0n; d = d.double(), n >>= 1n) { // double-and-add ladder --- if (n & 1n) --- p = p.add(d); // if bit is present, add to point --- else if (safe) --- f = f.add(d); // if not, add to fake for timing safety --- } --- return p; --- } - +mul :: Projective -> Integer -> Projective +mul p n + | n == 0 = _ZERO + | not (ge n) = error "ppad-secp256k1 (mul): scalar not in group" + | otherwise = loop _ZERO p n + where + loop !r !d m + | m <= 0 = r + | otherwise = + let nd = double d + nm = I.integerShiftR m 1 + nr = if I.integerTestBit m 0 then add r d else r + in loop nr nd nm + +-- XX confirm nf evaluation +-- timing safety +mul_safe :: Projective -> Integer -> Projective +mul_safe p n + | not (ge n) = error "ppad-secp256k1 (mul_safe): scalar not in group" + | otherwise = loop _ZERO _BASE p n + where + loop !r !f !d m + | m <= 0 = r + | otherwise = + let nd = double d + nm = I.integerShiftR m 1 + in if I.integerTestBit m 0 + then loop (add r d) f nd nm + else loop r (add f d) nd nm -- to affine coordinates affine :: Projective -> Maybe Affine @@ -321,9 +333,9 @@ affine p@(Projective x y z) | z == 1 = pure (Affine x y) | otherwise = do iz <- modinv z (fromIntegral _CURVE_P) - if mods (z * iz) /= 1 + if mod (z * iz) /= 1 then Nothing - else pure (Affine (mods (x * iz)) (mods (y * iz))) + else pure (Affine (mod (x * iz)) (mod (y * iz))) -- to projective coordinates projective :: Affine -> Projective @@ -337,7 +349,7 @@ valid p = case affine p of Nothing -> False Just (Affine x y) | not (fe x) || not (fe y) -> False - | mods (y * y) /= weierstrass x -> False + | mod (y * y) /= weierstrass x -> False | otherwise -> True -- parse hex-encoded point @@ -346,7 +358,7 @@ parse (B16.decode -> ebs) = case ebs of Left _ -> Nothing Right bs -> case BS.uncons bs of Nothing -> Nothing - Just (h, t) -> + Just (fromIntegral -> h, t) -> let (roll -> x, etc) = BS.splitAt _GROUP_BYTELENGTH t len = BS.length bs in if len == 33 && (h == 0x02 || h == 0x03) -- compressed @@ -355,10 +367,10 @@ parse (B16.decode -> ebs) = case ebs of else do y <- modsqrt (weierstrass x) let yodd = I.integerTestBit y 0 - hodd = I.integerTestBit (fromIntegral h) 0 + hodd = I.integerTestBit h 0 pure $ if hodd /= yodd - then Projective x (mods (negate y)) 1 + then Projective x (mod (negate y)) 1 else Projective x y 1 else if len == 65 && h == 0x04 -- uncompressed then let (roll -> y, _) = BS.splitAt _GROUP_BYTELENGTH etc @@ -374,5 +386,5 @@ parse (B16.decode -> ebs) = case ebs of -- big-endian bytestring decoding roll :: BS.ByteString -> Integer roll = BS.foldl' unstep 0 where - unstep a b = (a `I.integerShiftL` 8) `I.integerOr` fromIntegral b + unstep a (fromIntegral -> b) = (a `I.integerShiftL` 8) `I.integerOr` b