fixed

Pure Haskell large fixed-width integers and Montgomery arithmetic.
git clone git://git.ppad.tech/fixed.git
Log | Files | Refs | README | LICENSE

commit d46e5dd868f4e35b819c6817312d5b670fe82238
parent babdb8a23f4f859c7da14c4e457e091784e4d54d
Author: Jared Tobin <jared@jtobin.io>
Date:   Tue, 23 Dec 2025 12:28:50 -0330

lib: misc constant-time hardening

Diffstat:
Mbench/Weight.hs | 10+++++-----
Mlib/Data/Word/Wide.hs | 41++++++++++++++++++++++++++++++-----------
Mlib/Data/Word/Wider.hs | 85++++++++++++++++++++++++++++++++++++++++++++++++++-----------------------------
Mlib/Numeric/Montgomery/Secp256k1/Curve.hs | 2+-
Mlib/Numeric/Montgomery/Secp256k1/Scalar.hs | 2+-
Mtest/Montgomery/Curve.hs | 40++++++++++++++++++++--------------------
Mtest/Montgomery/Scalar.hs | 38+++++++++++++++++++-------------------
Mtest/Wider.hs | 22+++++++++++-----------
8 files changed, 141 insertions(+), 99 deletions(-)

diff --git a/bench/Weight.hs b/bench/Weight.hs @@ -38,11 +38,11 @@ cmp = let !a = 1 !b = 2 !c = 2 ^ 255 - 19 - in wgroup "cmp" $ do - func "cmp: 1 < 2" (W.cmp a) b - func "cmp: 2 < 1" (W.cmp b) a - func "cmp: 2 < 2 ^ 255 - 19" (W.cmp b) c - func "cmp: 2 ^ 255 - 19 < 2" (W.cmp c) b + in wgroup "cmp_vartime" $ do + func "cmp_vartime: 1 < 2" (W.cmp_vartime a) b + func "cmp_vartime: 2 < 1" (W.cmp_vartime b) a + func "cmp_vartime: 2 < 2 ^ 255 - 19" (W.cmp_vartime b) c + func "cmp_vartime: 2 ^ 255 - 19 < 2" (W.cmp_vartime c) b add :: Weigh () add = diff --git a/lib/Data/Word/Wide.hs b/lib/Data/Word/Wide.hs @@ -1,6 +1,7 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE NumericUnderscores #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE UnboxedSums #-} {-# LANGUAGE UnboxedTuples #-} @@ -20,8 +21,8 @@ module Data.Word.Wide ( -- * Construction, Conversion , wide - , to - , from + , to_vartime + , from_vartime -- * Bit Manipulation , or @@ -55,6 +56,7 @@ module Data.Word.Wide ( import Control.DeepSeq import Data.Bits ((.|.), (.&.), (.<<.), (.>>.)) import qualified Data.Bits as B +import qualified Data.Choice as C import Data.Word.Limb (Limb(..)) import qualified Data.Word.Limb as L import GHC.Exts @@ -68,22 +70,33 @@ fi = fromIntegral -- wide words ----------------------------------------------------------------- +pattern Limb2 + :: Word# -> Word# + -> (# Limb, Limb #) +pattern Limb2 w0 w1 = (# Limb w0, Limb w1 #) +{-# COMPLETE Limb2 #-} + -- | Little-endian wide words. data Wide = Wide !(# Limb, Limb #) instance Show Wide where - show = show . from + show = show . from_vartime +-- | Note that 'fromInteger' necessarily runs in variable time due +-- to conversion from the variable-size, potentially heap-allocated +-- 'Integer' type. instance Num Wide where (+) = add (-) = sub (*) = mul abs = id - fromInteger = to + fromInteger = to_vartime negate = neg - signum a = case a of - Wide (# Limb 0##, Limb 0## #) -> 0 - _ -> 1 + signum (Wide (# l0, l1 #)) = + let !(Limb l) = l0 `L.or#` l1 + !n = C.from_word_nonzero# l + !b = C.to_word# n + in Wide (Limb2 b 0##) instance NFData Wide where rnf (Wide a) = case a of (# _, _ #) -> () @@ -95,8 +108,11 @@ wide :: Word -> Word -> Wide wide (W# l) (W# h) = Wide (# Limb l, Limb h #) -- | Convert an 'Integer' to a 'Wide' word. -to :: Integer -> Wide -to n = +-- +-- >>> to_vartime 1 +-- 1 +to_vartime :: Integer -> Wide +to_vartime n = let !size = B.finiteBitSize (0 :: Word) !mask = fi (maxBound :: Word) :: Integer !(W# w0) = fi (n .&. mask) @@ -104,8 +120,11 @@ to n = in Wide (# Limb w0, Limb w1 #) -- | Convert a 'Wide' word to an 'Integer'. -from :: Wide -> Integer -from (Wide (# Limb a, Limb b #)) = +-- +-- >>> from_vartime 1 +-- 1 +from_vartime :: Wide -> Integer +from_vartime (Wide (# Limb a, Limb b #)) = fi (W# b) .<<. (B.finiteBitSize (0 :: Word)) .|. fi (W# a) diff --git a/lib/Data/Word/Wider.hs b/lib/Data/Word/Wider.hs @@ -1,6 +1,7 @@ {-# LANGUAGE BangPatterns #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE NumericUnderscores #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE UnboxedSums #-} {-# LANGUAGE UnboxedTuples #-} @@ -18,12 +19,12 @@ module Data.Word.Wider ( -- * Four-limb words Wider(..) , wider - , to - , from + , to_vartime + , from_vartime -- * Comparison , eq_vartime - , cmp + , cmp_vartime , cmp# , eq# , lt @@ -83,7 +84,7 @@ import qualified Data.Bits as B import qualified Data.Choice as C import Data.Word.Limb (Limb(..)) import qualified Data.Word.Limb as L -import GHC.Exts (Word(..), Int(..), Int#) +import GHC.Exts (Word(..), Int(..), Word#, Int#) import qualified GHC.Exts as Exts import Prelude hiding (div, mod, or, and, not, quot, rem, recip, odd) @@ -95,6 +96,12 @@ fi = fromIntegral -- wider words ---------------------------------------------------------------- +pattern Limb4 + :: Word# -> Word# -> Word# -> Word# + -> (# Limb, Limb, Limb, Limb #) +pattern Limb4 w0 w1 w2 w3 = (# Limb w0, Limb w1, Limb w2, Limb w3 #) +{-# COMPLETE Limb4 #-} + -- | Little-endian wider words, consisting of four 'Limbs'. -- -- >>> 1 :: Wider @@ -102,24 +109,23 @@ fi = fromIntegral data Wider = Wider !(# Limb, Limb, Limb, Limb #) instance Show Wider where - show = show . from - -instance Eq Wider where - Wider a == Wider b = C.decide (eq# a b) - -instance Ord Wider where - compare = cmp + show = show . from_vartime +-- | Note that 'fromInteger' necessarily runs in variable time due +-- to conversion from the variable-size, potentially heap-allocated +-- 'Integer' type. instance Num Wider where (+) = add (-) = sub (*) = mul abs = id - fromInteger = to - negate w = add (not w) (Wider (# Limb 1##, Limb 0##, Limb 0##, Limb 0## #)) - signum a = case a of - Wider (# Limb 0##, Limb 0##, Limb 0##, Limb 0## #) -> 0 - _ -> 1 + fromInteger = to_vartime + negate w = add (not w) (Wider (Limb4 1## 0## 0## 0##)) + signum (Wider (# l0, l1, l2, l3 #)) = + let !(Limb l) = l0 `L.or#` l1 `L.or#` l2 `L.or#` l3 + !n = C.from_word_nonzero# l + !b = C.to_word# n + in Wider (Limb4 b 0## 0## 0##) instance NFData Wider where rnf (Wider a) = case a of @@ -132,8 +138,8 @@ eq# -> (# Limb, Limb, Limb, Limb #) -> C.Choice eq# a b = - let !(# Limb a0, Limb a1, Limb a2, Limb a3 #) = a - !(# Limb b0, Limb b1, Limb b2, Limb b3 #) = b + let !(Limb4 a0 a1 a2 a3) = a + !(Limb4 b0 b1 b2 b3) = b in C.eq_wider# (# a0, a1, a2, a3 #) (# b0, b1, b2, b3 #) {-# INLINE eq# #-} @@ -161,6 +167,13 @@ lt# a b = in C.from_word_mask# bor {-# INLINE lt# #-} +-- | Constant-time less-than comparison between 'Wider' values. +-- +-- >>> import qualified Data.Choice as CT +-- >>> CT.decide (lt 1 2) +-- True +-- >>> CT.decide (lt 1 1) +-- False lt :: Wider -> Wider -> C.Choice lt (Wider a) (Wider b) = lt# a b @@ -173,6 +186,13 @@ gt# a b = in C.from_word_mask# bor {-# INLINE gt# #-} +-- | Constant-time greater-than comparison between 'Wider' values. +-- +-- >>> import qualified Data.Choice as CT +-- >>> CT.decide (gt 1 2) +-- False +-- >>> CT.decide (gt 2 1) +-- True gt :: Wider -> Wider -> C.Choice gt (Wider a) (Wider b) = gt# a b @@ -194,20 +214,23 @@ cmp# (# l0, l1, l2, l3 #) (# r0, r1, r2, r3 #) = in (Exts.word2Int# (C.to_word# (L.nonzero# d3))) Exts.*# s {-# INLINE cmp# #-} --- | Constant-time comparison between 'Wider' words. +-- | Variable-time comparison between 'Wider' words. +-- +-- The actual comparison here is performed in constant time, but we must +-- branch to return an 'Ordering'. -- --- >>> cmp 1 2 +-- >>> cmp_vartime 1 2 -- LT --- >>> cmp 2 1 +-- >>> cmp_vartime 2 1 -- GT --- >>> cmp 2 2 +-- >>> cmp_vartime 2 2 -- EQ -cmp :: Wider -> Wider -> Ordering -cmp (Wider a) (Wider b) = case cmp# a b of +cmp_vartime :: Wider -> Wider -> Ordering +cmp_vartime (Wider a) (Wider b) = case cmp# a b of 1# -> GT 0# -> EQ _ -> LT -{-# INLINABLE cmp #-} +{-# INLINABLE cmp_vartime #-} -- construction / conversion -------------------------------------------------- @@ -222,10 +245,10 @@ wider (W# w0) (W# w1) (W# w2) (W# w3) = Wider -- | Convert an 'Integer' to a 'Wider' word. -- --- >>> to 1 +-- >>> to_vartime 1 -- 1 -to :: Integer -> Wider -to n = +to_vartime :: Integer -> Wider +to_vartime n = let !size = B.finiteBitSize (0 :: Word) !mask = fi (maxBound :: Word) :: Integer !(W# w0) = fi (n .&. mask) @@ -236,10 +259,10 @@ to n = -- | Convert a 'Wider' word to an 'Integer'. -- --- >>> from 1 +-- >>> from_vartime 1 -- 1 -from :: Wider -> Integer -from (Wider (# Limb w0, Limb w1, Limb w2, Limb w3 #)) = +from_vartime :: Wider -> Integer +from_vartime (Wider (# Limb w0, Limb w1, Limb w2, Limb w3 #)) = fi (W# w3) .<<. (3 * size) .|. fi (W# w2) .<<. (2 * size) .|. fi (W# w1) .<<. size diff --git a/lib/Numeric/Montgomery/Secp256k1/Curve.hs b/lib/Numeric/Montgomery/Secp256k1/Curve.hs @@ -100,7 +100,7 @@ instance Num Montgomery where a * b = mul a b negate a = neg a abs = id - fromInteger = to . WW.to + fromInteger = to . WW.to_vartime signum a = case a of Montgomery (# Limb 0##, Limb 0##, Limb 0##, Limb 0## #) -> 0 _ -> 1 diff --git a/lib/Numeric/Montgomery/Secp256k1/Scalar.hs b/lib/Numeric/Montgomery/Secp256k1/Scalar.hs @@ -98,7 +98,7 @@ instance Num Montgomery where a * b = mul a b negate a = neg a abs = id - fromInteger = to . WW.to + fromInteger = to . WW.to_vartime signum a = case a of Montgomery (# Limb 0##, Limb 0##, Limb 0##, Limb 0## #) -> 0 _ -> 1 diff --git a/test/Montgomery/Curve.hs b/test/Montgomery/Curve.hs @@ -40,7 +40,7 @@ repr = H.assertBool mempty (W.eq_vartime 0 (C.from mm)) add_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion add_case t a b s = do - H.assertEqual "sanity" ((W.from a + W.from b) `mod` W.from m) (W.from s) + H.assertEqual "sanity" ((W.from_vartime a + W.from_vartime b) `mod` W.from_vartime m) (W.from_vartime s) H.assertBool t (W.eq_vartime s (C.from (C.to a + C.to b))) add :: H.Assertion @@ -61,7 +61,7 @@ add = do sub_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion sub_case t b a d = do - H.assertEqual "sanity" ((W.from b - W.from a) `mod` W.from m) (W.from d) + H.assertEqual "sanity" ((W.from_vartime b - W.from_vartime a) `mod` W.from_vartime m) (W.from_vartime d) H.assertBool t (W.eq_vartime d (C.from (C.to b - C.to a))) sub :: H.Assertion @@ -81,7 +81,7 @@ sub = do mul_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion mul_case t a b p = do - H.assertEqual "sanity" ((W.from a * W.from b) `mod` W.from m) (W.from p) + H.assertEqual "sanity" ((W.from_vartime a * W.from_vartime b) `mod` W.from_vartime m) (W.from_vartime p) H.assertBool t (W.eq_vartime p (C.from (C.to a * C.to b))) mul :: H.Assertion @@ -105,7 +105,7 @@ mul = do 0x000000000000000000000000000000000000000000000001000007A2000E90A1 instance Q.Arbitrary W.Wider where - arbitrary = fmap W.to Q.arbitrary + arbitrary = fmap W.to_vartime Q.arbitrary instance Q.Arbitrary C.Montgomery where arbitrary = fmap C.to Q.arbitrary @@ -114,39 +114,39 @@ add_matches :: W.Wider -> W.Wider -> Bool add_matches a b = let ma = C.to a mb = C.to b - ia = W.from a - ib = W.from b - im = W.from m - in W.eq_vartime (W.to ((ia + ib) `mod` im)) (C.from (ma + mb)) + ia = W.from_vartime a + ib = W.from_vartime b + im = W.from_vartime m + in W.eq_vartime (W.to_vartime ((ia + ib) `mod` im)) (C.from (ma + mb)) mul_matches :: W.Wider -> W.Wider -> Bool mul_matches a b = let ma = C.to a mb = C.to b - ia = W.from a - ib = W.from b - im = W.from m - in W.eq_vartime (W.to ((ia * ib) `mod` im)) (C.from (ma * mb)) + ia = W.from_vartime a + ib = W.from_vartime b + im = W.from_vartime m + in W.eq_vartime (W.to_vartime ((ia * ib) `mod` im)) (C.from (ma * mb)) sqr_matches :: W.Wider -> Bool sqr_matches a = let ma = C.to a - ia = W.from a - im = W.from m - in W.eq_vartime (W.to ((ia * ia) `mod` im)) (C.from (C.sqr ma)) + ia = W.from_vartime a + im = W.from_vartime m + in W.eq_vartime (W.to_vartime ((ia * ia) `mod` im)) (C.from (C.sqr ma)) exp_matches :: C.Montgomery -> W.Wider -> Bool exp_matches a b = - let ia = W.from (C.from a) - nb = fromIntegral (W.from b) - nm = fromIntegral (W.from m) - in W.eq_vartime (W.to (modexp ia nb nm)) (C.from (C.exp a b)) + let ia = W.from_vartime (C.from a) + nb = fromIntegral (W.from_vartime b) + nm = fromIntegral (W.from_vartime m) + in W.eq_vartime (W.to_vartime (modexp ia nb nm)) (C.from (C.exp a b)) inv_valid :: Q.NonZero C.Montgomery -> Bool inv_valid (Q.NonZero s) = C.eq_vartime (C.inv s * s) 1 odd_correct :: C.Montgomery -> Bool -odd_correct w = C.odd w == I.integerTestBit (W.from (C.from w)) 0 +odd_correct w = C.odd w == I.integerTestBit (W.from_vartime (C.from w)) 0 tests :: TestTree tests = testGroup "montgomery tests (curve)" [ diff --git a/test/Montgomery/Scalar.hs b/test/Montgomery/Scalar.hs @@ -40,7 +40,7 @@ repr = H.assertBool mempty (W.eq_vartime 0 (S.from mm)) add_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion add_case t a b s = do - H.assertEqual "sanity" ((W.from a + W.from b) `mod` W.from m) (W.from s) + H.assertEqual "sanity" ((W.from_vartime a + W.from_vartime b) `mod` W.from_vartime m) (W.from_vartime s) H.assertBool t (W.eq_vartime s (S.from (S.to a + S.to b))) add :: H.Assertion @@ -61,7 +61,7 @@ add = do sub_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion sub_case t b a d = do - H.assertEqual "sanity" ((W.from b - W.from a) `mod` W.from m) (W.from d) + H.assertEqual "sanity" ((W.from_vartime b - W.from_vartime a) `mod` W.from_vartime m) (W.from_vartime d) H.assertBool t (W.eq_vartime d (S.from (S.to b - S.to a))) sub :: H.Assertion @@ -81,7 +81,7 @@ sub = do mul_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion mul_case t a b p = do - H.assertEqual "sanity" ((W.from a * W.from b) `mod` W.from m) (W.from p) + H.assertEqual "sanity" ((W.from_vartime a * W.from_vartime b) `mod` W.from_vartime m) (W.from_vartime p) H.assertBool t (W.eq_vartime p (S.from (S.to a * S.to b))) mul :: H.Assertion @@ -105,7 +105,7 @@ mul = do 0x9D671CD581C69BC5E697F5E45BCD07C6741496C20E7CF878896CF21467D7D140 instance Q.Arbitrary W.Wider where - arbitrary = fmap W.to Q.arbitrary + arbitrary = fmap W.to_vartime Q.arbitrary instance Q.Arbitrary S.Montgomery where arbitrary = fmap S.to Q.arbitrary @@ -114,33 +114,33 @@ add_matches :: W.Wider -> W.Wider -> Bool add_matches a b = let ma = S.to a mb = S.to b - ia = W.from a - ib = W.from b - im = W.from m - in W.eq_vartime (W.to ((ia + ib) `mod` im)) (S.from (ma + mb)) + ia = W.from_vartime a + ib = W.from_vartime b + im = W.from_vartime m + in W.eq_vartime (W.to_vartime ((ia + ib) `mod` im)) (S.from (ma + mb)) mul_matches :: W.Wider -> W.Wider -> Bool mul_matches a b = let ma = S.to a mb = S.to b - ia = W.from a - ib = W.from b - im = W.from m - in W.eq_vartime (W.to ((ia * ib) `mod` im)) (S.from (ma * mb)) + ia = W.from_vartime a + ib = W.from_vartime b + im = W.from_vartime m + in W.eq_vartime (W.to_vartime ((ia * ib) `mod` im)) (S.from (ma * mb)) sqr_matches :: W.Wider -> Bool sqr_matches a = let ma = S.to a - ia = W.from a - im = W.from m - in W.eq_vartime (W.to ((ia * ia) `mod` im)) (S.from (S.sqr ma)) + ia = W.from_vartime a + im = W.from_vartime m + in W.eq_vartime (W.to_vartime ((ia * ia) `mod` im)) (S.from (S.sqr ma)) exp_matches :: S.Montgomery -> W.Wider -> Bool exp_matches a b = - let ia = W.from (S.from a) - nb = fromIntegral (W.from b) - nm = fromIntegral (W.from m) - in W.eq_vartime (W.to (modexp ia nb nm)) (S.from (S.exp a b)) + let ia = W.from_vartime (S.from a) + nb = fromIntegral (W.from_vartime b) + nm = fromIntegral (W.from_vartime m) + in W.eq_vartime (W.to_vartime (modexp ia nb nm)) (S.from (S.exp a b)) inv_valid :: Q.NonZero S.Montgomery -> Bool inv_valid (Q.NonZero s) = S.eq_vartime (S.inv s * s) 1 diff --git a/test/Wider.hs b/test/Wider.hs @@ -107,17 +107,17 @@ cmp = do let !a = 0 !b = 1 !c = 2 ^ (256 :: Word) - 1 - H.assertEqual mempty (W.cmp a b) LT - H.assertEqual mempty (W.cmp a c) LT - H.assertEqual mempty (W.cmp b c) LT + H.assertEqual mempty (W.cmp_vartime a b) LT + H.assertEqual mempty (W.cmp_vartime a c) LT + H.assertEqual mempty (W.cmp_vartime b c) LT - H.assertEqual mempty (W.cmp a a) EQ - H.assertEqual mempty (W.cmp b b) EQ - H.assertEqual mempty (W.cmp c c) EQ + H.assertEqual mempty (W.cmp_vartime a a) EQ + H.assertEqual mempty (W.cmp_vartime b b) EQ + H.assertEqual mempty (W.cmp_vartime c c) EQ - H.assertEqual mempty (W.cmp b a) GT - H.assertEqual mempty (W.cmp c a) GT - H.assertEqual mempty (W.cmp c b) GT + H.assertEqual mempty (W.cmp_vartime b a) GT + H.assertEqual mempty (W.cmp_vartime c a) GT + H.assertEqual mempty (W.cmp_vartime c b) GT sqr :: H.Assertion sqr = do @@ -144,10 +144,10 @@ sub_mod = do H.assertBool mempty (W.eq_vartime o e) instance Q.Arbitrary W.Wider where - arbitrary = fmap W.to Q.arbitrary + arbitrary = fmap W.to_vartime Q.arbitrary odd_correct :: W.Wider -> Bool -odd_correct w = C.decide (W.odd w) == I.integerTestBit (W.from w) 0 +odd_correct w = C.decide (W.odd w) == I.integerTestBit (W.from_vartime w) 0 tests :: TestTree tests = testGroup "wider tests" [