commit 028eb2f53a3449c63372914d1e771c60e146a823
parent 6e811629750af47d55134b69e2823fdcf974d1fd
Author: Jared Tobin <jared@jtobin.io>
Date: Sun, 21 Dec 2025 13:03:34 -0330
lib: wnaf refactoring
Diffstat:
3 files changed, 174 insertions(+), 99 deletions(-)
diff --git a/bench/Main.hs b/bench/Main.hs
@@ -29,8 +29,9 @@ main = defaultMain [
, add
, double
, mul
- , precompute
+ , mul_vartime
, mul_wnaf
+ , precompute
, derive_pub
, schnorr
, ecdsa
@@ -96,6 +97,16 @@ mul = env setup $ \x ->
setup = pure . parse_int256 $ decodeLenient
"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed"
+mul_vartime :: Benchmark
+mul_vartime = env setup $ \x ->
+ bgroup "mul_vartime" [
+ bench "2 G" $ nf (S.mul_vartime S._CURVE_G) 2
+ , bench "(2 ^ 255 - 19) G" $ nf (S.mul_vartime S._CURVE_G) x
+ ]
+ where
+ setup = pure . parse_int256 $ decodeLenient
+ "7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed"
+
precompute :: Benchmark
precompute = bench "precompute" $ nfIO (pure S.precompute)
diff --git a/bench/Weight.hs b/bench/Weight.hs
@@ -35,7 +35,6 @@ main = W.mainWith $ do
add
double
mul
- mul_unsafe
mul_wnaf
derive_pub
schnorr
@@ -92,15 +91,6 @@ mul =
W.func' "2 G" (S.mul g) t
W.func' "(2 ^ 255 - 19) G" (S.mul g) b
-mul_unsafe :: W.Weigh ()
-mul_unsafe =
- let !g = S._CURVE_G
- !t = 2
- !b = 0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed
- in W.wgroup "mul_unsafe" $ do
- W.func' "2 G" (S.mul_unsafe g) t
- W.func' "(2 ^ 255 - 19) G" (S.mul_unsafe g) b
-
mul_wnaf :: W.Weigh ()
mul_wnaf =
let !t = 2
diff --git a/lib/Crypto/Curve/Secp256k1.hs b/lib/Crypto/Curve/Secp256k1.hs
@@ -27,7 +27,6 @@ module Crypto.Curve.Secp256k1 (
-- * Field and group parameters
_CURVE_Q
, _CURVE_P
- , modQ
-- * secp256k1 points
, Pub
@@ -78,7 +77,7 @@ module Crypto.Curve.Secp256k1 (
, add_proj
, double
, mul
- , mul_unsafe
+ , mul_vartime
, mul_wnaf
-- Coordinate systems and transformations
@@ -102,14 +101,14 @@ import Control.Monad (guard)
import Control.Monad.ST
import qualified Crypto.DRBG.HMAC as DRBG
import qualified Crypto.Hash.SHA256 as SHA256
-import Data.Bits ((.&.))
import qualified Data.Bits as B
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BI
import qualified Data.ByteString.Unsafe as BU
import qualified Data.Choice as CT
import qualified Data.Maybe as M
-import qualified Data.Primitive.Array as A
+import Data.Primitive.ByteArray (ByteArray(..), MutableByteArray(..))
+import qualified Data.Primitive.ByteArray as BA
import Data.Word (Word8)
import Data.Word.Limb (Limb(..))
import qualified Data.Word.Limb as L
@@ -118,28 +117,19 @@ import qualified Data.Word.Wider as W
import qualified Foreign.Storable as Storable (pokeByteOff)
import qualified GHC.Exts as Exts
import GHC.Generics
-import qualified GHC.Int (Int(..))
import qualified GHC.Word (Word(..), Word8(..))
import qualified Numeric.Montgomery.Secp256k1.Curve as C
import qualified Numeric.Montgomery.Secp256k1.Scalar as S
import Prelude hiding (sqrt)
--- utilities ------------------------------------------------------------------
-
-fi :: (Integral a, Num b) => a -> b
-fi = fromIntegral
-{-# INLINE fi #-}
-
--- dumb strict pair
-data Pair a b = Pair !a !b
+-- convenience synonyms -------------------------------------------------------
--- Unboxed Montgomery synonym.
+-- Unboxed Wider/Montgomery synonym.
type Limb4 = (# Limb, Limb, Limb, Limb #)
-- Unboxed Projective synonym.
type Proj = (# Limb4, Limb4, Limb4 #)
--- convenience patterns
pattern Zero :: Wider
pattern Zero = Wider Z
@@ -151,6 +141,12 @@ pattern P x y z =
Projective (C.Montgomery x) (C.Montgomery y) (C.Montgomery z)
{-# COMPLETE P #-}
+-- utilities ------------------------------------------------------------------
+
+fi :: (Integral a, Num b) => a -> b
+fi = fromIntegral
+{-# INLINE fi #-}
+
-- convert a Word8 to a Limb
limb :: Word8 -> Limb
limb (GHC.Word.W8# (Exts.word8ToWord# -> w)) = Limb w
@@ -173,10 +169,6 @@ word8_to_wider :: Word8 -> Wider
word8_to_wider w = Wider (# limb w, Limb 0##, Limb 0##, Limb 0## #)
{-# INLINABLE word8_to_wider #-}
-wider_to_int :: Wider -> Int
-wider_to_int (Wider (# Limb l, _, _, _ #)) = GHC.Int.I# (Exts.word2Int# l)
-{-# INLINABLE wider_to_int #-}
-
-- unsafely extract the first 64-bit word from a big-endian-encoded bytestring
unsafe_word0 :: BS.ByteString -> Limb
unsafe_word0 bs =
@@ -299,6 +291,7 @@ modQ = S.from . S.to
-- bytewise xor
xor :: BS.ByteString -> BS.ByteString -> BS.ByteString
xor = BS.packZipWith B.xor
+{-# INLINABLE xor #-}
-- constants ------------------------------------------------------------------
@@ -444,9 +437,15 @@ lift_vartime x = do
even_y_vartime :: Projective -> Projective
even_y_vartime p = case affine p of
Affine _ (C.retr -> y)
- | CT.decide (W.odd y) -> neg p -- XX
+ | CT.decide (W.odd y) -> neg p
| otherwise -> p
+-- Constant-time selection of Projective points.
+select_proj :: Projective -> Projective -> CT.Choice -> Projective
+select_proj (P ax ay az) (P bx by bz) c =
+ P (W.select# ax bx c) (W.select# ay by c) (W.select# az bz c)
+{-# INLINE select_proj #-}
+
-- unboxed internals ----------------------------------------------------------
-- algo 7, renes et al, 2015
@@ -552,6 +551,10 @@ select_proj# (# ax, ay, az #) (# bx, by, bz #) c =
(# W.select# ax bx c, W.select# ay by c, W.select# az bz c #)
{-# INLINE select_proj# #-}
+neg# :: Proj -> Proj
+neg# (# x, y, z #) = (# x, C.neg# y, z #)
+{-# INLINE neg# #-}
+
mul# :: Proj -> Limb4 -> (# () | Proj #)
mul# (# px, py, pz #) s
| CT.decide (CT.not# (ge# s)) = (# () | #)
@@ -576,17 +579,77 @@ ge# n =
in CT.and# (W.gt# n Z) (W.lt# n q)
{-# INLINE ge# #-}
--- ec arithmetic --------------------------------------------------------------
+mul_wnaf# :: ByteArray -> Int -> Limb4 -> (# () | Proj #)
+mul_wnaf# ctxArray ctxW ls
+ | CT.decide (CT.not# (ge# ls)) = (# () | #)
+ | otherwise =
+ let !(P zx zy zz) = _CURVE_ZERO
+ !(P gx gy gz) = _CURVE_G
+ in (# | loop 0 (# zx, zy, zz #) (# gx, gy, gz #) ls #)
+ where
+ !one = (# Limb 1##, Limb 0##, Limb 0##, Limb 0## #)
+ !wins = fi (256 `quot` ctxW + 1)
+ !size@(GHC.Word.W# s) = 2 ^ (ctxW - 1)
+ !(GHC.Word.W# mask) = 2 ^ ctxW - 1
+ !(GHC.Word.W# texW) = fi ctxW
+ !(GHC.Word.W# mnum) = 2 ^ ctxW
+
+ loop !j@(GHC.Word.W# w) !acc !f !n@(# Limb lo, _, _, _ #)
+ | j == wins = acc
+ | otherwise =
+ let !(GHC.Word.W# off0) = j * size
+ !b0 = Exts.and# lo mask
+ !bor = CT.from_word_gt# b0 s
+
+ !(# n0, _ #) = W.shr_limb# n (Exts.word2Int# texW)
+ !n0_plus_1 = W.add_w# n0 one
+ !n1 = W.select# n0 n0_plus_1 bor
+
+ !abs_b = CT.select_word# b0 (Exts.minusWord# mnum b0) bor
+ !is_zero = CT.from_word_eq# b0 0##
+ !c0 = CT.from_word# (Exts.and# w 1##)
+ !off_nz = Exts.minusWord# (Exts.plusWord# off0 abs_b) 1##
+ !off = CT.select_word# off0 off_nz (CT.not# is_zero)
+
+ !pr = index_proj# ctxArray (Exts.word2Int# off)
+ !neg_pr = neg# pr
+ !pt_zero = select_proj# pr neg_pr c0
+ !pt_nonzero = select_proj# pr neg_pr bor
+
+ !f_added = add_proj# f pt_zero
+ !acc_added = add_proj# acc pt_nonzero
+ !nacc = select_proj# acc_added acc is_zero
+ !nf = select_proj# f f_added is_zero
+ in loop (succ j) nacc nf n1
+{-# INLINE mul_wnaf# #-}
+
+-- retrieve a point (as an unboxed tuple) from a context array
+index_proj# :: ByteArray -> Exts.Int# -> Proj
+index_proj# (ByteArray arr#) i# =
+ let !base# = i# Exts.*# 12#
+ !x = (# Limb (Exts.indexWordArray# arr# base#)
+ , Limb (Exts.indexWordArray# arr# (base# Exts.+# 01#))
+ , Limb (Exts.indexWordArray# arr# (base# Exts.+# 02#))
+ , Limb (Exts.indexWordArray# arr# (base# Exts.+# 03#)) #)
+ !y = (# Limb (Exts.indexWordArray# arr# (base# Exts.+# 04#))
+ , Limb (Exts.indexWordArray# arr# (base# Exts.+# 05#))
+ , Limb (Exts.indexWordArray# arr# (base# Exts.+# 06#))
+ , Limb (Exts.indexWordArray# arr# (base# Exts.+# 07#)) #)
+ !z = (# Limb (Exts.indexWordArray# arr# (base# Exts.+# 08#))
+ , Limb (Exts.indexWordArray# arr# (base# Exts.+# 09#))
+ , Limb (Exts.indexWordArray# arr# (base# Exts.+# 10#))
+ , Limb (Exts.indexWordArray# arr# (base# Exts.+# 11#)) #)
+ in (# x, y, z #)
+{-# INLINE index_proj# #-}
--- Constant-time selection of Projective points.
-select_proj :: Projective -> Projective -> CT.Choice -> Projective
-select_proj (P ax ay az) (P bx by bz) c =
- P (W.select# ax bx c) (W.select# ay by c) (W.select# az bz c)
-{-# INLINE select_proj #-}
+-- ec arithmetic --------------------------------------------------------------
-- Negate secp256k1 point.
neg :: Projective -> Projective
-neg (Projective x y z) = Projective x (negate y) z
+neg (P x y z) =
+ let !(# px, py, pz #) = neg# (# x, y, z #)
+ in P px py pz
+{-# INLINABLE neg #-}
-- Elliptic curve addition on secp256k1.
add :: Projective -> Projective -> Projective
@@ -624,13 +687,11 @@ mul (P x y z) (Wider s) = case mul# (# x, y, z #) s of
(# | (# px, py, pz #) #) -> Just $! P px py pz
{-# INLINABLE mul #-}
--- XX mul_vartime might be nicer
-
-- Timing-unsafe scalar multiplication of secp256k1 points.
--
-- Don't use this function if the scalar could potentially be a secret.
-mul_unsafe :: Projective -> Wider -> Maybe Projective
-mul_unsafe p = \case
+mul_vartime :: Projective -> Wider -> Maybe Projective
+mul_vartime p = \case
Zero -> pure _CURVE_ZERO
n | not (ge n) -> Nothing
| otherwise -> pure $! loop _CURVE_ZERO p n
@@ -646,7 +707,7 @@ mul_unsafe p = \case
-- | Precomputed multiples of the secp256k1 base or generator point.
data Context = Context {
ctxW :: {-# UNPACK #-} !Int
- , ctxArray :: !(A.Array Projective)
+ , ctxArray :: {-# UNPACK #-} !ByteArray
} deriving Generic
instance Show Context where
@@ -664,66 +725,79 @@ instance Show Context where
precompute :: Context
precompute = _precompute 8
--- translation of noble-secp256k1's 'precompute'
+-- This is a highly-optimized version of a function originally
+-- translated from noble-secp256k1's "precompute". Points are stored in
+-- a ByteArray by arranging each limb into slices of 12 consecutive
+-- slots (each Projective point consists of three Montgomery values,
+-- each of which consists of four limbs, summing to twelve limbs in
+-- total).
+--
+-- Each point takes 96 bytes to store in this fashion, so the total size of
+-- the ByteArray is (size * 96) bytes.
_precompute :: Int -> Context
_precompute ctxW = Context {..} where
- ctxArray = A.arrayFromListN size (loop_w mempty _CURVE_G 0)
capJ = (2 :: Int) ^ (ctxW - 1)
ws = 256 `quot` ctxW + 1
size = ws * capJ
- loop_w !acc !p !w
- | w == ws = reverse acc
- | otherwise =
- let b = p
- !(Pair nacc nb) = loop_j p (b : acc) b 1
- np = double nb
- in loop_w nacc np (succ w)
-
- loop_j !p !acc !b !j
- | j == capJ = Pair acc b
- | otherwise =
- let nb = add b p
- in loop_j p (nb : acc) nb (succ j)
+ -- construct the context array
+ ctxArray = runST $ do
+ marr <- BA.newByteArray (size * 96)
+ loop_w marr _CURVE_G 0
+ BA.unsafeFreezeByteArray marr
+
+ -- write a point into the i^th 12-slot slice in the array
+ write :: MutableByteArray s -> Int -> Projective -> ST s ()
+ write marr i
+ (P (# Limb x0, Limb x1, Limb x2, Limb x3 #)
+ (# Limb y0, Limb y1, Limb y2, Limb y3 #)
+ (# Limb z0, Limb z1, Limb z2, Limb z3 #)) = do
+ let !base = i * 12
+ BA.writeByteArray marr (base + 00) (GHC.Word.W# x0)
+ BA.writeByteArray marr (base + 01) (GHC.Word.W# x1)
+ BA.writeByteArray marr (base + 02) (GHC.Word.W# x2)
+ BA.writeByteArray marr (base + 03) (GHC.Word.W# x3)
+ BA.writeByteArray marr (base + 04) (GHC.Word.W# y0)
+ BA.writeByteArray marr (base + 05) (GHC.Word.W# y1)
+ BA.writeByteArray marr (base + 06) (GHC.Word.W# y2)
+ BA.writeByteArray marr (base + 07) (GHC.Word.W# y3)
+ BA.writeByteArray marr (base + 08) (GHC.Word.W# z0)
+ BA.writeByteArray marr (base + 09) (GHC.Word.W# z1)
+ BA.writeByteArray marr (base + 10) (GHC.Word.W# z2)
+ BA.writeByteArray marr (base + 11) (GHC.Word.W# z3)
+
+ -- loop over windows
+ loop_w :: MutableByteArray s -> Projective -> Int -> ST s ()
+ loop_w !marr !p !w
+ | w == ws = pure ()
+ | otherwise = do
+ nb <- loop_j marr p p (w * capJ) 0
+ let np = double nb
+ loop_w marr np (succ w)
+
+ -- loop within windows
+ loop_j
+ :: MutableByteArray s
+ -> Projective
+ -> Projective
+ -> Int
+ -> Int
+ -> ST s Projective
+ loop_j !marr !p !b !idx !j = do
+ write marr idx b
+ if j == capJ - 1
+ then pure b
+ else do
+ let !nb = add b p
+ loop_j marr p nb (succ idx) (succ j)
-- Timing-safe wNAF (w-ary non-adjacent form) scalar multiplication of
-- secp256k1 points.
mul_wnaf :: Context -> Wider -> Maybe Projective
-mul_wnaf Context {..} _SECRET = do
- guard (ge _SECRET)
- pure $! loop 0 _CURVE_ZERO _CURVE_G _SECRET
- where
- wins = 256 `quot` ctxW + 1
- wsize = 2 ^ (ctxW - 1)
- mask = 2 ^ ctxW - 1
- mnum = 2 ^ ctxW
-
- loop !w !acc !f !n
- | w == wins = acc
- | otherwise =
- let !off0 = w * wsize
-
- !b0 = wider_to_int n .&. mask
- !n0 = n `W.shr_limb` ctxW
-
- !(Pair b1 n1) | b0 > wsize = Pair (b0 - mnum) (n0 + 1)
- | otherwise = Pair b0 n0
-
- !c0 = B.testBit w 0
- !c1 = b1 < 0
-
- !off1 = off0 + fi (abs b1) - 1
-
- in if b1 == 0
- then let !pr = A.indexArray ctxArray off0
- !pt | c0 = neg pr
- | otherwise = pr
- in loop (w + 1) acc (add f pt) n1
- else let !pr = A.indexArray ctxArray off1
- !pt | c1 = neg pr
- | otherwise = pr
- in loop (w + 1) (add acc pt) f n1
-{-# INLINE mul_wnaf #-}
+mul_wnaf Context {..} (Wider s) = case mul_wnaf# ctxArray ctxW s of
+ (# () | #) -> Nothing
+ (# | (# px, py, pz #) #) -> Just $! P px py pz
+{-# INLINABLE mul_wnaf #-}
-- | Derive a public key (i.e., a secp256k1 point) from the provided
-- secret.
@@ -759,6 +833,7 @@ parse_int256 :: BS.ByteString -> Maybe Wider
parse_int256 bs = do
guard (BS.length bs == 32)
pure $! unsafe_roll32 bs
+{-# INLINABLE parse_int256 #-}
-- | Parse compressed secp256k1 point (33 bytes), uncompressed point (65
-- bytes), or BIP0340-style point (32 bytes).
@@ -794,8 +869,7 @@ _parse_compressed h (unsafe_roll32 -> x)
| otherwise = do
let !mx = C.to x
!my <- C.sqrt (weierstrass mx)
- let !(W.Wider (# Limb w, _, _, _ #)) = C.retr my
- !yodd = B.testBit (GHC.Word.W# w) 0
+ let !yodd = CT.decide (W.odd (C.retr my))
!hodd = B.testBit h 0
pure $!
if hodd /= yodd
@@ -954,7 +1028,7 @@ verify_schnorr
-> Pub -- ^ public key
-> BS.ByteString -- ^ 64-byte Schnorr signature
-> Bool
-verify_schnorr = _verify_schnorr (mul_unsafe _CURVE_G)
+verify_schnorr = _verify_schnorr (mul_vartime _CURVE_G)
-- | The same as 'verify_schnorr', except uses a 'Context' to optimise
-- internal calculations.
@@ -991,7 +1065,7 @@ _verify_schnorr _mul m p sig
e = modQ . unsafe_roll32 $
hash_challenge (unroll32 r <> unroll32 x_P <> m)
pt0 <- _mul s
- pt1 <- mul_unsafe capP e
+ pt1 <- mul_vartime capP e
let dif = add pt0 (neg pt1)
guard (dif /= _CURVE_ZERO)
let Affine (C.from -> x_R) (C.from -> y_R) = affine dif
@@ -1254,7 +1328,7 @@ verify_ecdsa_unrestricted
-> Pub -- ^ public key
-> ECDSA -- ^ signature
-> Bool
-verify_ecdsa_unrestricted = _verify_ecdsa_unrestricted (mul_unsafe _CURVE_G)
+verify_ecdsa_unrestricted = _verify_ecdsa_unrestricted (mul_vartime _CURVE_G)
-- | The same as 'verify_ecdsa_unrestricted', except uses a 'Context' to
-- optimise internal calculations.
@@ -1292,7 +1366,7 @@ _verify_ecdsa_unrestricted _mul m p (ECDSA r0 s0) = M.isJust $ do
u1 = S.retr (e * si)
u2 = S.retr (r * si)
pt0 <- _mul u1
- pt1 <- mul_unsafe p u2
+ pt1 <- mul_vartime p u2
let capR = add pt0 pt1
guard (capR /= _CURVE_ZERO)
let Affine (S.to . C.retr -> v) _ = affine capR