commit 7d533266b4b52a38607b087a9b8bad3ef794c558
parent 0213f7350fcda1c8d28bb9dae686205cf5983f88
Author: Jared Tobin <jared@jtobin.io>
Date: Sat, 27 Dec 2025 17:18:20 -0330
lib: refactor and refine choice module
Diffstat:
4 files changed, 95 insertions(+), 154 deletions(-)
diff --git a/lib/Data/Choice.hs b/lib/Data/Choice.hs
@@ -11,7 +11,7 @@
-- License: MIT
-- Maintainer: Jared Tobin <jared@ppad.tech>
--
--- Constant-time choice.
+-- Primitives for constant-time choice.
module Data.Choice (
-- * Choice
@@ -21,31 +21,15 @@ module Data.Choice (
, decide
, to_word#
- -- * MaybeWord#
- , MaybeWord#(..)
- , some_word#
- , none_word#
-
- -- * MaybeWide#
- , MaybeWide#(..)
- , some_wide#
- , just_wide#
- , none_wide#
- , expect_wide#
- , expect_wide_or#
-
-- * Construction
- , from_word_mask#
- , from_word#
+ , from_full_mask#
+ , from_bit#
, from_word_nonzero#
, from_word_eq#
, from_word_le#
, from_word_lt#
, from_word_gt#
- , from_wide#
- , from_wide_le#
-
-- * Manipulation
, or
, and
@@ -85,10 +69,6 @@ lo# :: Word# -> (# Word#, Word# #)
lo# w = (# w, 0## #)
{-# INLINE lo# #-}
-not_w# :: (# Word#, Word# #) -> (# Word#, Word# #)
-not_w# (# a0, a1 #) = (# Exts.not# a0, Exts.not# a1 #)
-{-# INLINE not_w# #-}
-
or_w# :: (# Word#, Word# #) -> (# Word#, Word# #) -> (# Word#, Word# #)
or_w# (# a0, a1 #) (# b0, b1 #) = (# Exts.or# a0 b0, Exts.or# a1 b1 #)
{-# INLINE or_w# #-}
@@ -101,36 +81,6 @@ xor_w# :: (# Word#, Word# #) -> (# Word#, Word# #) -> (# Word#, Word# #)
xor_w# (# a0, a1 #) (# b0, b1 #) = (# Exts.xor# a0 b0, Exts.xor# a1 b1 #)
{-# INLINE xor_w# #-}
--- subtract-with-borrow
-sub_b# :: Word# -> Word# -> Word# -> (# Word#, Word# #)
-sub_b# m n b =
- let !(# d0, b0 #) = Exts.subWordC# m n
- !(# d, b1 #) = Exts.subWordC# d0 b
- !c = Exts.int2Word# (Exts.orI# b0 b1)
- in (# d, c #)
-{-# INLINE sub_b# #-}
-
--- wide subtract-with-borrow
-sub_wb#
- :: (# Word#, Word# #)
- -> (# Word#, Word# #)
- -> (# Word#, Word#, Word# #)
-sub_wb# (# a0, a1 #) (# b0, b1 #) =
- let !(# s0, c0 #) = sub_b# a0 b0 0##
- !(# s1, c1 #) = sub_b# a1 b1 c0
- in (# s0, s1, c1 #)
-{-# INLINE sub_wb# #-}
-
--- wide subtraction (wrapping)
-sub_w#
- :: (# Word#, Word# #)
- -> (# Word#, Word# #)
- -> (# Word#, Word# #)
-sub_w# a b =
- let !(# c0, c1, _ #) = sub_wb# a b
- in (# c0, c1 #)
-{-# INLINE sub_w# #-}
-
-- choice ---------------------------------------------------------------------
-- | Constant-time choice, encoded as a mask.
@@ -139,15 +89,11 @@ sub_w# a b =
-- 'Choice' value cannot be bound at the top level. You should work
-- with it locally in the context of a computation.
--
--- It's safe to 'decide' a choice, reducing it to a 'Bool', at any
--- time, but the general encouraged pattern is to do that only at the
--- end of a computation.
---
-- >>> decide (or# (false# ()) (true# ()))
-- True
newtype Choice = Choice Word#
--- | Construct the falsy value.
+-- | Construct the falsy 'Choice'.
--
-- >>> decide (false# ())
-- False
@@ -155,7 +101,7 @@ false# :: () -> Choice
false# _ = Choice 0##
{-# INLINE false# #-}
--- | Construct the truthy value.
+-- | Construct the truthy 'Choice'.
--
-- >>> decide (true# ())
-- True
@@ -166,6 +112,13 @@ true# _ = case maxBound :: Word of
-- | Decide a 'Choice' by reducing it to a 'Bool'.
--
+-- The 'decide' function itself runs in constant time, but once
+-- it reduces a 'Choice' to a 'Bool', any subsequent branching on
+-- the result is liable to introduce variable-time behaviour.
+--
+-- You should 'decide' only at the /end/ of a computation, after all
+-- security-sensitive computations have been carried out.
+--
-- >>> decide (true# ())
-- True
decide :: Choice -> Bool
@@ -177,87 +130,34 @@ to_word# :: Choice -> Word#
to_word# (Choice c) = Exts.and# c 1##
{-# INLINE to_word# #-}
--- constant time 'Maybe Word#'
-newtype MaybeWord# = MaybeWord# (# Word#, Choice #)
-
-some_word# :: Word# -> MaybeWord#
-some_word# w = MaybeWord# (# w, true# () #)
-{-# INLINE some_word# #-}
-
-none_word# :: Word# -> MaybeWord#
-none_word# w = MaybeWord# (# w, false# () #)
-{-# INLINE none_word# #-}
-
--- constant time 'Maybe (# Word#, Word# #)'
-newtype MaybeWide# = MaybeWide# (# (# Word#, Word# #), Choice #)
-
-just_wide# :: (# Word#, Word# #) -> Choice -> MaybeWide#
-just_wide# w c = MaybeWide# (# w, c #)
-{-# INLINE just_wide# #-}
-
-some_wide# :: (# Word#, Word# #) -> MaybeWide#
-some_wide# w = MaybeWide# (# w, true# () #)
-{-# INLINE some_wide# #-}
-
-none_wide# :: (# Word#, Word# #) -> MaybeWide#
-none_wide# w = MaybeWide# (# w, false# () #)
-{-# INLINE none_wide# #-}
-
-expect_wide# :: MaybeWide# -> String -> (# Word#, Word# #)
-expect_wide# (MaybeWide# (# w, Choice c #)) msg
- | Exts.isTrue# (Exts.eqWord# c t#) = w
- | otherwise = error $ "ppad-fixed (expect_wide#): " <> msg
- where
- !(Choice t#) = true# ()
-{-# INLINE expect_wide# #-}
-
-expect_wide_or# :: MaybeWide# -> (# Word#, Word# #) -> (# Word#, Word# #)
-expect_wide_or# (MaybeWide# (# w, Choice c #)) alt
- | Exts.isTrue# (Exts.eqWord# c t#) = w
- | otherwise = alt
- where
- !(Choice t#) = true# ()
-{-# INLINE expect_wide_or# #-}
-
-- construction ---------------------------------------------------------------
--- | Construct a 'Choice' from an unboxed mask.
+-- | Construct a 'Choice' from an unboxed full-word mask.
--
--- The input is /not/ checked.
+-- The input is /not/ checked to be a full-word mask.
--
--- >>> decide (from_word_mask# 0##)
+-- >>> decide (from_full_mask# 0##)
-- False
--- >>> decide (from_word_mask# 0xFFFFFFFFF_FFFFFFFF##)
+-- >>> decide (from_full_mask# 0xFFFFFFFFF_FFFFFFFF##)
-- True
-from_word_mask# :: Word# -> Choice
-from_word_mask# w = Choice w
-{-# INLINE from_word_mask# #-}
+from_full_mask# :: Word# -> Choice
+from_full_mask# w = Choice w
+{-# INLINE from_full_mask# #-}
-- | Construct a 'Choice' from an unboxed word, which should be either
-- 0## or 1##.
--
--- The input is /not/ checked.
+-- The input is /not/ checked to be a bit.
--
--- >>> decide (from_word# 1##)
+-- >>> decide (from_bit# 1##)
-- True
-from_word# :: Word# -> Choice
-from_word# w = Choice (neg_w# w)
-{-# INLINE from_word# #-}
-
--- | Construct a 'Choice' from a two-limb word, constructing a mask from
--- the lower limb, which should be 0## or 1##.
---
--- The input is /not/ checked.
---
--- >>> decide (from_wide# (# 0##, 1## #))
--- False
-from_wide# :: (# Word#, Word# #) -> Choice
-from_wide# (# l, _ #) = from_word# l
-{-# INLINE from_wide# #-}
+from_bit# :: Word# -> Choice
+from_bit# w = Choice (neg_w# w)
+{-# INLINE from_bit# #-}
-- | Construct a 'Choice' from a /nonzero/ unboxed word.
--
--- The input is /not/ checked.
+-- The input is /not/ checked to be nonzero.
--
-- >>> decide (from_word_nonzero# 2##)
-- True
@@ -266,7 +166,7 @@ from_word_nonzero# w =
let !n = neg_w# w
!s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1#
!v = Exts.uncheckedShiftRL# (Exts.or# w n) s
- in from_word# v
+ in from_bit# v
{-# INLINE from_word_nonzero# #-}
-- | Construct a 'Choice' from an equality comparison.
@@ -280,7 +180,7 @@ from_word_eq# x y = case from_word_nonzero# (Exts.xor# x y) of
Choice w -> Choice (Exts.not# w)
{-# INLINE from_word_eq# #-}
--- | Construct a 'Choice from an at most comparison.
+-- | Construct a 'Choice from an at-most comparison.
--
-- >>> decide (from_word_le# 0## 1##)
-- True
@@ -295,28 +195,9 @@ from_word_le# x y =
(Exts.or# (Exts.not# x) y)
(Exts.or# (Exts.xor# x y) (Exts.not# (Exts.minusWord# y x))))
s
- in from_word# bit
+ in from_bit# bit
{-# INLINE from_word_le# #-}
--- | Construct a 'Choice' from an at most comparison on a two-limb
--- unboxed word.
---
--- >>> decide (from_wide_le# (# 0##, 0## #) (# 1##, 0## #))
--- True
--- >>> decide (from_wide_le# (# 1##, 0## #) (# 1##, 0## #))
--- True
-from_wide_le# :: (# Word#, Word# #) -> (# Word#, Word# #) -> Choice
-from_wide_le# x y =
- let !s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1#
- !mask =
- (and_w#
- (or_w# (not_w# x) y)
- (or_w# (xor_w# x y) (not_w# (sub_w# y x))))
- !bit = case mask of
- (# l, _ #) -> Exts.uncheckedShiftRL# l s
- in from_word# bit
-{-# INLINE from_wide_le# #-}
-
-- | Construct a 'Choice' from a less-than comparison.
--
-- >>> decide (from_word_lt# 0## 1##)
@@ -332,7 +213,7 @@ from_word_lt# x y =
(Exts.and# (Exts.not# x) y)
(Exts.and# (Exts.or# (Exts.not# x) y) (Exts.minusWord# x y)))
s
- in from_word# bit
+ in from_bit# bit
{-# INLINE from_word_lt# #-}
-- | Construct a 'Choice' from a greater-than comparison.
@@ -348,31 +229,51 @@ from_word_gt# x y = from_word_lt# y x
-- manipulation ---------------------------------------------------------------
-- | Logically negate a 'Choice'.
+--
+-- >>> C.decide (C.not (C.true# ()))
+-- False
+-- >>> C.decide (C.not (C.false# ()))
+-- True
not :: Choice -> Choice
not (Choice w) = Choice (Exts.not# w)
{-# INLINE not #-}
-- | Logical disjunction on 'Choice' values.
+--
+-- >>> C.decide (C.or (C.true# ()) (C.false# ()))
+-- True
or :: Choice -> Choice -> Choice
or (Choice w0) (Choice w1) = Choice (Exts.or# w0 w1)
{-# INLINE or #-}
-- | Logical conjunction on 'Choice' values.
+--
+-- >>> C.decide (C.and (C.true# ()) (C.false# ()))
+-- False
and :: Choice -> Choice -> Choice
and (Choice w0) (Choice w1) = Choice (Exts.and# w0 w1)
{-# INLINE and #-}
-- | Logical inequality on 'Choice' values.
+--
+-- >>> C.decide (C.xor (C.true# ()) (C.false# ()))
+-- True
xor :: Choice -> Choice -> Choice
xor (Choice w0) (Choice w1) = Choice (Exts.xor# w0 w1)
{-# INLINE xor #-}
-- | Logical inequality on 'Choice' values.
+--
+-- >>> C.decide (C.ne (C.true# ()) (C.false# ()))
+-- True
ne :: Choice -> Choice -> Choice
ne c0 c1 = xor c0 c1
{-# INLINE ne #-}
-- | Logical equality on 'Choice' values.
+--
+-- >>> C.decide (C.eq (C.true# ()) (C.false# ()))
+-- False
eq :: Choice -> Choice -> Choice
eq c0 c1 = not (ne c0 c1)
{-# INLINE eq #-}
@@ -380,6 +281,9 @@ eq c0 c1 = not (ne c0 c1)
-- constant-time selection ----------------------------------------------------
-- | Select an unboxed word, given a 'Choice'.
+--
+-- >>> let w = C.select_word# 0## 1## (C.true# ()) in GHC.Word.W# w
+-- 1
select_word# :: Word# -> Word# -> Choice -> Word#
select_word# a b (Choice c) = Exts.xor# a (Exts.and# c (Exts.xor# a b))
{-# INLINE select_word# #-}
diff --git a/lib/Data/Word/Limb.hs b/lib/Data/Word/Limb.hs
@@ -317,7 +317,7 @@ sub_s#
-> Limb -- ^ difference
sub_s# (Limb m) (Limb n) =
let !(# d, b #) = Exts.subWordC# m n
- !borrow = C.from_word# (Exts.int2Word# b)
+ !borrow = C.from_bit# (Exts.int2Word# b)
in Limb (C.select_word# d 0## borrow)
{-# INLINE sub_s# #-}
diff --git a/lib/Data/Word/Wide.hs b/lib/Data/Word/Wide.hs
@@ -25,6 +25,10 @@ module Data.Word.Wide (
, to_vartime
, from_vartime
+ -- * Constant-time selection
+ , select
+ , select#
+
-- * Bit Manipulation
, or
, or#
@@ -36,6 +40,7 @@ module Data.Word.Wide (
, not#
-- * Comparison
+ , eq
, eq_vartime
-- * Arithmetic
@@ -131,11 +136,43 @@ from_vartime (Wide (# Limb a, Limb b #)) =
-- comparison -----------------------------------------------------------------
+-- | Compare 'Wide' words for equality in constant time.
+eq :: Wide -> Wide -> C.Choice
+eq (Wide (# Limb a0, Limb a1 #)) (Wide (# Limb b0, Limb b1 #)) =
+ C.eq_wide# (# a0, a1 #) (# b0, b1 #)
+
-- | Compare 'Wide' words for equality in variable time.
eq_vartime :: Wide -> Wide -> Bool
eq_vartime (Wide (# Limb a0, Limb b0 #)) (Wide (# Limb a1, Limb b1 #)) =
isTrue# (andI# (eqWord# a0 a1) (eqWord# b0 b1))
+-- constant-time selection-----------------------------------------------------
+
+select#
+ :: (# Limb, Limb #) -- ^ a
+ -> (# Limb, Limb #) -- ^ b
+ -> C.Choice -- ^ c
+ -> (# Limb, Limb #) -- ^ result
+select# a b c =
+ let !(# Limb a0, Limb a1 #) = a
+ !(# Limb b0, Limb b1 #) = b
+ !(# w0, w1 #) =
+ C.select_wide# (# a0, a1 #) (# b0, b1 #) c
+ in (# Limb w0, Limb w1 #)
+{-# INLINE select# #-}
+
+-- | Return a if c is truthy, otherwise return b.
+--
+-- >>> import qualified Data.Choice as C
+-- >>> select 0 1 (C.true# ())
+-- 1
+select
+ :: Wide -- ^ a
+ -> Wide -- ^ b
+ -> C.Choice -- ^ c
+ -> Wide -- ^ result
+select (Wide a) (Wide b) c = Wide (select# a b c)
+
-- bits -----------------------------------------------------------------------
or_w# :: (# Limb, Limb #) -> (# Limb, Limb #) -> (# Limb, Limb #)
diff --git a/lib/Data/Word/Wider.hs b/lib/Data/Word/Wider.hs
@@ -165,7 +165,7 @@ lt#
-> C.Choice
lt# a b =
let !(# _, Limb bor #) = sub_b# a b
- in C.from_word_mask# bor
+ in C.from_full_mask# bor
{-# INLINE lt# #-}
-- | Constant-time less-than comparison between 'Wider' values.
@@ -184,7 +184,7 @@ gt#
-> C.Choice
gt# a b =
let !(# _, Limb bor #) = sub_b# b a
- in C.from_word_mask# bor
+ in C.from_full_mask# bor
{-# INLINE gt# #-}
-- | Constant-time greater-than comparison between 'Wider' values.
@@ -314,7 +314,7 @@ shr1_c# (# w0, w1, w2, w3 #) =
!(# s0, c0 #) = (# L.shr# w0 1#, L.shl# w0 s #)
!r0 = L.or# s0 c1
!(Limb w) = L.shr# c0 s
- in (# (# r0, r1, r2, r3 #), C.from_word# w #)
+ in (# (# r0, r1, r2, r3 #), C.from_bit# w #)
{-# INLINE shr1_c# #-}
-- | Constant-time 1-bit shift-right with carry, with a 'Choice'
@@ -349,7 +349,7 @@ shl1_c# (# w0, w1, w2, w3 #) =
!(# s3, c3 #) = (# L.shl# w3 1#, L.shr# w3 s #)
!r3 = L.or# s3 c2
!(Limb w) = L.shl# c3 s
- in (# (# r0, r1, r2, r3 #), C.from_word# w #)
+ in (# (# r0, r1, r2, r3 #), C.from_bit# w #)
{-# INLINE shl1_c# #-}
-- | Constant-time 1-bit shift-left with carry, with a 'Choice' indicating
@@ -762,7 +762,7 @@ sqr (Wider w) =
in (Wider l, Wider h)
odd# :: (# Limb, Limb, Limb, Limb #) -> C.Choice
-odd# (# Limb w, _, _, _ #) = C.from_word# (Exts.and# w 1##)
+odd# (# Limb w, _, _, _ #) = C.from_bit# (Exts.and# w 1##)
{-# INLINE odd# #-}
-- | Check if a 'Wider' is odd, returning a 'Choice'.