fixed

Pure Haskell large fixed-width integers and Montgomery arithmetic.
git clone git://git.ppad.tech/fixed.git
Log | Files | Refs | README | LICENSE

commit b27a5d079113939782e288954d9de30f037b3d8e
parent f0c1e3348b3192c1f62003caa8b9eb6ecbaf1f5f
Author: Jared Tobin <jared@jtobin.io>
Date:   Tue, 23 Dec 2025 12:57:07 -0330

lib: init hardening in montgomery modules

Diffstat:
Mlib/Numeric/Montgomery/Secp256k1/Curve.hs | 30++++++++++++++++--------------
Mlib/Numeric/Montgomery/Secp256k1/Scalar.hs | 30++++++++++++++++--------------
Mtest/Montgomery/Curve.hs | 61++++++++++++++++++++++++++++++++++++++++++-------------------
Mtest/Montgomery/Scalar.hs | 58++++++++++++++++++++++++++++++++++++++++------------------
4 files changed, 114 insertions(+), 65 deletions(-)

diff --git a/lib/Numeric/Montgomery/Secp256k1/Curve.hs b/lib/Numeric/Montgomery/Secp256k1/Curve.hs @@ -19,8 +19,8 @@ module Numeric.Montgomery.Secp256k1.Curve ( -- * Montgomery form, secp256k1 field prime modulus Montgomery(..) , render - , to - , from + , to_vartime + , from_vartime , zero , one @@ -92,21 +92,23 @@ render (Montgomery (# Limb a, Limb b, Limb c, Limb d #)) = <> show (W# c) <> ", " <> show (W# d) <> ")" instance Show Montgomery where - show = show . from + show = show . from_vartime +-- | Note that 'fromInteger' necessarily runs in variable time due +-- to conversion from the variable-size, potentially heap-allocated +-- 'Integer' type. instance Num Montgomery where a + b = add a b a - b = sub a b a * b = mul a b negate a = neg a abs = id - fromInteger = to . WW.to_vartime - signum a = case a of - Montgomery (# Limb 0##, Limb 0##, Limb 0##, Limb 0## #) -> 0 - _ -> 1 - -instance Eq Montgomery where - a == b = C.decide (eq a b) + fromInteger = to_vartime . WW.to_vartime + signum (Montgomery (# l0, l1, l2, l3 #)) = + let !(Limb l) = l0 `L.or#` l1 `L.or#` l2 `L.or#` l3 + !n = C.from_word_nonzero# l + !b = C.to_word# n + in Montgomery (# Limb b, Limb 0##, Limb 0##, Limb 0## #) instance NFData Montgomery where rnf (Montgomery a) = case a of (# _, _, _, _ #) -> () @@ -366,14 +368,14 @@ to# x = {-# INLINE to# #-} -- | Convert a 'Wider' word to the Montgomery domain. -to :: Wider -> Montgomery -to (Wider x) = Montgomery (to# x) +to_vartime :: Wider -> Montgomery +to_vartime (Wider x) = Montgomery (to# x) -- | Retrieve a 'Montgomery' word from the Montgomery domain. -- -- This function is a synonym for 'retr'. -from :: Montgomery -> Wider -from = retr +from_vartime :: Montgomery -> Wider +from_vartime = retr add# :: (# Limb, Limb, Limb, Limb #) -- ^ augend diff --git a/lib/Numeric/Montgomery/Secp256k1/Scalar.hs b/lib/Numeric/Montgomery/Secp256k1/Scalar.hs @@ -19,8 +19,8 @@ module Numeric.Montgomery.Secp256k1.Scalar ( -- * Montgomery form, secp256k1 scalar group order modulus Montgomery(..) , render - , to - , from + , to_vartime + , from_vartime , zero , one @@ -80,7 +80,7 @@ import Prelude hiding (or, and, not, exp, odd) data Montgomery = Montgomery !(# Limb, Limb, Limb, Limb #) instance Show Montgomery where - show = show . from + show = show . from_vartime -- | Render a 'Montgomery' value as a 'String', showing its individual -- 'Limb's. @@ -92,19 +92,21 @@ render (Montgomery (# Limb a, Limb b, Limb c, Limb d #)) = "(" <> show (W# a) <> ", " <> show (W# b) <> ", " <> show (W# c) <> ", " <> show (W# d) <> ")" +-- | Note that 'fromInteger' necessarily runs in variable time due +-- to conversion from the variable-size, potentially heap-allocated +-- 'Integer' type. instance Num Montgomery where a + b = add a b a - b = sub a b a * b = mul a b negate a = neg a abs = id - fromInteger = to . WW.to_vartime - signum a = case a of - Montgomery (# Limb 0##, Limb 0##, Limb 0##, Limb 0## #) -> 0 - _ -> 1 - -instance Eq Montgomery where - a == b = C.decide (eq a b) + fromInteger = to_vartime . WW.to_vartime + signum (Montgomery (# l0, l1, l2, l3 #)) = + let !(Limb l) = l0 `L.or#` l1 `L.or#` l2 `L.or#` l3 + !n = C.from_word_nonzero# l + !b = C.to_word# n + in Montgomery (# Limb b, Limb 0##, Limb 0##, Limb 0## #) instance NFData Montgomery where rnf (Montgomery a) = case a of (# _, _, _, _ #) -> () @@ -364,14 +366,14 @@ to# x = {-# INLINE to# #-} -- | Convert a 'Wider' word to the Montgomery domain. -to :: Wider -> Montgomery -to (Wider x) = Montgomery (to# x) +to_vartime :: Wider -> Montgomery +to_vartime (Wider x) = Montgomery (to# x) -- | Retrieve a 'Montgomery' word from the Montgomery domain. -- -- This function is a synonym for 'retr'. -from :: Montgomery -> Wider -from = retr +from_vartime :: Montgomery -> Wider +from_vartime = retr add# :: (# Limb, Limb, Limb, Limb #) -- ^ augend diff --git a/test/Montgomery/Curve.hs b/test/Montgomery/Curve.hs @@ -11,6 +11,7 @@ module Montgomery.Curve ( tests ) where +import qualified Data.Choice as CT import qualified Data.Word.Wider as W import qualified GHC.Num.Integer as I import GHC.Natural @@ -19,6 +20,10 @@ import Test.Tasty import qualified Test.Tasty.HUnit as H import qualified Test.Tasty.QuickCheck as Q +-- orphan Eq instance for testing +instance Eq C.Montgomery where + a == b = CT.decide (C.eq a b) + -- generic modular exponentiation -- b ^ e mod m modexp :: Integer -> Natural -> Natural -> Integer @@ -36,12 +41,15 @@ mm :: C.Montgomery mm = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F repr :: H.Assertion -repr = H.assertBool mempty (W.eq_vartime 0 (C.from mm)) +repr = H.assertBool mempty (W.eq_vartime 0 (C.from_vartime mm)) add_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion add_case t a b s = do - H.assertEqual "sanity" ((W.from_vartime a + W.from_vartime b) `mod` W.from_vartime m) (W.from_vartime s) - H.assertBool t (W.eq_vartime s (C.from (C.to a + C.to b))) + H.assertEqual "sanity" + ((W.from_vartime a + W.from_vartime b) `mod` W.from_vartime m) + (W.from_vartime s) + H.assertBool t + (W.eq_vartime s (C.from_vartime (C.to_vartime a + C.to_vartime b))) add :: H.Assertion add = do @@ -61,8 +69,11 @@ add = do sub_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion sub_case t b a d = do - H.assertEqual "sanity" ((W.from_vartime b - W.from_vartime a) `mod` W.from_vartime m) (W.from_vartime d) - H.assertBool t (W.eq_vartime d (C.from (C.to b - C.to a))) + H.assertEqual "sanity" + ((W.from_vartime b - W.from_vartime a) `mod` W.from_vartime m) + (W.from_vartime d) + H.assertBool t + (W.eq_vartime d (C.from_vartime (C.to_vartime b - C.to_vartime a))) sub :: H.Assertion sub = do @@ -81,8 +92,11 @@ sub = do mul_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion mul_case t a b p = do - H.assertEqual "sanity" ((W.from_vartime a * W.from_vartime b) `mod` W.from_vartime m) (W.from_vartime p) - H.assertBool t (W.eq_vartime p (C.from (C.to a * C.to b))) + H.assertEqual "sanity" + ((W.from_vartime a * W.from_vartime b) `mod` W.from_vartime m) + (W.from_vartime p) + H.assertBool t + (W.eq_vartime p (C.from_vartime (C.to_vartime a * C.to_vartime b))) mul :: H.Assertion mul = do @@ -108,45 +122,54 @@ instance Q.Arbitrary W.Wider where arbitrary = fmap W.to_vartime Q.arbitrary instance Q.Arbitrary C.Montgomery where - arbitrary = fmap C.to Q.arbitrary + arbitrary = fmap C.to_vartime Q.arbitrary add_matches :: W.Wider -> W.Wider -> Bool add_matches a b = - let ma = C.to a - mb = C.to b + let ma = C.to_vartime a + mb = C.to_vartime b ia = W.from_vartime a ib = W.from_vartime b im = W.from_vartime m - in W.eq_vartime (W.to_vartime ((ia + ib) `mod` im)) (C.from (ma + mb)) + in W.eq_vartime + (W.to_vartime ((ia + ib) `mod` im)) + (C.from_vartime (ma + mb)) mul_matches :: W.Wider -> W.Wider -> Bool mul_matches a b = - let ma = C.to a - mb = C.to b + let ma = C.to_vartime a + mb = C.to_vartime b ia = W.from_vartime a ib = W.from_vartime b im = W.from_vartime m - in W.eq_vartime (W.to_vartime ((ia * ib) `mod` im)) (C.from (ma * mb)) + in W.eq_vartime + (W.to_vartime ((ia * ib) `mod` im)) + (C.from_vartime (ma * mb)) sqr_matches :: W.Wider -> Bool sqr_matches a = - let ma = C.to a + let ma = C.to_vartime a ia = W.from_vartime a im = W.from_vartime m - in W.eq_vartime (W.to_vartime ((ia * ia) `mod` im)) (C.from (C.sqr ma)) + in W.eq_vartime + (W.to_vartime ((ia * ia) `mod` im)) + (C.from_vartime (C.sqr ma)) exp_matches :: C.Montgomery -> W.Wider -> Bool exp_matches a b = - let ia = W.from_vartime (C.from a) + let ia = W.from_vartime (C.from_vartime a) nb = fromIntegral (W.from_vartime b) nm = fromIntegral (W.from_vartime m) - in W.eq_vartime (W.to_vartime (modexp ia nb nm)) (C.from (C.exp a b)) + in W.eq_vartime + (W.to_vartime (modexp ia nb nm)) + (C.from_vartime (C.exp a b)) inv_valid :: Q.NonZero C.Montgomery -> Bool inv_valid (Q.NonZero s) = C.eq_vartime (C.inv s * s) 1 odd_correct :: C.Montgomery -> Bool -odd_correct w = C.odd w == I.integerTestBit (W.from_vartime (C.from w)) 0 +odd_correct w = + C.odd w == I.integerTestBit (W.from_vartime (C.from_vartime w)) 0 tests :: TestTree tests = testGroup "montgomery tests (curve)" [ diff --git a/test/Montgomery/Scalar.hs b/test/Montgomery/Scalar.hs @@ -11,6 +11,7 @@ module Montgomery.Scalar ( tests ) where +import qualified Data.Choice as CT import qualified Data.Word.Wider as W import qualified GHC.Num.Integer as I import GHC.Natural @@ -19,6 +20,10 @@ import Test.Tasty import qualified Test.Tasty.HUnit as H import qualified Test.Tasty.QuickCheck as Q +-- orphan Eq instance for testing +instance Eq S.Montgomery where + a == b = CT.decide (S.eq a b) + -- generic modular exponentiation -- b ^ e mod m modexp :: Integer -> Natural -> Natural -> Integer @@ -36,12 +41,15 @@ mm :: S.Montgomery mm = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 repr :: H.Assertion -repr = H.assertBool mempty (W.eq_vartime 0 (S.from mm)) +repr = H.assertBool mempty (W.eq_vartime 0 (S.from_vartime mm)) add_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion add_case t a b s = do - H.assertEqual "sanity" ((W.from_vartime a + W.from_vartime b) `mod` W.from_vartime m) (W.from_vartime s) - H.assertBool t (W.eq_vartime s (S.from (S.to a + S.to b))) + H.assertEqual "sanity" + ((W.from_vartime a + W.from_vartime b) `mod` W.from_vartime m) + (W.from_vartime s) + H.assertBool t + (W.eq_vartime s (S.from_vartime (S.to_vartime a + S.to_vartime b))) add :: H.Assertion add = do @@ -61,8 +69,11 @@ add = do sub_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion sub_case t b a d = do - H.assertEqual "sanity" ((W.from_vartime b - W.from_vartime a) `mod` W.from_vartime m) (W.from_vartime d) - H.assertBool t (W.eq_vartime d (S.from (S.to b - S.to a))) + H.assertEqual "sanity" + ((W.from_vartime b - W.from_vartime a) `mod` W.from_vartime m) + (W.from_vartime d) + H.assertBool t + (W.eq_vartime d (S.from_vartime (S.to_vartime b - S.to_vartime a))) sub :: H.Assertion sub = do @@ -81,8 +92,11 @@ sub = do mul_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion mul_case t a b p = do - H.assertEqual "sanity" ((W.from_vartime a * W.from_vartime b) `mod` W.from_vartime m) (W.from_vartime p) - H.assertBool t (W.eq_vartime p (S.from (S.to a * S.to b))) + H.assertEqual "sanity" + ((W.from_vartime a * W.from_vartime b) `mod` W.from_vartime m) + (W.from_vartime p) + H.assertBool t + (W.eq_vartime p (S.from_vartime (S.to_vartime a * S.to_vartime b))) mul :: H.Assertion mul = do @@ -108,39 +122,47 @@ instance Q.Arbitrary W.Wider where arbitrary = fmap W.to_vartime Q.arbitrary instance Q.Arbitrary S.Montgomery where - arbitrary = fmap S.to Q.arbitrary + arbitrary = fmap S.to_vartime Q.arbitrary add_matches :: W.Wider -> W.Wider -> Bool add_matches a b = - let ma = S.to a - mb = S.to b + let ma = S.to_vartime a + mb = S.to_vartime b ia = W.from_vartime a ib = W.from_vartime b im = W.from_vartime m - in W.eq_vartime (W.to_vartime ((ia + ib) `mod` im)) (S.from (ma + mb)) + in W.eq_vartime + (W.to_vartime ((ia + ib) `mod` im)) + (S.from_vartime (ma + mb)) mul_matches :: W.Wider -> W.Wider -> Bool mul_matches a b = - let ma = S.to a - mb = S.to b + let ma = S.to_vartime a + mb = S.to_vartime b ia = W.from_vartime a ib = W.from_vartime b im = W.from_vartime m - in W.eq_vartime (W.to_vartime ((ia * ib) `mod` im)) (S.from (ma * mb)) + in W.eq_vartime + (W.to_vartime ((ia * ib) `mod` im)) + (S.from_vartime (ma * mb)) sqr_matches :: W.Wider -> Bool sqr_matches a = - let ma = S.to a + let ma = S.to_vartime a ia = W.from_vartime a im = W.from_vartime m - in W.eq_vartime (W.to_vartime ((ia * ia) `mod` im)) (S.from (S.sqr ma)) + in W.eq_vartime + (W.to_vartime ((ia * ia) `mod` im)) + (S.from_vartime (S.sqr ma)) exp_matches :: S.Montgomery -> W.Wider -> Bool exp_matches a b = - let ia = W.from_vartime (S.from a) + let ia = W.from_vartime (S.from_vartime a) nb = fromIntegral (W.from_vartime b) nm = fromIntegral (W.from_vartime m) - in W.eq_vartime (W.to_vartime (modexp ia nb nm)) (S.from (S.exp a b)) + in W.eq_vartime + (W.to_vartime (modexp ia nb nm)) + (S.from_vartime (S.exp a b)) inv_valid :: Q.NonZero S.Montgomery -> Bool inv_valid (Q.NonZero s) = S.eq_vartime (S.inv s * s) 1