commit d46e5dd868f4e35b819c6817312d5b670fe82238
parent babdb8a23f4f859c7da14c4e457e091784e4d54d
Author: Jared Tobin <jared@jtobin.io>
Date: Tue, 23 Dec 2025 12:28:50 -0330
lib: misc constant-time hardening
Diffstat:
8 files changed, 141 insertions(+), 99 deletions(-)
diff --git a/bench/Weight.hs b/bench/Weight.hs
@@ -38,11 +38,11 @@ cmp =
let !a = 1
!b = 2
!c = 2 ^ 255 - 19
- in wgroup "cmp" $ do
- func "cmp: 1 < 2" (W.cmp a) b
- func "cmp: 2 < 1" (W.cmp b) a
- func "cmp: 2 < 2 ^ 255 - 19" (W.cmp b) c
- func "cmp: 2 ^ 255 - 19 < 2" (W.cmp c) b
+ in wgroup "cmp_vartime" $ do
+ func "cmp_vartime: 1 < 2" (W.cmp_vartime a) b
+ func "cmp_vartime: 2 < 1" (W.cmp_vartime b) a
+ func "cmp_vartime: 2 < 2 ^ 255 - 19" (W.cmp_vartime b) c
+ func "cmp_vartime: 2 ^ 255 - 19 < 2" (W.cmp_vartime c) b
add :: Weigh ()
add =
diff --git a/lib/Data/Word/Wide.hs b/lib/Data/Word/Wide.hs
@@ -1,6 +1,7 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NumericUnderscores #-}
+{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE UnboxedSums #-}
{-# LANGUAGE UnboxedTuples #-}
@@ -20,8 +21,8 @@ module Data.Word.Wide (
-- * Construction, Conversion
, wide
- , to
- , from
+ , to_vartime
+ , from_vartime
-- * Bit Manipulation
, or
@@ -55,6 +56,7 @@ module Data.Word.Wide (
import Control.DeepSeq
import Data.Bits ((.|.), (.&.), (.<<.), (.>>.))
import qualified Data.Bits as B
+import qualified Data.Choice as C
import Data.Word.Limb (Limb(..))
import qualified Data.Word.Limb as L
import GHC.Exts
@@ -68,22 +70,33 @@ fi = fromIntegral
-- wide words -----------------------------------------------------------------
+pattern Limb2
+ :: Word# -> Word#
+ -> (# Limb, Limb #)
+pattern Limb2 w0 w1 = (# Limb w0, Limb w1 #)
+{-# COMPLETE Limb2 #-}
+
-- | Little-endian wide words.
data Wide = Wide !(# Limb, Limb #)
instance Show Wide where
- show = show . from
+ show = show . from_vartime
+-- | Note that 'fromInteger' necessarily runs in variable time due
+-- to conversion from the variable-size, potentially heap-allocated
+-- 'Integer' type.
instance Num Wide where
(+) = add
(-) = sub
(*) = mul
abs = id
- fromInteger = to
+ fromInteger = to_vartime
negate = neg
- signum a = case a of
- Wide (# Limb 0##, Limb 0## #) -> 0
- _ -> 1
+ signum (Wide (# l0, l1 #)) =
+ let !(Limb l) = l0 `L.or#` l1
+ !n = C.from_word_nonzero# l
+ !b = C.to_word# n
+ in Wide (Limb2 b 0##)
instance NFData Wide where
rnf (Wide a) = case a of (# _, _ #) -> ()
@@ -95,8 +108,11 @@ wide :: Word -> Word -> Wide
wide (W# l) (W# h) = Wide (# Limb l, Limb h #)
-- | Convert an 'Integer' to a 'Wide' word.
-to :: Integer -> Wide
-to n =
+--
+-- >>> to_vartime 1
+-- 1
+to_vartime :: Integer -> Wide
+to_vartime n =
let !size = B.finiteBitSize (0 :: Word)
!mask = fi (maxBound :: Word) :: Integer
!(W# w0) = fi (n .&. mask)
@@ -104,8 +120,11 @@ to n =
in Wide (# Limb w0, Limb w1 #)
-- | Convert a 'Wide' word to an 'Integer'.
-from :: Wide -> Integer
-from (Wide (# Limb a, Limb b #)) =
+--
+-- >>> from_vartime 1
+-- 1
+from_vartime :: Wide -> Integer
+from_vartime (Wide (# Limb a, Limb b #)) =
fi (W# b) .<<. (B.finiteBitSize (0 :: Word))
.|. fi (W# a)
diff --git a/lib/Data/Word/Wider.hs b/lib/Data/Word/Wider.hs
@@ -1,6 +1,7 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NumericUnderscores #-}
+{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE UnboxedSums #-}
{-# LANGUAGE UnboxedTuples #-}
@@ -18,12 +19,12 @@ module Data.Word.Wider (
-- * Four-limb words
Wider(..)
, wider
- , to
- , from
+ , to_vartime
+ , from_vartime
-- * Comparison
, eq_vartime
- , cmp
+ , cmp_vartime
, cmp#
, eq#
, lt
@@ -83,7 +84,7 @@ import qualified Data.Bits as B
import qualified Data.Choice as C
import Data.Word.Limb (Limb(..))
import qualified Data.Word.Limb as L
-import GHC.Exts (Word(..), Int(..), Int#)
+import GHC.Exts (Word(..), Int(..), Word#, Int#)
import qualified GHC.Exts as Exts
import Prelude hiding (div, mod, or, and, not, quot, rem, recip, odd)
@@ -95,6 +96,12 @@ fi = fromIntegral
-- wider words ----------------------------------------------------------------
+pattern Limb4
+ :: Word# -> Word# -> Word# -> Word#
+ -> (# Limb, Limb, Limb, Limb #)
+pattern Limb4 w0 w1 w2 w3 = (# Limb w0, Limb w1, Limb w2, Limb w3 #)
+{-# COMPLETE Limb4 #-}
+
-- | Little-endian wider words, consisting of four 'Limbs'.
--
-- >>> 1 :: Wider
@@ -102,24 +109,23 @@ fi = fromIntegral
data Wider = Wider !(# Limb, Limb, Limb, Limb #)
instance Show Wider where
- show = show . from
-
-instance Eq Wider where
- Wider a == Wider b = C.decide (eq# a b)
-
-instance Ord Wider where
- compare = cmp
+ show = show . from_vartime
+-- | Note that 'fromInteger' necessarily runs in variable time due
+-- to conversion from the variable-size, potentially heap-allocated
+-- 'Integer' type.
instance Num Wider where
(+) = add
(-) = sub
(*) = mul
abs = id
- fromInteger = to
- negate w = add (not w) (Wider (# Limb 1##, Limb 0##, Limb 0##, Limb 0## #))
- signum a = case a of
- Wider (# Limb 0##, Limb 0##, Limb 0##, Limb 0## #) -> 0
- _ -> 1
+ fromInteger = to_vartime
+ negate w = add (not w) (Wider (Limb4 1## 0## 0## 0##))
+ signum (Wider (# l0, l1, l2, l3 #)) =
+ let !(Limb l) = l0 `L.or#` l1 `L.or#` l2 `L.or#` l3
+ !n = C.from_word_nonzero# l
+ !b = C.to_word# n
+ in Wider (Limb4 b 0## 0## 0##)
instance NFData Wider where
rnf (Wider a) = case a of
@@ -132,8 +138,8 @@ eq#
-> (# Limb, Limb, Limb, Limb #)
-> C.Choice
eq# a b =
- let !(# Limb a0, Limb a1, Limb a2, Limb a3 #) = a
- !(# Limb b0, Limb b1, Limb b2, Limb b3 #) = b
+ let !(Limb4 a0 a1 a2 a3) = a
+ !(Limb4 b0 b1 b2 b3) = b
in C.eq_wider# (# a0, a1, a2, a3 #) (# b0, b1, b2, b3 #)
{-# INLINE eq# #-}
@@ -161,6 +167,13 @@ lt# a b =
in C.from_word_mask# bor
{-# INLINE lt# #-}
+-- | Constant-time less-than comparison between 'Wider' values.
+--
+-- >>> import qualified Data.Choice as CT
+-- >>> CT.decide (lt 1 2)
+-- True
+-- >>> CT.decide (lt 1 1)
+-- False
lt :: Wider -> Wider -> C.Choice
lt (Wider a) (Wider b) = lt# a b
@@ -173,6 +186,13 @@ gt# a b =
in C.from_word_mask# bor
{-# INLINE gt# #-}
+-- | Constant-time greater-than comparison between 'Wider' values.
+--
+-- >>> import qualified Data.Choice as CT
+-- >>> CT.decide (gt 1 2)
+-- False
+-- >>> CT.decide (gt 2 1)
+-- True
gt :: Wider -> Wider -> C.Choice
gt (Wider a) (Wider b) = gt# a b
@@ -194,20 +214,23 @@ cmp# (# l0, l1, l2, l3 #) (# r0, r1, r2, r3 #) =
in (Exts.word2Int# (C.to_word# (L.nonzero# d3))) Exts.*# s
{-# INLINE cmp# #-}
--- | Constant-time comparison between 'Wider' words.
+-- | Variable-time comparison between 'Wider' words.
+--
+-- The actual comparison here is performed in constant time, but we must
+-- branch to return an 'Ordering'.
--
--- >>> cmp 1 2
+-- >>> cmp_vartime 1 2
-- LT
--- >>> cmp 2 1
+-- >>> cmp_vartime 2 1
-- GT
--- >>> cmp 2 2
+-- >>> cmp_vartime 2 2
-- EQ
-cmp :: Wider -> Wider -> Ordering
-cmp (Wider a) (Wider b) = case cmp# a b of
+cmp_vartime :: Wider -> Wider -> Ordering
+cmp_vartime (Wider a) (Wider b) = case cmp# a b of
1# -> GT
0# -> EQ
_ -> LT
-{-# INLINABLE cmp #-}
+{-# INLINABLE cmp_vartime #-}
-- construction / conversion --------------------------------------------------
@@ -222,10 +245,10 @@ wider (W# w0) (W# w1) (W# w2) (W# w3) = Wider
-- | Convert an 'Integer' to a 'Wider' word.
--
--- >>> to 1
+-- >>> to_vartime 1
-- 1
-to :: Integer -> Wider
-to n =
+to_vartime :: Integer -> Wider
+to_vartime n =
let !size = B.finiteBitSize (0 :: Word)
!mask = fi (maxBound :: Word) :: Integer
!(W# w0) = fi (n .&. mask)
@@ -236,10 +259,10 @@ to n =
-- | Convert a 'Wider' word to an 'Integer'.
--
--- >>> from 1
+-- >>> from_vartime 1
-- 1
-from :: Wider -> Integer
-from (Wider (# Limb w0, Limb w1, Limb w2, Limb w3 #)) =
+from_vartime :: Wider -> Integer
+from_vartime (Wider (# Limb w0, Limb w1, Limb w2, Limb w3 #)) =
fi (W# w3) .<<. (3 * size)
.|. fi (W# w2) .<<. (2 * size)
.|. fi (W# w1) .<<. size
diff --git a/lib/Numeric/Montgomery/Secp256k1/Curve.hs b/lib/Numeric/Montgomery/Secp256k1/Curve.hs
@@ -100,7 +100,7 @@ instance Num Montgomery where
a * b = mul a b
negate a = neg a
abs = id
- fromInteger = to . WW.to
+ fromInteger = to . WW.to_vartime
signum a = case a of
Montgomery (# Limb 0##, Limb 0##, Limb 0##, Limb 0## #) -> 0
_ -> 1
diff --git a/lib/Numeric/Montgomery/Secp256k1/Scalar.hs b/lib/Numeric/Montgomery/Secp256k1/Scalar.hs
@@ -98,7 +98,7 @@ instance Num Montgomery where
a * b = mul a b
negate a = neg a
abs = id
- fromInteger = to . WW.to
+ fromInteger = to . WW.to_vartime
signum a = case a of
Montgomery (# Limb 0##, Limb 0##, Limb 0##, Limb 0## #) -> 0
_ -> 1
diff --git a/test/Montgomery/Curve.hs b/test/Montgomery/Curve.hs
@@ -40,7 +40,7 @@ repr = H.assertBool mempty (W.eq_vartime 0 (C.from mm))
add_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion
add_case t a b s = do
- H.assertEqual "sanity" ((W.from a + W.from b) `mod` W.from m) (W.from s)
+ H.assertEqual "sanity" ((W.from_vartime a + W.from_vartime b) `mod` W.from_vartime m) (W.from_vartime s)
H.assertBool t (W.eq_vartime s (C.from (C.to a + C.to b)))
add :: H.Assertion
@@ -61,7 +61,7 @@ add = do
sub_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion
sub_case t b a d = do
- H.assertEqual "sanity" ((W.from b - W.from a) `mod` W.from m) (W.from d)
+ H.assertEqual "sanity" ((W.from_vartime b - W.from_vartime a) `mod` W.from_vartime m) (W.from_vartime d)
H.assertBool t (W.eq_vartime d (C.from (C.to b - C.to a)))
sub :: H.Assertion
@@ -81,7 +81,7 @@ sub = do
mul_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion
mul_case t a b p = do
- H.assertEqual "sanity" ((W.from a * W.from b) `mod` W.from m) (W.from p)
+ H.assertEqual "sanity" ((W.from_vartime a * W.from_vartime b) `mod` W.from_vartime m) (W.from_vartime p)
H.assertBool t (W.eq_vartime p (C.from (C.to a * C.to b)))
mul :: H.Assertion
@@ -105,7 +105,7 @@ mul = do
0x000000000000000000000000000000000000000000000001000007A2000E90A1
instance Q.Arbitrary W.Wider where
- arbitrary = fmap W.to Q.arbitrary
+ arbitrary = fmap W.to_vartime Q.arbitrary
instance Q.Arbitrary C.Montgomery where
arbitrary = fmap C.to Q.arbitrary
@@ -114,39 +114,39 @@ add_matches :: W.Wider -> W.Wider -> Bool
add_matches a b =
let ma = C.to a
mb = C.to b
- ia = W.from a
- ib = W.from b
- im = W.from m
- in W.eq_vartime (W.to ((ia + ib) `mod` im)) (C.from (ma + mb))
+ ia = W.from_vartime a
+ ib = W.from_vartime b
+ im = W.from_vartime m
+ in W.eq_vartime (W.to_vartime ((ia + ib) `mod` im)) (C.from (ma + mb))
mul_matches :: W.Wider -> W.Wider -> Bool
mul_matches a b =
let ma = C.to a
mb = C.to b
- ia = W.from a
- ib = W.from b
- im = W.from m
- in W.eq_vartime (W.to ((ia * ib) `mod` im)) (C.from (ma * mb))
+ ia = W.from_vartime a
+ ib = W.from_vartime b
+ im = W.from_vartime m
+ in W.eq_vartime (W.to_vartime ((ia * ib) `mod` im)) (C.from (ma * mb))
sqr_matches :: W.Wider -> Bool
sqr_matches a =
let ma = C.to a
- ia = W.from a
- im = W.from m
- in W.eq_vartime (W.to ((ia * ia) `mod` im)) (C.from (C.sqr ma))
+ ia = W.from_vartime a
+ im = W.from_vartime m
+ in W.eq_vartime (W.to_vartime ((ia * ia) `mod` im)) (C.from (C.sqr ma))
exp_matches :: C.Montgomery -> W.Wider -> Bool
exp_matches a b =
- let ia = W.from (C.from a)
- nb = fromIntegral (W.from b)
- nm = fromIntegral (W.from m)
- in W.eq_vartime (W.to (modexp ia nb nm)) (C.from (C.exp a b))
+ let ia = W.from_vartime (C.from a)
+ nb = fromIntegral (W.from_vartime b)
+ nm = fromIntegral (W.from_vartime m)
+ in W.eq_vartime (W.to_vartime (modexp ia nb nm)) (C.from (C.exp a b))
inv_valid :: Q.NonZero C.Montgomery -> Bool
inv_valid (Q.NonZero s) = C.eq_vartime (C.inv s * s) 1
odd_correct :: C.Montgomery -> Bool
-odd_correct w = C.odd w == I.integerTestBit (W.from (C.from w)) 0
+odd_correct w = C.odd w == I.integerTestBit (W.from_vartime (C.from w)) 0
tests :: TestTree
tests = testGroup "montgomery tests (curve)" [
diff --git a/test/Montgomery/Scalar.hs b/test/Montgomery/Scalar.hs
@@ -40,7 +40,7 @@ repr = H.assertBool mempty (W.eq_vartime 0 (S.from mm))
add_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion
add_case t a b s = do
- H.assertEqual "sanity" ((W.from a + W.from b) `mod` W.from m) (W.from s)
+ H.assertEqual "sanity" ((W.from_vartime a + W.from_vartime b) `mod` W.from_vartime m) (W.from_vartime s)
H.assertBool t (W.eq_vartime s (S.from (S.to a + S.to b)))
add :: H.Assertion
@@ -61,7 +61,7 @@ add = do
sub_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion
sub_case t b a d = do
- H.assertEqual "sanity" ((W.from b - W.from a) `mod` W.from m) (W.from d)
+ H.assertEqual "sanity" ((W.from_vartime b - W.from_vartime a) `mod` W.from_vartime m) (W.from_vartime d)
H.assertBool t (W.eq_vartime d (S.from (S.to b - S.to a)))
sub :: H.Assertion
@@ -81,7 +81,7 @@ sub = do
mul_case :: String -> W.Wider -> W.Wider -> W.Wider -> H.Assertion
mul_case t a b p = do
- H.assertEqual "sanity" ((W.from a * W.from b) `mod` W.from m) (W.from p)
+ H.assertEqual "sanity" ((W.from_vartime a * W.from_vartime b) `mod` W.from_vartime m) (W.from_vartime p)
H.assertBool t (W.eq_vartime p (S.from (S.to a * S.to b)))
mul :: H.Assertion
@@ -105,7 +105,7 @@ mul = do
0x9D671CD581C69BC5E697F5E45BCD07C6741496C20E7CF878896CF21467D7D140
instance Q.Arbitrary W.Wider where
- arbitrary = fmap W.to Q.arbitrary
+ arbitrary = fmap W.to_vartime Q.arbitrary
instance Q.Arbitrary S.Montgomery where
arbitrary = fmap S.to Q.arbitrary
@@ -114,33 +114,33 @@ add_matches :: W.Wider -> W.Wider -> Bool
add_matches a b =
let ma = S.to a
mb = S.to b
- ia = W.from a
- ib = W.from b
- im = W.from m
- in W.eq_vartime (W.to ((ia + ib) `mod` im)) (S.from (ma + mb))
+ ia = W.from_vartime a
+ ib = W.from_vartime b
+ im = W.from_vartime m
+ in W.eq_vartime (W.to_vartime ((ia + ib) `mod` im)) (S.from (ma + mb))
mul_matches :: W.Wider -> W.Wider -> Bool
mul_matches a b =
let ma = S.to a
mb = S.to b
- ia = W.from a
- ib = W.from b
- im = W.from m
- in W.eq_vartime (W.to ((ia * ib) `mod` im)) (S.from (ma * mb))
+ ia = W.from_vartime a
+ ib = W.from_vartime b
+ im = W.from_vartime m
+ in W.eq_vartime (W.to_vartime ((ia * ib) `mod` im)) (S.from (ma * mb))
sqr_matches :: W.Wider -> Bool
sqr_matches a =
let ma = S.to a
- ia = W.from a
- im = W.from m
- in W.eq_vartime (W.to ((ia * ia) `mod` im)) (S.from (S.sqr ma))
+ ia = W.from_vartime a
+ im = W.from_vartime m
+ in W.eq_vartime (W.to_vartime ((ia * ia) `mod` im)) (S.from (S.sqr ma))
exp_matches :: S.Montgomery -> W.Wider -> Bool
exp_matches a b =
- let ia = W.from (S.from a)
- nb = fromIntegral (W.from b)
- nm = fromIntegral (W.from m)
- in W.eq_vartime (W.to (modexp ia nb nm)) (S.from (S.exp a b))
+ let ia = W.from_vartime (S.from a)
+ nb = fromIntegral (W.from_vartime b)
+ nm = fromIntegral (W.from_vartime m)
+ in W.eq_vartime (W.to_vartime (modexp ia nb nm)) (S.from (S.exp a b))
inv_valid :: Q.NonZero S.Montgomery -> Bool
inv_valid (Q.NonZero s) = S.eq_vartime (S.inv s * s) 1
diff --git a/test/Wider.hs b/test/Wider.hs
@@ -107,17 +107,17 @@ cmp = do
let !a = 0
!b = 1
!c = 2 ^ (256 :: Word) - 1
- H.assertEqual mempty (W.cmp a b) LT
- H.assertEqual mempty (W.cmp a c) LT
- H.assertEqual mempty (W.cmp b c) LT
+ H.assertEqual mempty (W.cmp_vartime a b) LT
+ H.assertEqual mempty (W.cmp_vartime a c) LT
+ H.assertEqual mempty (W.cmp_vartime b c) LT
- H.assertEqual mempty (W.cmp a a) EQ
- H.assertEqual mempty (W.cmp b b) EQ
- H.assertEqual mempty (W.cmp c c) EQ
+ H.assertEqual mempty (W.cmp_vartime a a) EQ
+ H.assertEqual mempty (W.cmp_vartime b b) EQ
+ H.assertEqual mempty (W.cmp_vartime c c) EQ
- H.assertEqual mempty (W.cmp b a) GT
- H.assertEqual mempty (W.cmp c a) GT
- H.assertEqual mempty (W.cmp c b) GT
+ H.assertEqual mempty (W.cmp_vartime b a) GT
+ H.assertEqual mempty (W.cmp_vartime c a) GT
+ H.assertEqual mempty (W.cmp_vartime c b) GT
sqr :: H.Assertion
sqr = do
@@ -144,10 +144,10 @@ sub_mod = do
H.assertBool mempty (W.eq_vartime o e)
instance Q.Arbitrary W.Wider where
- arbitrary = fmap W.to Q.arbitrary
+ arbitrary = fmap W.to_vartime Q.arbitrary
odd_correct :: W.Wider -> Bool
-odd_correct w = C.decide (W.odd w) == I.integerTestBit (W.from w) 0
+odd_correct w = C.decide (W.odd w) == I.integerTestBit (W.from_vartime w) 0
tests :: TestTree
tests = testGroup "wider tests" [