commit b27a5d079113939782e288954d9de30f037b3d8e
parent f0c1e3348b3192c1f62003caa8b9eb6ecbaf1f5f
Author: Jared Tobin <jared@jtobin.io>
Date: Tue, 23 Dec 2025 12:57:07 -0330
lib: init hardening in montgomery modules
Diffstat:
4 files changed, 114 insertions(+), 65 deletions(-)
diff --git a/lib/Numeric/Montgomery/Secp256k1/Curve.hs b/lib/Numeric/Montgomery/Secp256k1/Curve.hs
@@ -19,8 +19,8 @@ module Numeric.Montgomery.Secp256k1.Curve (
-- * Montgomery form, secp256k1 field prime modulus
Montgomery(..)
, render
- , to
- , from
+ , to_vartime
+ , from_vartime
, zero
, one
@@ -92,21 +92,23 @@ render (Montgomery (# Limb a, Limb b, Limb c, Limb d #)) =
<> show (W# c) <> ", " <> show (W# d) <> ")"
instance Show Montgomery where
- show = show . from
+ show = show . from_vartime
+-- | Note that 'fromInteger' necessarily runs in variable time due
+-- to conversion from the variable-size, potentially heap-allocated
+-- 'Integer' type.
instance Num Montgomery where
a + b = add a b
a - b = sub a b
a * b = mul a b
negate a = neg a
abs = id
- fromInteger = to . WW.to_vartime
- signum a = case a of
- Montgomery (# Limb 0##, Limb 0##, Limb 0##, Limb 0## #) -> 0
- _ -> 1
-
-instance Eq Montgomery where
- a == b = C.decide (eq a b)
+ fromInteger = to_vartime . WW.to_vartime
+ signum (Montgomery (# l0, l1, l2, l3 #)) =
+ let !(Limb l) = l0 `L.or#` l1 `L.or#` l2 `L.or#` l3
+ !n = C.from_word_nonzero# l
+ !b = C.to_word# n
+ in Montgomery (# Limb b, Limb 0##, Limb 0##, Limb 0## #)
instance NFData Montgomery where
rnf (Montgomery a) = case a of (# _, _, _, _ #) -> ()
@@ -366,14 +368,14 @@ to# x =
{-# INLINE to# #-}
-- | Convert a 'Wider' word to the Montgomery domain.
-to :: Wider -> Montgomery
-to (Wider x) = Montgomery (to# x)
+to_vartime :: Wider -> Montgomery
+to_vartime (Wider x) = Montgomery (to# x)
-- | Retrieve a 'Montgomery' word from the Montgomery domain.
--
-- This function is a synonym for 'retr'.
-from :: Montgomery -> Wider
-from = retr
+from_vartime :: Montgomery -> Wider
+from_vartime = retr
add#
:: (# Limb, Limb, Limb, Limb #) -- ^ augend
diff --git a/lib/Numeric/Montgomery/Secp256k1/Scalar.hs b/lib/Numeric/Montgomery/Secp256k1/Scalar.hs
@@ -19,8 +19,8 @@ module Numeric.Montgomery.Secp256k1.Scalar (
-- * Montgomery form, secp256k1 scalar group order modulus
Montgomery(..)
, render
- , to
- , from
+ , to_vartime
+ , from_vartime
, zero
, one
@@ -80,7 +80,7 @@ import Prelude hiding (or, and, not, exp, odd)
data Montgomery = Montgomery !(# Limb, Limb, Limb, Limb #)
instance Show Montgomery where
- show = show . from
+ show = show . from_vartime
-- | Render a 'Montgomery' value as a 'String', showing its individual
-- 'Limb's.
@@ -92,19 +92,21 @@ render (Montgomery (# Limb a, Limb b, Limb c, Limb d #)) =
"(" <> show (W# a) <> ", " <> show (W# b) <> ", "
<> show (W# c) <> ", " <> show (W# d) <> ")"
+-- | Note that 'fromInteger' necessarily runs in variable time due
+-- to conversion from the variable-size, potentially heap-allocated
+-- 'Integer' type.
instance Num Montgomery where
a + b = add a b
a - b = sub a b
a * b = mul a b
negate a = neg a
abs = id
- fromInteger = to . WW.to_vartime
- signum a = case a of
- Montgomery (# Limb 0##, Limb 0##, Limb 0##, Limb 0## #) -> 0
- _ -> 1
-
-instance Eq Montgomery where
- a == b = C.decide (eq a b)
+ fromInteger = to_vartime . WW.to_vartime
+ signum (Montgomery (# l0, l1, l2, l3 #)) =
+ let !(Limb l) = l0 `L.or#` l1 `L.or#` l2 `L.or#` l3
+ !n = C.from_word_nonzero# l
+ !b = C.to_word# n
+ in Montgomery (# Limb b, Limb 0##, Limb 0##, Limb 0## #)
instance NFData Montgomery where
rnf (Montgomery a) = case a of (# _, _, _, _ #) -> ()
@@ -364,14 +366,14 @@ to# x =
{-# INLINE to# #-}
-- | Convert a 'Wider' word to the Montgomery domain.
-to :: Wider -> Montgomery
-to (Wider x) = Montgomery (to# x)
+to_vartime :: Wider -> Montgomery
+to_vartime (Wider x) = Montgomery (to# x)
-- | Retrieve a 'Montgomery' word from the Montgomery domain.
--
-- This function is a synonym for 'retr'.
-from :: Montgomery -> Wider
-from = retr
+from_vartime :: Montgomery -> Wider
+from_vartime = retr
add#
:: (# Limb, Limb, Limb, Limb #) -- ^ augend
diff --git a/test/Montgomery/Curve.hs b/test/Montgomery/Curve.hs
@@ -11,6 +11,7 @@ module Montgomery.Curve (
tests
) where
+import qualified Data.Choice as CT
import qualified Data.Word.Wider as W
import qualified GHC.Num.Integer as I
import GHC.Natural
@@ -19,6 +20,10 @@ import Test.Tasty
import qualified Test.Tasty.HUnit as H
import qualified Test.Tasty.QuickCheck as Q
+-- orphan Eq instance for testing
+instance Eq C.Montgomery where
+ a == b = CT.decide (C.eq a b)
+
-- generic modular exponentiation
-- b ^ e mod m
modexp :: Integer -> Natural -> Natural -> Integer
@@ -36,12 +41,15 @@ mm :: C.Montgomery
mm = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
repr :: H.Assertion
-repr = H.assertBool mempty (W.eq_vartime 0 (C.from mm))
+repr = H.assertBool mempty (W.eq_vartime 0 (C.from_vartime mm))
add_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion
add_case t a b s = do
- H.assertEqual "sanity" ((W.from_vartime a + W.from_vartime b) `mod` W.from_vartime m) (W.from_vartime s)
- H.assertBool t (W.eq_vartime s (C.from (C.to a + C.to b)))
+ H.assertEqual "sanity"
+ ((W.from_vartime a + W.from_vartime b) `mod` W.from_vartime m)
+ (W.from_vartime s)
+ H.assertBool t
+ (W.eq_vartime s (C.from_vartime (C.to_vartime a + C.to_vartime b)))
add :: H.Assertion
add = do
@@ -61,8 +69,11 @@ add = do
sub_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion
sub_case t b a d = do
- H.assertEqual "sanity" ((W.from_vartime b - W.from_vartime a) `mod` W.from_vartime m) (W.from_vartime d)
- H.assertBool t (W.eq_vartime d (C.from (C.to b - C.to a)))
+ H.assertEqual "sanity"
+ ((W.from_vartime b - W.from_vartime a) `mod` W.from_vartime m)
+ (W.from_vartime d)
+ H.assertBool t
+ (W.eq_vartime d (C.from_vartime (C.to_vartime b - C.to_vartime a)))
sub :: H.Assertion
sub = do
@@ -81,8 +92,11 @@ sub = do
mul_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion
mul_case t a b p = do
- H.assertEqual "sanity" ((W.from_vartime a * W.from_vartime b) `mod` W.from_vartime m) (W.from_vartime p)
- H.assertBool t (W.eq_vartime p (C.from (C.to a * C.to b)))
+ H.assertEqual "sanity"
+ ((W.from_vartime a * W.from_vartime b) `mod` W.from_vartime m)
+ (W.from_vartime p)
+ H.assertBool t
+ (W.eq_vartime p (C.from_vartime (C.to_vartime a * C.to_vartime b)))
mul :: H.Assertion
mul = do
@@ -108,45 +122,54 @@ instance Q.Arbitrary W.Wider where
arbitrary = fmap W.to_vartime Q.arbitrary
instance Q.Arbitrary C.Montgomery where
- arbitrary = fmap C.to Q.arbitrary
+ arbitrary = fmap C.to_vartime Q.arbitrary
add_matches :: W.Wider -> W.Wider -> Bool
add_matches a b =
- let ma = C.to a
- mb = C.to b
+ let ma = C.to_vartime a
+ mb = C.to_vartime b
ia = W.from_vartime a
ib = W.from_vartime b
im = W.from_vartime m
- in W.eq_vartime (W.to_vartime ((ia + ib) `mod` im)) (C.from (ma + mb))
+ in W.eq_vartime
+ (W.to_vartime ((ia + ib) `mod` im))
+ (C.from_vartime (ma + mb))
mul_matches :: W.Wider -> W.Wider -> Bool
mul_matches a b =
- let ma = C.to a
- mb = C.to b
+ let ma = C.to_vartime a
+ mb = C.to_vartime b
ia = W.from_vartime a
ib = W.from_vartime b
im = W.from_vartime m
- in W.eq_vartime (W.to_vartime ((ia * ib) `mod` im)) (C.from (ma * mb))
+ in W.eq_vartime
+ (W.to_vartime ((ia * ib) `mod` im))
+ (C.from_vartime (ma * mb))
sqr_matches :: W.Wider -> Bool
sqr_matches a =
- let ma = C.to a
+ let ma = C.to_vartime a
ia = W.from_vartime a
im = W.from_vartime m
- in W.eq_vartime (W.to_vartime ((ia * ia) `mod` im)) (C.from (C.sqr ma))
+ in W.eq_vartime
+ (W.to_vartime ((ia * ia) `mod` im))
+ (C.from_vartime (C.sqr ma))
exp_matches :: C.Montgomery -> W.Wider -> Bool
exp_matches a b =
- let ia = W.from_vartime (C.from a)
+ let ia = W.from_vartime (C.from_vartime a)
nb = fromIntegral (W.from_vartime b)
nm = fromIntegral (W.from_vartime m)
- in W.eq_vartime (W.to_vartime (modexp ia nb nm)) (C.from (C.exp a b))
+ in W.eq_vartime
+ (W.to_vartime (modexp ia nb nm))
+ (C.from_vartime (C.exp a b))
inv_valid :: Q.NonZero C.Montgomery -> Bool
inv_valid (Q.NonZero s) = C.eq_vartime (C.inv s * s) 1
odd_correct :: C.Montgomery -> Bool
-odd_correct w = C.odd w == I.integerTestBit (W.from_vartime (C.from w)) 0
+odd_correct w =
+ C.odd w == I.integerTestBit (W.from_vartime (C.from_vartime w)) 0
tests :: TestTree
tests = testGroup "montgomery tests (curve)" [
diff --git a/test/Montgomery/Scalar.hs b/test/Montgomery/Scalar.hs
@@ -11,6 +11,7 @@ module Montgomery.Scalar (
tests
) where
+import qualified Data.Choice as CT
import qualified Data.Word.Wider as W
import qualified GHC.Num.Integer as I
import GHC.Natural
@@ -19,6 +20,10 @@ import Test.Tasty
import qualified Test.Tasty.HUnit as H
import qualified Test.Tasty.QuickCheck as Q
+-- orphan Eq instance for testing
+instance Eq S.Montgomery where
+ a == b = CT.decide (S.eq a b)
+
-- generic modular exponentiation
-- b ^ e mod m
modexp :: Integer -> Natural -> Natural -> Integer
@@ -36,12 +41,15 @@ mm :: S.Montgomery
mm = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
repr :: H.Assertion
-repr = H.assertBool mempty (W.eq_vartime 0 (S.from mm))
+repr = H.assertBool mempty (W.eq_vartime 0 (S.from_vartime mm))
add_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion
add_case t a b s = do
- H.assertEqual "sanity" ((W.from_vartime a + W.from_vartime b) `mod` W.from_vartime m) (W.from_vartime s)
- H.assertBool t (W.eq_vartime s (S.from (S.to a + S.to b)))
+ H.assertEqual "sanity"
+ ((W.from_vartime a + W.from_vartime b) `mod` W.from_vartime m)
+ (W.from_vartime s)
+ H.assertBool t
+ (W.eq_vartime s (S.from_vartime (S.to_vartime a + S.to_vartime b)))
add :: H.Assertion
add = do
@@ -61,8 +69,11 @@ add = do
sub_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion
sub_case t b a d = do
- H.assertEqual "sanity" ((W.from_vartime b - W.from_vartime a) `mod` W.from_vartime m) (W.from_vartime d)
- H.assertBool t (W.eq_vartime d (S.from (S.to b - S.to a)))
+ H.assertEqual "sanity"
+ ((W.from_vartime b - W.from_vartime a) `mod` W.from_vartime m)
+ (W.from_vartime d)
+ H.assertBool t
+ (W.eq_vartime d (S.from_vartime (S.to_vartime b - S.to_vartime a)))
sub :: H.Assertion
sub = do
@@ -81,8 +92,11 @@ sub = do
mul_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion
mul_case t a b p = do
- H.assertEqual "sanity" ((W.from_vartime a * W.from_vartime b) `mod` W.from_vartime m) (W.from_vartime p)
- H.assertBool t (W.eq_vartime p (S.from (S.to a * S.to b)))
+ H.assertEqual "sanity"
+ ((W.from_vartime a * W.from_vartime b) `mod` W.from_vartime m)
+ (W.from_vartime p)
+ H.assertBool t
+ (W.eq_vartime p (S.from_vartime (S.to_vartime a * S.to_vartime b)))
mul :: H.Assertion
mul = do
@@ -108,39 +122,47 @@ instance Q.Arbitrary W.Wider where
arbitrary = fmap W.to_vartime Q.arbitrary
instance Q.Arbitrary S.Montgomery where
- arbitrary = fmap S.to Q.arbitrary
+ arbitrary = fmap S.to_vartime Q.arbitrary
add_matches :: W.Wider -> W.Wider -> Bool
add_matches a b =
- let ma = S.to a
- mb = S.to b
+ let ma = S.to_vartime a
+ mb = S.to_vartime b
ia = W.from_vartime a
ib = W.from_vartime b
im = W.from_vartime m
- in W.eq_vartime (W.to_vartime ((ia + ib) `mod` im)) (S.from (ma + mb))
+ in W.eq_vartime
+ (W.to_vartime ((ia + ib) `mod` im))
+ (S.from_vartime (ma + mb))
mul_matches :: W.Wider -> W.Wider -> Bool
mul_matches a b =
- let ma = S.to a
- mb = S.to b
+ let ma = S.to_vartime a
+ mb = S.to_vartime b
ia = W.from_vartime a
ib = W.from_vartime b
im = W.from_vartime m
- in W.eq_vartime (W.to_vartime ((ia * ib) `mod` im)) (S.from (ma * mb))
+ in W.eq_vartime
+ (W.to_vartime ((ia * ib) `mod` im))
+ (S.from_vartime (ma * mb))
sqr_matches :: W.Wider -> Bool
sqr_matches a =
- let ma = S.to a
+ let ma = S.to_vartime a
ia = W.from_vartime a
im = W.from_vartime m
- in W.eq_vartime (W.to_vartime ((ia * ia) `mod` im)) (S.from (S.sqr ma))
+ in W.eq_vartime
+ (W.to_vartime ((ia * ia) `mod` im))
+ (S.from_vartime (S.sqr ma))
exp_matches :: S.Montgomery -> W.Wider -> Bool
exp_matches a b =
- let ia = W.from_vartime (S.from a)
+ let ia = W.from_vartime (S.from_vartime a)
nb = fromIntegral (W.from_vartime b)
nm = fromIntegral (W.from_vartime m)
- in W.eq_vartime (W.to_vartime (modexp ia nb nm)) (S.from (S.exp a b))
+ in W.eq_vartime
+ (W.to_vartime (modexp ia nb nm))
+ (S.from_vartime (S.exp a b))
inv_valid :: Q.NonZero S.Montgomery -> Bool
inv_valid (Q.NonZero s) = S.eq_vartime (S.inv s * s) 1