commit 5868a462096497298532599f7672dc72dc1be1d9
parent fd0bc00ecdd659696372384992feb4db399c517e
Author: Jared Tobin <jared@jtobin.io>
Date: Sat, 20 Dec 2025 16:12:09 -0330
lib: add choice docs
Diffstat:
2 files changed, 138 insertions(+), 26 deletions(-)
diff --git a/lib/Data/Choice.hs b/lib/Data/Choice.hs
@@ -4,6 +4,14 @@
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE ViewPatterns #-}
+-- |
+-- Module: Data.Choice
+-- Copyright: (c) 2025 Jared Tobin
+-- License: MIT
+-- Maintainer: Jared Tobin <jared@ppad.tech>
+--
+-- Constant-time choice.
+
module Data.Choice (
-- * Choice
Choice
@@ -27,14 +35,14 @@ module Data.Choice (
-- * Construction
, from_word_mask#
- , from_word_lsb#
+ , from_word#
, from_word_nonzero#
, from_word_eq#
, from_word_le#
, from_word_lt#
, from_word_gt#
- , from_wide_lsb#
+ , from_wide#
, from_wide_le#
-- * Manipulation
@@ -62,10 +70,10 @@ import qualified GHC.Exts as Exts
-- utilities ------------------------------------------------------------------
--- make a mask from a bit (0 -> 0, 1 -> maxBound)
-wrapping_neg# :: Word# -> Word#
-wrapping_neg# w = Exts.plusWord# (Exts.not# w) 1##
-{-# INLINE wrapping_neg# #-}
+-- wrapping negation
+neg_w# :: Word# -> Word#
+neg_w# w = Exts.plusWord# (Exts.not# w) 1##
+{-# INLINE neg_w# #-}
hi# :: Word# -> (# Word#, Word# #)
hi# w = (# 0##, w #)
@@ -123,22 +131,46 @@ sub_w# a b =
-- choice ---------------------------------------------------------------------
--- constant-time choice, encoded as a mask
+-- | Constant-time choice, encoded as a mask.
+--
+-- Note that 'Choice' is defined as an unboxed newtype, and so a
+-- 'Choice' value cannot be bound at the top level. You should work
+-- with it locally in the context of a computation.
+--
+-- It's safe to 'decide' a choice, reducing it to a 'Bool', at any
+-- time, but the general encouraged pattern is to do that only at the
+-- end of a computation.
+--
+-- >>> decide (or# (false# ()) (true# ()))
+-- True
newtype Choice = Choice Word#
+-- | Construct the falsy value.
+--
+-- >>> decide (false# ())
+-- False
false# :: () -> Choice
false# _ = Choice 0##
{-# INLINE false# #-}
+-- | Construct the truthy value.
+--
+-- >>> decide (true# ())
+-- True
true# :: () -> Choice
true# _ = case maxBound :: Word of
W# w -> Choice w
{-# INLINE true# #-}
+-- | Decide a 'Choice' by reducing it to a 'Bool'.
+--
+-- >>> decide (true# ())
+-- True
decide :: Choice -> Bool
decide (Choice c) = Exts.isTrue# (Exts.neWord# c 0##)
{-# INLINE decide #-}
+-- | Convert a 'Choice' to an unboxed 'Word#'.
to_word# :: Choice -> Word#
to_word# (Choice c) = Exts.and# c 1##
{-# INLINE to_word# #-}
@@ -187,31 +219,71 @@ expect_wide_or# (MaybeWide# (# w, Choice c #)) alt
-- construction ---------------------------------------------------------------
+-- | Construct a 'Choice' from an unboxed mask.
+--
+-- The input is /not/ checked.
+--
+-- >>> decide (from_word_mask# 0##)
+-- False
+-- >>> decide (from_word_mask# 0xFFFFFFFFF_FFFFFFFF##)
+-- True
from_word_mask# :: Word# -> Choice
from_word_mask# w = Choice w
{-# INLINE from_word_mask# #-}
-from_word_lsb# :: Word# -> Choice
-from_word_lsb# w = Choice (wrapping_neg# w)
-{-# INLINE from_word_lsb# #-}
-
-from_wide_lsb# :: (# Word#, Word# #) -> Choice
-from_wide_lsb# (# l, _ #) = from_word_lsb# l
-{-# INLINE from_wide_lsb# #-}
-
+-- | Construct a 'Choice' from an unboxed word, which should be either
+-- 0## or 1##.
+--
+-- The input is /not/ checked.
+--
+-- >>> decide (from_word# 1##)
+-- True
+from_word# :: Word# -> Choice
+from_word# w = Choice (neg_w# w)
+{-# INLINE from_word# #-}
+
+-- | Construct a 'Choice' from a two-limb word, constructing a mask from
+-- the lower limb, which should be 0## or 1##.
+--
+-- The input is /not/ checked.
+--
+-- >>> decide (from_wide# (# 0##, 1## #))
+-- False
+from_wide# :: (# Word#, Word# #) -> Choice
+from_wide# (# l, _ #) = from_word# l
+{-# INLINE from_wide# #-}
+
+-- | Construct a 'Choice' from a /nonzero/ unboxed word.
+--
+-- The input is /not/ checked.
+--
+-- >>> decide (from_word_nonzero# 2##)
+-- True
from_word_nonzero# :: Word# -> Choice
from_word_nonzero# w =
- let !n = wrapping_neg# w
+ let !n = neg_w# w
!s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1#
!v = Exts.uncheckedShiftRL# (Exts.or# w n) s
- in from_word_lsb# v
+ in from_word# v
{-# INLINE from_word_nonzero# #-}
+-- | Construct a 'Choice' from an equality comparison.
+--
+-- >>> decide (from_word_eq# 0## 1##)
+-- False
+-- decide (from_word_eq# 1## 1##)
+-- True
from_word_eq# :: Word# -> Word# -> Choice
from_word_eq# x y = case from_word_nonzero# (Exts.xor# x y) of
Choice w -> Choice (Exts.not# w)
{-# INLINE from_word_eq# #-}
+-- | Construct a 'Choice from an at most comparison.
+--
+-- >>> decide (from_word_le# 0## 1##)
+-- True
+-- >>> decide (from_word_le# 1## 1##)
+-- True
from_word_le# :: Word# -> Word# -> Choice
from_word_le# x y =
let !s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1#
@@ -221,9 +293,16 @@ from_word_le# x y =
(Exts.or# (Exts.not# x) y)
(Exts.or# (Exts.xor# x y) (Exts.not# (Exts.minusWord# y x))))
s
- in from_word_lsb# bit
+ in from_word# bit
{-# INLINE from_word_le# #-}
+-- | Construct a 'Choice' from an at most comparison on a two-limb
+-- unboxed word.
+--
+-- >>> decide (from_wide_le# (# 0##, 0## #) (# 1##, 0## #))
+-- True
+-- >>> decide (from_wide_le# (# 1##, 0## #) (# 1##, 0## #))
+-- True
from_wide_le# :: (# Word#, Word# #) -> (# Word#, Word# #) -> Choice
from_wide_le# x y =
let !s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1#
@@ -233,9 +312,15 @@ from_wide_le# x y =
(or_w# (xor_w# x y) (not_w# (sub_w# y x))))
!bit = case mask of
(# l, _ #) -> Exts.uncheckedShiftRL# l s
- in from_word_lsb# bit
+ in from_word# bit
{-# INLINE from_wide_le# #-}
+-- | Construct a 'Choice' from a less-than comparison.
+--
+-- >>> decide (from_word_lt# 0## 1##)
+-- True
+-- >>> decide (from_word_lt# 1## 1##)
+-- False
from_word_lt# :: Word# -> Word# -> Choice
from_word_lt# x y =
let !s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1#
@@ -245,45 +330,59 @@ from_word_lt# x y =
(Exts.and# (Exts.not# x) y)
(Exts.and# (Exts.or# (Exts.not# x) y) (Exts.minusWord# x y)))
s
- in from_word_lsb# bit
+ in from_word# bit
{-# INLINE from_word_lt# #-}
+-- | Construct a 'Choice' from a greater-than comparison.
+--
+-- >>> decide (from_word_gt# 0## 1##)
+-- False
+-- >>> decide (from_word_gt# 1## 1##)
+-- False
from_word_gt# :: Word# -> Word# -> Choice
from_word_gt# x y = from_word_lt# y x
{-# INLINE from_word_gt# #-}
-- manipulation ---------------------------------------------------------------
+-- | Logically negate a 'Choice'.
not# :: Choice -> Choice
not# (Choice w) = Choice (Exts.not# w)
{-# INLINE not# #-}
+-- | Logical disjunction on 'Choice' values.
or# :: Choice -> Choice -> Choice
or# (Choice w0) (Choice w1) = Choice (Exts.or# w0 w1)
{-# INLINE or# #-}
+-- | Logical conjunction on 'Choice' values.
and# :: Choice -> Choice -> Choice
and# (Choice w0) (Choice w1) = Choice (Exts.and# w0 w1)
{-# INLINE and# #-}
+-- | Logical inequality on 'Choice' values.
xor# :: Choice -> Choice -> Choice
xor# (Choice w0) (Choice w1) = Choice (Exts.xor# w0 w1)
{-# INLINE xor# #-}
+-- | Logical inequality on 'Choice' values.
ne# :: Choice -> Choice -> Choice
ne# c0 c1 = xor# c0 c1
{-# INLINE ne# #-}
+-- | Logical equality on 'Choice' values.
eq# :: Choice -> Choice -> Choice
eq# c0 c1 = not# (ne# c0 c1)
{-# INLINE eq# #-}
-- constant-time selection ----------------------------------------------------
+-- | Select an unboxed word, given a 'Choice'.
select_word# :: Word# -> Word# -> Choice -> Word#
select_word# a b (Choice c) = Exts.xor# a (Exts.and# c (Exts.xor# a b))
{-# INLINE select_word# #-}
+-- | Select an unboxed two-limb word, given a 'Choice'.
select_wide#
:: (# Word#, Word# #)
-> (# Word#, Word# #)
@@ -294,6 +393,7 @@ select_wide# a b (Choice w) =
in xor_w# a (and_w# mask (xor_w# a b))
{-# INLINE select_wide# #-}
+-- | Select an unboxed four-limb word, given a 'Choice'.
select_wider#
:: (# Word#, Word#, Word#, Word# #)
-> (# Word#, Word#, Word#, Word# #)
@@ -309,14 +409,22 @@ select_wider# (# a0, a1, a2, a3 #) (# b0, b1, b2, b3 #) (Choice w) =
-- constant-time equality -----------------------------------------------------
+-- | Compare unboxed words for equality in constant time.
+--
+-- >>> decide (eq_word# 0## 1##)
+-- False
eq_word# :: Word# -> Word# -> Choice
eq_word# a b =
let !s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1#
!x = Exts.xor# a b
- !y = Exts.uncheckedShiftRL# (Exts.or# x (wrapping_neg# x)) s
+ !y = Exts.uncheckedShiftRL# (Exts.or# x (neg_w# x)) s
in Choice (Exts.xor# y 1##)
{-# INLINE eq_word# #-}
+-- | Compare unboxed two-limb words for equality in constant time.
+--
+-- >>> decide (eq_wide (# 0##, 0## #) (# 0##, 0## #))
+-- True
eq_wide#
:: (# Word#, Word# #)
-> (# Word#, Word# #)
@@ -324,10 +432,14 @@ eq_wide#
eq_wide# (# a0, a1 #) (# b0, b1 #) =
let !s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1#
!x = Exts.or# (Exts.xor# a0 b0) (Exts.xor# a1 b1)
- !y = Exts.uncheckedShiftRL# (Exts.or# x (wrapping_neg# x)) s
+ !y = Exts.uncheckedShiftRL# (Exts.or# x (neg_w# x)) s
in Choice (Exts.xor# y 1##)
{-# INLINE eq_wide# #-}
+-- | Compare unboxed four-limb words for equality in constant time.
+--
+-- >>> let zero = (# 0##, 0##, 0##, 0## #) in decide (eq_wider# zero zero)
+-- True
eq_wider#
:: (# Word#, Word#, Word#, Word# #)
-> (# Word#, Word#, Word#, Word# #)
@@ -336,7 +448,7 @@ eq_wider# (# a0, a1, a2, a3 #) (# b0, b1, b2, b3 #) =
let !s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1#
!x = Exts.or# (Exts.or# (Exts.xor# a0 b0) (Exts.xor# a1 b1))
(Exts.or# (Exts.xor# a2 b2) (Exts.xor# a3 b3))
- !y = Exts.uncheckedShiftRL# (Exts.or# x (wrapping_neg# x)) s
+ !y = Exts.uncheckedShiftRL# (Exts.or# x (neg_w# x)) s
in Choice (Exts.xor# y 1##)
{-# INLINE eq_wider# #-}
diff --git a/lib/Data/Word/Wider.hs b/lib/Data/Word/Wider.hs
@@ -281,7 +281,7 @@ shr1_c# (# w0, w1, w2, w3 #) =
!(# s0, c0 #) = (# L.shr# w0 1#, L.shl# w0 s #)
!r0 = L.or# s0 c1
!(Limb w) = L.shr# c0 s
- in (# (# r0, r1, r2, r3 #), C.from_word_lsb# w #)
+ in (# (# r0, r1, r2, r3 #), C.from_word# w #)
{-# INLINE shr1_c# #-}
-- | Constant-time 1-bit shift-right with carry, indicating whether the
@@ -321,7 +321,7 @@ shl1_c# (# w0, w1, w2, w3 #) =
!(# s3, c3 #) = (# L.shl# w3 1#, L.shr# w3 s #)
!r3 = L.or# s3 c2
!(Limb w) = L.shl# c3 s
- in (# (# r0, r1, r2, r3 #), C.from_word_lsb# w #)
+ in (# (# r0, r1, r2, r3 #), C.from_word# w #)
{-# INLINE shl1_c# #-}
-- | Constant-time 1-bit shift-left with carry, indicating whether the
@@ -739,7 +739,7 @@ sqr (Wider w) =
in (Wider l, Wider h)
odd# :: (# Limb, Limb, Limb, Limb #) -> C.Choice
-odd# (# Limb w, _, _, _ #) = C.from_word_lsb# (Exts.and# w 1##)
+odd# (# Limb w, _, _, _ #) = C.from_word# (Exts.and# w 1##)
{-# INLINE odd# #-}
-- | Check if a 'Wider' is odd.