fixed

Pure Haskell large fixed-width integers and Montgomery arithmetic.
git clone git://git.ppad.tech/fixed.git
Log | Files | Refs | README | LICENSE

commit 7d533266b4b52a38607b087a9b8bad3ef794c558
parent 0213f7350fcda1c8d28bb9dae686205cf5983f88
Author: Jared Tobin <jared@jtobin.io>
Date:   Sat, 27 Dec 2025 17:18:20 -0330

lib: refactor and refine choice module

Diffstat:
Mlib/Data/Choice.hs | 200+++++++++++++++++++++----------------------------------------------------------
Mlib/Data/Word/Limb.hs | 2+-
Mlib/Data/Word/Wide.hs | 37+++++++++++++++++++++++++++++++++++++
Mlib/Data/Word/Wider.hs | 10+++++-----
4 files changed, 95 insertions(+), 154 deletions(-)

diff --git a/lib/Data/Choice.hs b/lib/Data/Choice.hs @@ -11,7 +11,7 @@ -- License: MIT -- Maintainer: Jared Tobin <jared@ppad.tech> -- --- Constant-time choice. +-- Primitives for constant-time choice. module Data.Choice ( -- * Choice @@ -21,31 +21,15 @@ module Data.Choice ( , decide , to_word# - -- * MaybeWord# - , MaybeWord#(..) - , some_word# - , none_word# - - -- * MaybeWide# - , MaybeWide#(..) - , some_wide# - , just_wide# - , none_wide# - , expect_wide# - , expect_wide_or# - -- * Construction - , from_word_mask# - , from_word# + , from_full_mask# + , from_bit# , from_word_nonzero# , from_word_eq# , from_word_le# , from_word_lt# , from_word_gt# - , from_wide# - , from_wide_le# - -- * Manipulation , or , and @@ -85,10 +69,6 @@ lo# :: Word# -> (# Word#, Word# #) lo# w = (# w, 0## #) {-# INLINE lo# #-} -not_w# :: (# Word#, Word# #) -> (# Word#, Word# #) -not_w# (# a0, a1 #) = (# Exts.not# a0, Exts.not# a1 #) -{-# INLINE not_w# #-} - or_w# :: (# Word#, Word# #) -> (# Word#, Word# #) -> (# Word#, Word# #) or_w# (# a0, a1 #) (# b0, b1 #) = (# Exts.or# a0 b0, Exts.or# a1 b1 #) {-# INLINE or_w# #-} @@ -101,36 +81,6 @@ xor_w# :: (# Word#, Word# #) -> (# Word#, Word# #) -> (# Word#, Word# #) xor_w# (# a0, a1 #) (# b0, b1 #) = (# Exts.xor# a0 b0, Exts.xor# a1 b1 #) {-# INLINE xor_w# #-} --- subtract-with-borrow -sub_b# :: Word# -> Word# -> Word# -> (# Word#, Word# #) -sub_b# m n b = - let !(# d0, b0 #) = Exts.subWordC# m n - !(# d, b1 #) = Exts.subWordC# d0 b - !c = Exts.int2Word# (Exts.orI# b0 b1) - in (# d, c #) -{-# INLINE sub_b# #-} - --- wide subtract-with-borrow -sub_wb# - :: (# Word#, Word# #) - -> (# Word#, Word# #) - -> (# Word#, Word#, Word# #) -sub_wb# (# a0, a1 #) (# b0, b1 #) = - let !(# s0, c0 #) = sub_b# a0 b0 0## - !(# s1, c1 #) = sub_b# a1 b1 c0 - in (# s0, s1, c1 #) -{-# INLINE sub_wb# #-} - --- wide subtraction (wrapping) -sub_w# - :: (# Word#, Word# #) - -> (# Word#, Word# #) - -> (# Word#, Word# #) -sub_w# a b = - let !(# c0, c1, _ #) = sub_wb# a b - in (# c0, c1 #) -{-# INLINE sub_w# #-} - -- choice --------------------------------------------------------------------- -- | Constant-time choice, encoded as a mask. @@ -139,15 +89,11 @@ sub_w# a b = -- 'Choice' value cannot be bound at the top level. You should work -- with it locally in the context of a computation. -- --- It's safe to 'decide' a choice, reducing it to a 'Bool', at any --- time, but the general encouraged pattern is to do that only at the --- end of a computation. --- -- >>> decide (or# (false# ()) (true# ())) -- True newtype Choice = Choice Word# --- | Construct the falsy value. +-- | Construct the falsy 'Choice'. -- -- >>> decide (false# ()) -- False @@ -155,7 +101,7 @@ false# :: () -> Choice false# _ = Choice 0## {-# INLINE false# #-} --- | Construct the truthy value. +-- | Construct the truthy 'Choice'. -- -- >>> decide (true# ()) -- True @@ -166,6 +112,13 @@ true# _ = case maxBound :: Word of -- | Decide a 'Choice' by reducing it to a 'Bool'. -- +-- The 'decide' function itself runs in constant time, but once +-- it reduces a 'Choice' to a 'Bool', any subsequent branching on +-- the result is liable to introduce variable-time behaviour. +-- +-- You should 'decide' only at the /end/ of a computation, after all +-- security-sensitive computations have been carried out. +-- -- >>> decide (true# ()) -- True decide :: Choice -> Bool @@ -177,87 +130,34 @@ to_word# :: Choice -> Word# to_word# (Choice c) = Exts.and# c 1## {-# INLINE to_word# #-} --- constant time 'Maybe Word#' -newtype MaybeWord# = MaybeWord# (# Word#, Choice #) - -some_word# :: Word# -> MaybeWord# -some_word# w = MaybeWord# (# w, true# () #) -{-# INLINE some_word# #-} - -none_word# :: Word# -> MaybeWord# -none_word# w = MaybeWord# (# w, false# () #) -{-# INLINE none_word# #-} - --- constant time 'Maybe (# Word#, Word# #)' -newtype MaybeWide# = MaybeWide# (# (# Word#, Word# #), Choice #) - -just_wide# :: (# Word#, Word# #) -> Choice -> MaybeWide# -just_wide# w c = MaybeWide# (# w, c #) -{-# INLINE just_wide# #-} - -some_wide# :: (# Word#, Word# #) -> MaybeWide# -some_wide# w = MaybeWide# (# w, true# () #) -{-# INLINE some_wide# #-} - -none_wide# :: (# Word#, Word# #) -> MaybeWide# -none_wide# w = MaybeWide# (# w, false# () #) -{-# INLINE none_wide# #-} - -expect_wide# :: MaybeWide# -> String -> (# Word#, Word# #) -expect_wide# (MaybeWide# (# w, Choice c #)) msg - | Exts.isTrue# (Exts.eqWord# c t#) = w - | otherwise = error $ "ppad-fixed (expect_wide#): " <> msg - where - !(Choice t#) = true# () -{-# INLINE expect_wide# #-} - -expect_wide_or# :: MaybeWide# -> (# Word#, Word# #) -> (# Word#, Word# #) -expect_wide_or# (MaybeWide# (# w, Choice c #)) alt - | Exts.isTrue# (Exts.eqWord# c t#) = w - | otherwise = alt - where - !(Choice t#) = true# () -{-# INLINE expect_wide_or# #-} - -- construction --------------------------------------------------------------- --- | Construct a 'Choice' from an unboxed mask. +-- | Construct a 'Choice' from an unboxed full-word mask. -- --- The input is /not/ checked. +-- The input is /not/ checked to be a full-word mask. -- --- >>> decide (from_word_mask# 0##) +-- >>> decide (from_full_mask# 0##) -- False --- >>> decide (from_word_mask# 0xFFFFFFFFF_FFFFFFFF##) +-- >>> decide (from_full_mask# 0xFFFFFFFFF_FFFFFFFF##) -- True -from_word_mask# :: Word# -> Choice -from_word_mask# w = Choice w -{-# INLINE from_word_mask# #-} +from_full_mask# :: Word# -> Choice +from_full_mask# w = Choice w +{-# INLINE from_full_mask# #-} -- | Construct a 'Choice' from an unboxed word, which should be either -- 0## or 1##. -- --- The input is /not/ checked. +-- The input is /not/ checked to be a bit. -- --- >>> decide (from_word# 1##) +-- >>> decide (from_bit# 1##) -- True -from_word# :: Word# -> Choice -from_word# w = Choice (neg_w# w) -{-# INLINE from_word# #-} - --- | Construct a 'Choice' from a two-limb word, constructing a mask from --- the lower limb, which should be 0## or 1##. --- --- The input is /not/ checked. --- --- >>> decide (from_wide# (# 0##, 1## #)) --- False -from_wide# :: (# Word#, Word# #) -> Choice -from_wide# (# l, _ #) = from_word# l -{-# INLINE from_wide# #-} +from_bit# :: Word# -> Choice +from_bit# w = Choice (neg_w# w) +{-# INLINE from_bit# #-} -- | Construct a 'Choice' from a /nonzero/ unboxed word. -- --- The input is /not/ checked. +-- The input is /not/ checked to be nonzero. -- -- >>> decide (from_word_nonzero# 2##) -- True @@ -266,7 +166,7 @@ from_word_nonzero# w = let !n = neg_w# w !s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1# !v = Exts.uncheckedShiftRL# (Exts.or# w n) s - in from_word# v + in from_bit# v {-# INLINE from_word_nonzero# #-} -- | Construct a 'Choice' from an equality comparison. @@ -280,7 +180,7 @@ from_word_eq# x y = case from_word_nonzero# (Exts.xor# x y) of Choice w -> Choice (Exts.not# w) {-# INLINE from_word_eq# #-} --- | Construct a 'Choice from an at most comparison. +-- | Construct a 'Choice from an at-most comparison. -- -- >>> decide (from_word_le# 0## 1##) -- True @@ -295,28 +195,9 @@ from_word_le# x y = (Exts.or# (Exts.not# x) y) (Exts.or# (Exts.xor# x y) (Exts.not# (Exts.minusWord# y x)))) s - in from_word# bit + in from_bit# bit {-# INLINE from_word_le# #-} --- | Construct a 'Choice' from an at most comparison on a two-limb --- unboxed word. --- --- >>> decide (from_wide_le# (# 0##, 0## #) (# 1##, 0## #)) --- True --- >>> decide (from_wide_le# (# 1##, 0## #) (# 1##, 0## #)) --- True -from_wide_le# :: (# Word#, Word# #) -> (# Word#, Word# #) -> Choice -from_wide_le# x y = - let !s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1# - !mask = - (and_w# - (or_w# (not_w# x) y) - (or_w# (xor_w# x y) (not_w# (sub_w# y x)))) - !bit = case mask of - (# l, _ #) -> Exts.uncheckedShiftRL# l s - in from_word# bit -{-# INLINE from_wide_le# #-} - -- | Construct a 'Choice' from a less-than comparison. -- -- >>> decide (from_word_lt# 0## 1##) @@ -332,7 +213,7 @@ from_word_lt# x y = (Exts.and# (Exts.not# x) y) (Exts.and# (Exts.or# (Exts.not# x) y) (Exts.minusWord# x y))) s - in from_word# bit + in from_bit# bit {-# INLINE from_word_lt# #-} -- | Construct a 'Choice' from a greater-than comparison. @@ -348,31 +229,51 @@ from_word_gt# x y = from_word_lt# y x -- manipulation --------------------------------------------------------------- -- | Logically negate a 'Choice'. +-- +-- >>> C.decide (C.not (C.true# ())) +-- False +-- >>> C.decide (C.not (C.false# ())) +-- True not :: Choice -> Choice not (Choice w) = Choice (Exts.not# w) {-# INLINE not #-} -- | Logical disjunction on 'Choice' values. +-- +-- >>> C.decide (C.or (C.true# ()) (C.false# ())) +-- True or :: Choice -> Choice -> Choice or (Choice w0) (Choice w1) = Choice (Exts.or# w0 w1) {-# INLINE or #-} -- | Logical conjunction on 'Choice' values. +-- +-- >>> C.decide (C.and (C.true# ()) (C.false# ())) +-- False and :: Choice -> Choice -> Choice and (Choice w0) (Choice w1) = Choice (Exts.and# w0 w1) {-# INLINE and #-} -- | Logical inequality on 'Choice' values. +-- +-- >>> C.decide (C.xor (C.true# ()) (C.false# ())) +-- True xor :: Choice -> Choice -> Choice xor (Choice w0) (Choice w1) = Choice (Exts.xor# w0 w1) {-# INLINE xor #-} -- | Logical inequality on 'Choice' values. +-- +-- >>> C.decide (C.ne (C.true# ()) (C.false# ())) +-- True ne :: Choice -> Choice -> Choice ne c0 c1 = xor c0 c1 {-# INLINE ne #-} -- | Logical equality on 'Choice' values. +-- +-- >>> C.decide (C.eq (C.true# ()) (C.false# ())) +-- False eq :: Choice -> Choice -> Choice eq c0 c1 = not (ne c0 c1) {-# INLINE eq #-} @@ -380,6 +281,9 @@ eq c0 c1 = not (ne c0 c1) -- constant-time selection ---------------------------------------------------- -- | Select an unboxed word, given a 'Choice'. +-- +-- >>> let w = C.select_word# 0## 1## (C.true# ()) in GHC.Word.W# w +-- 1 select_word# :: Word# -> Word# -> Choice -> Word# select_word# a b (Choice c) = Exts.xor# a (Exts.and# c (Exts.xor# a b)) {-# INLINE select_word# #-} diff --git a/lib/Data/Word/Limb.hs b/lib/Data/Word/Limb.hs @@ -317,7 +317,7 @@ sub_s# -> Limb -- ^ difference sub_s# (Limb m) (Limb n) = let !(# d, b #) = Exts.subWordC# m n - !borrow = C.from_word# (Exts.int2Word# b) + !borrow = C.from_bit# (Exts.int2Word# b) in Limb (C.select_word# d 0## borrow) {-# INLINE sub_s# #-} diff --git a/lib/Data/Word/Wide.hs b/lib/Data/Word/Wide.hs @@ -25,6 +25,10 @@ module Data.Word.Wide ( , to_vartime , from_vartime + -- * Constant-time selection + , select + , select# + -- * Bit Manipulation , or , or# @@ -36,6 +40,7 @@ module Data.Word.Wide ( , not# -- * Comparison + , eq , eq_vartime -- * Arithmetic @@ -131,11 +136,43 @@ from_vartime (Wide (# Limb a, Limb b #)) = -- comparison ----------------------------------------------------------------- +-- | Compare 'Wide' words for equality in constant time. +eq :: Wide -> Wide -> C.Choice +eq (Wide (# Limb a0, Limb a1 #)) (Wide (# Limb b0, Limb b1 #)) = + C.eq_wide# (# a0, a1 #) (# b0, b1 #) + -- | Compare 'Wide' words for equality in variable time. eq_vartime :: Wide -> Wide -> Bool eq_vartime (Wide (# Limb a0, Limb b0 #)) (Wide (# Limb a1, Limb b1 #)) = isTrue# (andI# (eqWord# a0 a1) (eqWord# b0 b1)) +-- constant-time selection----------------------------------------------------- + +select# + :: (# Limb, Limb #) -- ^ a + -> (# Limb, Limb #) -- ^ b + -> C.Choice -- ^ c + -> (# Limb, Limb #) -- ^ result +select# a b c = + let !(# Limb a0, Limb a1 #) = a + !(# Limb b0, Limb b1 #) = b + !(# w0, w1 #) = + C.select_wide# (# a0, a1 #) (# b0, b1 #) c + in (# Limb w0, Limb w1 #) +{-# INLINE select# #-} + +-- | Return a if c is truthy, otherwise return b. +-- +-- >>> import qualified Data.Choice as C +-- >>> select 0 1 (C.true# ()) +-- 1 +select + :: Wide -- ^ a + -> Wide -- ^ b + -> C.Choice -- ^ c + -> Wide -- ^ result +select (Wide a) (Wide b) c = Wide (select# a b c) + -- bits ----------------------------------------------------------------------- or_w# :: (# Limb, Limb #) -> (# Limb, Limb #) -> (# Limb, Limb #) diff --git a/lib/Data/Word/Wider.hs b/lib/Data/Word/Wider.hs @@ -165,7 +165,7 @@ lt# -> C.Choice lt# a b = let !(# _, Limb bor #) = sub_b# a b - in C.from_word_mask# bor + in C.from_full_mask# bor {-# INLINE lt# #-} -- | Constant-time less-than comparison between 'Wider' values. @@ -184,7 +184,7 @@ gt# -> C.Choice gt# a b = let !(# _, Limb bor #) = sub_b# b a - in C.from_word_mask# bor + in C.from_full_mask# bor {-# INLINE gt# #-} -- | Constant-time greater-than comparison between 'Wider' values. @@ -314,7 +314,7 @@ shr1_c# (# w0, w1, w2, w3 #) = !(# s0, c0 #) = (# L.shr# w0 1#, L.shl# w0 s #) !r0 = L.or# s0 c1 !(Limb w) = L.shr# c0 s - in (# (# r0, r1, r2, r3 #), C.from_word# w #) + in (# (# r0, r1, r2, r3 #), C.from_bit# w #) {-# INLINE shr1_c# #-} -- | Constant-time 1-bit shift-right with carry, with a 'Choice' @@ -349,7 +349,7 @@ shl1_c# (# w0, w1, w2, w3 #) = !(# s3, c3 #) = (# L.shl# w3 1#, L.shr# w3 s #) !r3 = L.or# s3 c2 !(Limb w) = L.shl# c3 s - in (# (# r0, r1, r2, r3 #), C.from_word# w #) + in (# (# r0, r1, r2, r3 #), C.from_bit# w #) {-# INLINE shl1_c# #-} -- | Constant-time 1-bit shift-left with carry, with a 'Choice' indicating @@ -762,7 +762,7 @@ sqr (Wider w) = in (Wider l, Wider h) odd# :: (# Limb, Limb, Limb, Limb #) -> C.Choice -odd# (# Limb w, _, _, _ #) = C.from_word# (Exts.and# w 1##) +odd# (# Limb w, _, _, _ #) = C.from_bit# (Exts.and# w 1##) {-# INLINE odd# #-} -- | Check if a 'Wider' is odd, returning a 'Choice'.