commit ed2c4feab308b96659fefd982c68e95c121af91b
parent af6f564ef97aa00021355212aad67f96bb91f0ec
Author: Jared Tobin <jared@jtobin.io>
Date: Mon, 25 Mar 2024 10:33:41 +0400
lib: add scalar multiplication, s/mods/mod
Diffstat:
1 file changed, 133 insertions(+), 121 deletions(-)
diff --git a/lib/Crypto/Secp256k1.hs b/lib/Crypto/Secp256k1.hs
@@ -1,3 +1,4 @@
+{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE MagicHash #-}
@@ -14,6 +15,7 @@ import Data.STRef
import GHC.Generics
import GHC.Natural
import qualified GHC.Num.Integer as I
+import Prelude hiding (mod)
_B256 :: Integer
_B256 = 2 ^ (256 :: Integer)
@@ -43,8 +45,8 @@ _CURVE_GY :: Integer
_CURVE_GY = 0x483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8
-- modular division by secp256k1 group order
-mods :: Integer -> Integer
-mods a = I.integerMod a _CURVE_P
+mod :: Integer -> Integer
+mod a = I.integerMod a _CURVE_P
-- is field element (i.e., is invertible)
fe :: Integer -> Bool
@@ -57,7 +59,7 @@ ge n = 0 < n && n < _CURVE_N
-- for a, m return x such that ax = 1 mod m
modinv :: Integer -> Natural -> Maybe Integer
modinv a m = case I.integerRecipMod# a m of
- (# n | #) -> Just (fromIntegral n)
+ (# fromIntegral -> n | #) -> Just n
(# | _ #) -> Nothing
-- modular square root (shanks-tonelli)
@@ -69,7 +71,7 @@ modsqrt n = runST $ do
loop r num e
rr <- readSTRef r
pure $
- if mods (rr * rr) == n
+ if mod (rr * rr) == n
then Just rr
else Nothing
where
@@ -85,28 +87,28 @@ modsqrt n = runST $ do
-- prime order j-invariant 0 (i.e. a == 0)
weierstrass :: Integer -> Integer
-weierstrass x = mods (mods (x * x) * x + _CURVE_B)
+weierstrass x = mod (mod (x * x) * x + _CURVE_B)
data Affine = Affine Integer Integer
deriving stock (Show, Generic)
instance Eq Affine where
Affine x1 y1 == Affine x2 y2 =
- mods x1 == mods x2 && mods y1 == mods y2
+ mod x1 == mod x2 && mod y1 == mod y2
data Projective = Projective {
- px :: Integer
- , py :: Integer
- , pz :: Integer
+ px :: !Integer
+ , py :: !Integer
+ , pz :: !Integer
}
deriving stock (Show, Generic)
instance Eq Projective where
Projective ax ay az == Projective bx by bz =
- let x1z2 = mods (ax * bz)
- x2z1 = mods (bx * az)
- y1z2 = mods (ay * bz)
- y2z1 = mods (by * az)
+ let x1z2 = mod (ax * bz)
+ x2z1 = mod (bx * az)
+ y1z2 = mod (ay * bz)
+ y2z1 = mod (by * az)
in x1z2 == x2z1 && y1z2 == y2z1
_ZERO :: Projective
@@ -117,7 +119,7 @@ _BASE = Projective _CURVE_GX _CURVE_GY 1
-- negate point
neg :: Projective -> Projective
-neg (Projective x y z) = Projective x (mods (negate y)) z
+neg (Projective x y z) = Projective x (mod (negate y)) z
-- general ec addition
add :: Projective -> Projective -> Projective
@@ -133,125 +135,125 @@ add_proj (Projective x1 y1 z1) (Projective x2 y2 z2) = runST $ do
x3 <- newSTRef 0
y3 <- newSTRef 0
z3 <- newSTRef 0
- let b3 = mods (_CURVE_B * 3)
- t0 <- newSTRef (mods (x1 * x2)) -- 1
- t1 <- newSTRef (mods (y1 * y2))
- t2 <- newSTRef (mods (z1 * z2))
- t3 <- newSTRef (mods (x1 + y1)) -- 4
- t4 <- newSTRef (mods (x2 + y2))
+ let b3 = mod (_CURVE_B * 3)
+ t0 <- newSTRef (mod (x1 * x2)) -- 1
+ t1 <- newSTRef (mod (y1 * y2))
+ t2 <- newSTRef (mod (z1 * z2))
+ t3 <- newSTRef (mod (x1 + y1)) -- 4
+ t4 <- newSTRef (mod (x2 + y2))
readSTRef t4 >>= \r4 ->
- modifySTRef' t3 (\r3 -> mods (r3 * r4))
+ modifySTRef' t3 (\r3 -> mod (r3 * r4))
readSTRef t0 >>= \r0 ->
readSTRef t1 >>= \r1 ->
- writeSTRef t4 (mods (r0 + r1))
+ writeSTRef t4 (mod (r0 + r1))
readSTRef t4 >>= \r4 ->
- modifySTRef' t3 (\r3 -> mods (r3 - r4)) -- 8
- writeSTRef t4 (mods (y1 + z1))
- writeSTRef x3 (mods (y2 + z2))
+ modifySTRef' t3 (\r3 -> mod (r3 - r4)) -- 8
+ writeSTRef t4 (mod (y1 + z1))
+ writeSTRef x3 (mod (y2 + z2))
readSTRef x3 >>= \rx3 ->
- modifySTRef' t4 (\r4 -> mods (r4 * rx3))
+ modifySTRef' t4 (\r4 -> mod (r4 * rx3))
readSTRef t1 >>= \r1 ->
readSTRef t2 >>= \r2 ->
- writeSTRef x3 (mods (r1 + r2)) -- 12
+ writeSTRef x3 (mod (r1 + r2)) -- 12
readSTRef x3 >>= \rx3 ->
- modifySTRef' t4 (\r4 -> mods (r4 - rx3))
- writeSTRef x3 (mods (x1 + z1))
- writeSTRef y3 (mods (x2 + z2))
+ modifySTRef' t4 (\r4 -> mod (r4 - rx3))
+ writeSTRef x3 (mod (x1 + z1))
+ writeSTRef y3 (mod (x2 + z2))
readSTRef y3 >>= \ry3 ->
- modifySTRef' x3 (\rx3 -> mods (rx3 * ry3)) -- 16
+ modifySTRef' x3 (\rx3 -> mod (rx3 * ry3)) -- 16
readSTRef t0 >>= \r0 ->
readSTRef t2 >>= \r2 ->
- writeSTRef y3 (mods (r0 + r2))
+ writeSTRef y3 (mod (r0 + r2))
readSTRef x3 >>= \rx3 ->
- modifySTRef' y3 (\ry3 -> mods (rx3 - ry3))
+ modifySTRef' y3 (\ry3 -> mod (rx3 - ry3))
readSTRef t0 >>= \r0 ->
- writeSTRef x3 (mods (r0 + r0))
+ writeSTRef x3 (mod (r0 + r0))
readSTRef x3 >>= \rx3 ->
- modifySTRef t0 (\r0 -> mods (rx3 + r0)) -- 20
- modifySTRef' t2 (\r2 -> mods (b3 * r2))
+ modifySTRef t0 (\r0 -> mod (rx3 + r0)) -- 20
+ modifySTRef' t2 (\r2 -> mod (b3 * r2))
readSTRef t1 >>= \r1 ->
readSTRef t2 >>= \r2 ->
- writeSTRef z3 (mods (r1 + r2))
+ writeSTRef z3 (mod (r1 + r2))
readSTRef t2 >>= \r2 ->
- modifySTRef' t1 (\r1 -> mods (r1 - r2))
- modifySTRef' y3 (\ry3 -> mods (b3 * ry3)) -- 24
+ modifySTRef' t1 (\r1 -> mod (r1 - r2))
+ modifySTRef' y3 (\ry3 -> mod (b3 * ry3)) -- 24
readSTRef t4 >>= \r4 ->
readSTRef y3 >>= \ry3 ->
- writeSTRef x3 (mods (r4 * ry3))
+ writeSTRef x3 (mod (r4 * ry3))
readSTRef t3 >>= \r3 ->
readSTRef t1 >>= \r1 ->
- writeSTRef t2 (mods (r3 * r1))
+ writeSTRef t2 (mod (r3 * r1))
readSTRef t2 >>= \r2 ->
- modifySTRef' x3 (\rx3 -> mods (r2 - rx3))
+ modifySTRef' x3 (\rx3 -> mod (r2 - rx3))
readSTRef t0 >>= \r0 ->
- modifySTRef' y3 (\ry3 -> mods (ry3 * r0)) -- 28
+ modifySTRef' y3 (\ry3 -> mod (ry3 * r0)) -- 28
readSTRef z3 >>= \rz3 ->
- modifySTRef' t1 (\r1 -> mods (r1 * rz3))
+ modifySTRef' t1 (\r1 -> mod (r1 * rz3))
readSTRef t1 >>= \r1 ->
- modifySTRef' y3 (\ry3 -> mods (r1 + ry3))
+ modifySTRef' y3 (\ry3 -> mod (r1 + ry3))
readSTRef t3 >>= \r3 ->
- modifySTRef' t0 (\r0 -> mods (r0 * r3))
+ modifySTRef' t0 (\r0 -> mod (r0 * r3))
readSTRef t4 >>= \r4 ->
- modifySTRef' z3 (\rz3 -> mods (rz3 * r4)) -- 32
+ modifySTRef' z3 (\rz3 -> mod (rz3 * r4)) -- 32
readSTRef t0 >>= \r0 ->
- modifySTRef' z3 (\rz3 -> mods (rz3 + r0))
+ modifySTRef' z3 (\rz3 -> mod (rz3 + r0))
Projective <$> readSTRef x3 <*> readSTRef y3 <*> readSTRef z3
-- algo 8, renes et al, 2015
add_mixed :: Projective -> Projective -> Projective
add_mixed (Projective x1 y1 z1) (Projective x2 y2 z2)
- | z2 /= 1 = error "secp256k1: internal error"
+ | z2 /= 1 = error "ppad-secp256k1: internal error"
| otherwise = runST $ do
x3 <- newSTRef 0
y3 <- newSTRef 0
z3 <- newSTRef 0
- let b3 = mods (_CURVE_B * 3)
- t0 <- newSTRef (mods (x1 * x2)) -- 1
- t1 <- newSTRef (mods (y1 * y2))
- t3 <- newSTRef (mods (x2 + y2))
- t4 <- newSTRef (mods (x1 + y1)) -- 4
+ let b3 = mod (_CURVE_B * 3)
+ t0 <- newSTRef (mod (x1 * x2)) -- 1
+ t1 <- newSTRef (mod (y1 * y2))
+ t3 <- newSTRef (mod (x2 + y2))
+ t4 <- newSTRef (mod (x1 + y1)) -- 4
readSTRef t4 >>= \r4 ->
- modifySTRef' t3 (\r3 -> mods (r3 * r4))
+ modifySTRef' t3 (\r3 -> mod (r3 * r4))
readSTRef t0 >>= \r0 ->
readSTRef t1 >>= \r1 ->
- writeSTRef t4 (mods (r0 + r1))
+ writeSTRef t4 (mod (r0 + r1))
readSTRef t4 >>= \r4 ->
- modifySTRef' t3 (\r3 -> mods (r3 - r4)) -- 7
- writeSTRef t4 (mods (y2 * z1))
- modifySTRef' t4 (\r4 -> mods (r4 + y1))
- writeSTRef y3 (mods (x2 * z1)) -- 10
- modifySTRef' y3 (\ry3 -> mods (ry3 + x1))
+ modifySTRef' t3 (\r3 -> mod (r3 - r4)) -- 7
+ writeSTRef t4 (mod (y2 * z1))
+ modifySTRef' t4 (\r4 -> mod (r4 + y1))
+ writeSTRef y3 (mod (x2 * z1)) -- 10
+ modifySTRef' y3 (\ry3 -> mod (ry3 + x1))
readSTRef t0 >>= \r0 ->
- writeSTRef x3 (mods (r0 + r0))
+ writeSTRef x3 (mod (r0 + r0))
readSTRef x3 >>= \rx3 ->
- modifySTRef' t0 (\r0 -> mods (rx3 + r0)) -- 13
- t2 <- newSTRef (mods (b3 * z1))
+ modifySTRef' t0 (\r0 -> mod (rx3 + r0)) -- 13
+ t2 <- newSTRef (mod (b3 * z1))
readSTRef t1 >>= \r1 ->
readSTRef t2 >>= \r2 ->
- writeSTRef z3 (mods (r1 + r2))
+ writeSTRef z3 (mod (r1 + r2))
readSTRef t2 >>= \r2 ->
- modifySTRef' t1 (\r1 -> mods (r1 - r2)) -- 16
- modifySTRef' y3 (\ry3 -> mods (b3 * ry3))
+ modifySTRef' t1 (\r1 -> mod (r1 - r2)) -- 16
+ modifySTRef' y3 (\ry3 -> mod (b3 * ry3))
readSTRef t4 >>= \r4 ->
readSTRef y3 >>= \ry3 ->
- writeSTRef x3 (mods (r4 * ry3))
+ writeSTRef x3 (mod (r4 * ry3))
readSTRef t3 >>= \r3 ->
readSTRef t1 >>= \r1 ->
- writeSTRef t2 (mods (r3 * r1)) -- 19
+ writeSTRef t2 (mod (r3 * r1)) -- 19
readSTRef t2 >>= \r2 ->
- modifySTRef' x3 (\rx3 -> mods (r2 - rx3))
+ modifySTRef' x3 (\rx3 -> mod (r2 - rx3))
readSTRef t0 >>= \r0 ->
- modifySTRef' y3 (\ry3 -> mods (ry3 * r0))
+ modifySTRef' y3 (\ry3 -> mod (ry3 * r0))
readSTRef z3 >>= \rz3 ->
- modifySTRef' t1 (\r1 -> mods (r1 * rz3)) -- 22
+ modifySTRef' t1 (\r1 -> mod (r1 * rz3)) -- 22
readSTRef t1 >>= \r1 ->
- modifySTRef' y3 (\ry3 -> mods (r1 + ry3))
+ modifySTRef' y3 (\ry3 -> mod (r1 + ry3))
readSTRef t3 >>= \r3 ->
- modifySTRef' t0 (\r0 -> mods (r0 * r3))
+ modifySTRef' t0 (\r0 -> mod (r0 * r3))
readSTRef t4 >>= \r4 ->
- modifySTRef' z3 (\rz3 -> mods (rz3 * r4)) -- 25
+ modifySTRef' z3 (\rz3 -> mod (rz3 * r4)) -- 25
readSTRef t0 >>= \r0 ->
- modifySTRef' z3 (\rz3 -> mods (rz3 + r0))
+ modifySTRef' z3 (\rz3 -> mod (rz3 + r0))
Projective <$> readSTRef x3 <*> readSTRef y3 <*> readSTRef z3
-- algo 9, renes et al, 2015
@@ -260,59 +262,69 @@ double (Projective x y z) = runST $ do
x3 <- newSTRef 0
y3 <- newSTRef 0
z3 <- newSTRef 0
- let b3 = mods (_CURVE_B * 3)
- t0 <- newSTRef (mods (y * y)) -- 1
+ let b3 = mod (_CURVE_B * 3)
+ t0 <- newSTRef (mod (y * y)) -- 1
readSTRef t0 >>= \r0 ->
- writeSTRef z3 (mods (r0 + r0))
- modifySTRef' z3 (\rz3 -> mods (rz3 + rz3))
- modifySTRef' z3 (\rz3 -> mods (rz3 + rz3)) -- 4
- t1 <- newSTRef (mods (y * z))
- t2 <- newSTRef (mods (z * z))
- modifySTRef t2 (\r2 -> mods (b3 * r2)) -- 7
+ writeSTRef z3 (mod (r0 + r0))
+ modifySTRef' z3 (\rz3 -> mod (rz3 + rz3))
+ modifySTRef' z3 (\rz3 -> mod (rz3 + rz3)) -- 4
+ t1 <- newSTRef (mod (y * z))
+ t2 <- newSTRef (mod (z * z))
+ modifySTRef t2 (\r2 -> mod (b3 * r2)) -- 7
readSTRef z3 >>= \rz3 ->
readSTRef t2 >>= \r2 ->
- writeSTRef x3 (mods (r2 * rz3))
+ writeSTRef x3 (mod (r2 * rz3))
readSTRef t0 >>= \r0 ->
readSTRef t2 >>= \r2 ->
- writeSTRef y3 (mods (r0 + r2))
+ writeSTRef y3 (mod (r0 + r2))
readSTRef t1 >>= \r1 ->
- modifySTRef' z3 (\rz3 -> mods (r1 * rz3)) -- 10
+ modifySTRef' z3 (\rz3 -> mod (r1 * rz3)) -- 10
readSTRef t2 >>= \r2 ->
- writeSTRef t1 (mods (r2 + r2))
+ writeSTRef t1 (mod (r2 + r2))
readSTRef t1 >>= \r1 ->
- modifySTRef' t2 (\r2 -> mods (r1 + r2))
+ modifySTRef' t2 (\r2 -> mod (r1 + r2))
readSTRef t2 >>= \r2 ->
- modifySTRef' t0 (\r0 -> mods (r0 - r2)) -- 13
+ modifySTRef' t0 (\r0 -> mod (r0 - r2)) -- 13
readSTRef t0 >>= \r0 ->
- modifySTRef' y3 (\ry3 -> mods (r0 * ry3))
+ modifySTRef' y3 (\ry3 -> mod (r0 * ry3))
readSTRef x3 >>= \rx3 ->
- modifySTRef' y3 (\ry3 -> mods (rx3 + ry3))
- writeSTRef t1 (mods (x * y)) -- 16
+ modifySTRef' y3 (\ry3 -> mod (rx3 + ry3))
+ writeSTRef t1 (mod (x * y)) -- 16
readSTRef t0 >>= \r0 ->
readSTRef t1 >>= \r1 ->
- writeSTRef x3 (mods (r0 * r1))
- modifySTRef' x3 (\rx3 -> mods (rx3 + rx3))
+ writeSTRef x3 (mod (r0 * r1))
+ modifySTRef' x3 (\rx3 -> mod (rx3 + rx3))
Projective <$> readSTRef x3 <*> readSTRef y3 <*> readSTRef z3
-
-
--- mul(n, safe = true) {
--- if (!safe && n === 0n)
--- return I; // in unsafe mode, allow zero
--- if (!ge(n))
--- err('invalid scalar'); // must be 0 < n < CURVE.n
--- if (this.equals(G))
--- return wNAF(n).p; // use precomputes for base point
--- let p = I, f = G; // init result point & fake point
--- for (let d = this; n > 0n; d = d.double(), n >>= 1n) { // double-and-add ladder
--- if (n & 1n)
--- p = p.add(d); // if bit is present, add to point
--- else if (safe)
--- f = f.add(d); // if not, add to fake for timing safety
--- }
--- return p;
--- }
-
+mul :: Projective -> Integer -> Projective
+mul p n
+ | n == 0 = _ZERO
+ | not (ge n) = error "ppad-secp256k1 (mul): scalar not in group"
+ | otherwise = loop _ZERO p n
+ where
+ loop !r !d m
+ | m <= 0 = r
+ | otherwise =
+ let nd = double d
+ nm = I.integerShiftR m 1
+ nr = if I.integerTestBit m 0 then add r d else r
+ in loop nr nd nm
+
+-- XX confirm nf evaluation
+-- timing safety
+mul_safe :: Projective -> Integer -> Projective
+mul_safe p n
+ | not (ge n) = error "ppad-secp256k1 (mul_safe): scalar not in group"
+ | otherwise = loop _ZERO _BASE p n
+ where
+ loop !r !f !d m
+ | m <= 0 = r
+ | otherwise =
+ let nd = double d
+ nm = I.integerShiftR m 1
+ in if I.integerTestBit m 0
+ then loop (add r d) f nd nm
+ else loop r (add f d) nd nm
-- to affine coordinates
affine :: Projective -> Maybe Affine
@@ -321,9 +333,9 @@ affine p@(Projective x y z)
| z == 1 = pure (Affine x y)
| otherwise = do
iz <- modinv z (fromIntegral _CURVE_P)
- if mods (z * iz) /= 1
+ if mod (z * iz) /= 1
then Nothing
- else pure (Affine (mods (x * iz)) (mods (y * iz)))
+ else pure (Affine (mod (x * iz)) (mod (y * iz)))
-- to projective coordinates
projective :: Affine -> Projective
@@ -337,7 +349,7 @@ valid p = case affine p of
Nothing -> False
Just (Affine x y)
| not (fe x) || not (fe y) -> False
- | mods (y * y) /= weierstrass x -> False
+ | mod (y * y) /= weierstrass x -> False
| otherwise -> True
-- parse hex-encoded point
@@ -346,7 +358,7 @@ parse (B16.decode -> ebs) = case ebs of
Left _ -> Nothing
Right bs -> case BS.uncons bs of
Nothing -> Nothing
- Just (h, t) ->
+ Just (fromIntegral -> h, t) ->
let (roll -> x, etc) = BS.splitAt _GROUP_BYTELENGTH t
len = BS.length bs
in if len == 33 && (h == 0x02 || h == 0x03) -- compressed
@@ -355,10 +367,10 @@ parse (B16.decode -> ebs) = case ebs of
else do
y <- modsqrt (weierstrass x)
let yodd = I.integerTestBit y 0
- hodd = I.integerTestBit (fromIntegral h) 0
+ hodd = I.integerTestBit h 0
pure $
if hodd /= yodd
- then Projective x (mods (negate y)) 1
+ then Projective x (mod (negate y)) 1
else Projective x y 1
else if len == 65 && h == 0x04 -- uncompressed
then let (roll -> y, _) = BS.splitAt _GROUP_BYTELENGTH etc
@@ -374,5 +386,5 @@ parse (B16.decode -> ebs) = case ebs of
-- big-endian bytestring decoding
roll :: BS.ByteString -> Integer
roll = BS.foldl' unstep 0 where
- unstep a b = (a `I.integerShiftL` 8) `I.integerOr` fromIntegral b
+ unstep a (fromIntegral -> b) = (a `I.integerShiftL` 8) `I.integerOr` b