fixed

Pure Haskell large fixed-width integers and Montgomery arithmetic.
git clone git://git.ppad.tech/fixed.git
Log | Files | Refs | README | LICENSE

Choice.hs (10549B)


      1 {-# OPTIONS_HADDOCK prune #-}
      2 {-# LANGUAGE BangPatterns #-}
      3 {-# LANGUAGE MagicHash #-}
      4 {-# LANGUAGE UnliftedNewtypes #-}
      5 {-# LANGUAGE UnboxedTuples #-}
      6 {-# LANGUAGE ViewPatterns #-}
      7 
      8 -- |
      9 -- Module: Data.Choice
     10 -- Copyright: (c) 2025 Jared Tobin
     11 -- License: MIT
     12 -- Maintainer: Jared Tobin <jared@ppad.tech>
     13 --
     14 -- Primitives for constant-time choice.
     15 --
     16 -- The 'Choice' type encodes truthy and falsy values as unboxed 'Word#'
     17 -- bit masks.
     18 --
     19 -- Use the standard logical primitives ('or', 'and', 'xor', 'not', eq')
     20 -- to manipulate in-flight 'Choice' values. Use one of the selection
     21 -- functions to use a 'Choice' to select a value in constant time,
     22 -- or 'decide' to reduce a 'Choice' to a 'Bool' at the /end/ of a
     23 -- sensitive computation.
     24 
     25 module Data.Choice (
     26   -- * Choice
     27     Choice
     28   , decide
     29   , true#
     30   , false#
     31   , to_word#
     32 
     33   -- * Construction
     34   , from_full_mask#
     35   , from_bit#
     36   , from_word_nonzero#
     37   , from_word_eq#
     38   , from_word_le#
     39   , from_word_lt#
     40   , from_word_gt#
     41 
     42   -- * Manipulation
     43   , or
     44   , and
     45   , xor
     46   , not
     47   , ne
     48   , eq
     49 
     50   -- * Constant-time Selection
     51   , select_word#
     52   , select_wide#
     53   , select_wider#
     54 
     55   -- * Constant-time Equality
     56   , eq_word#
     57   , eq_wide#
     58   , eq_wider#
     59   ) where
     60 
     61 import qualified Data.Bits as B
     62 import GHC.Exts (Word#, Int(..), Word(..))
     63 import qualified GHC.Exts as Exts
     64 import Prelude hiding (and, not, or)
     65 
     66 -- utilities ------------------------------------------------------------------
     67 
     68 type Limb2 = (# Word#, Word# #)
     69 
     70 type Limb4 = (# Word#, Word#, Word#, Word# #)
     71 
     72 -- wrapping negation
     73 neg_w# :: Word# -> Word#
     74 neg_w# w = Exts.plusWord# (Exts.not# w) 1##
     75 {-# INLINE neg_w# #-}
     76 
     77 hi# :: Word# -> Limb2
     78 hi# w = (# 0##, w #)
     79 {-# INLINE hi# #-}
     80 
     81 lo# :: Word# -> Limb2
     82 lo# w = (# w, 0## #)
     83 {-# INLINE lo# #-}
     84 
     85 or_w# :: Limb2 -> Limb2 -> Limb2
     86 or_w# (# a0, a1 #) (# b0, b1 #) = (# Exts.or# a0 b0, Exts.or# a1 b1 #)
     87 {-# INLINE or_w# #-}
     88 
     89 and_w# :: Limb2 -> Limb2 -> Limb2
     90 and_w# (# a0, a1 #) (# b0, b1 #) = (# Exts.and# a0 b0, Exts.and# a1 b1 #)
     91 {-# INLINE and_w# #-}
     92 
     93 xor_w# :: Limb2 -> Limb2 -> Limb2
     94 xor_w# (# a0, a1 #) (# b0, b1 #) = (# Exts.xor# a0 b0, Exts.xor# a1 b1 #)
     95 {-# INLINE xor_w# #-}
     96 
     97 -- choice ---------------------------------------------------------------------
     98 
     99 -- | Constant-time choice, encoded as a mask.
    100 --
    101 --   Note that 'Choice' is defined as an unlifted newtype, and so a
    102 --   'Choice' value cannot be bound at the top level. You should work
    103 --   with it locally in the context of a computation.
    104 --
    105 --   Use one of the selection functions to select a 'Choice' value in
    106 --   constant time, or 'decide' to reduce it to a 'Bool' at the /end/ of
    107 --   a sensitive computation.
    108 --
    109 --   >>> decide (or# (false# ()) (true# ()))
    110 --   True
    111 newtype Choice = Choice Word#
    112 
    113 -- | Construct the falsy 'Choice'.
    114 --
    115 --   >>> decide (false# ())
    116 --   False
    117 false# :: () -> Choice
    118 false# _ = Choice 0##
    119 {-# INLINE false# #-}
    120 
    121 -- | Construct the truthy 'Choice'.
    122 --
    123 --   >>> decide (true# ())
    124 --   True
    125 true# :: () -> Choice
    126 true# _ = case maxBound :: Word of
    127   W# w -> Choice w
    128 {-# INLINE true# #-}
    129 
    130 -- | Decide a 'Choice' by reducing it to a 'Bool'.
    131 --
    132 --   The 'decide' function itself runs in constant time, but once
    133 --   it reduces a 'Choice' to a 'Bool', any subsequent branching on
    134 --   the result is liable to introduce variable-time behaviour.
    135 --
    136 --   You should 'decide' only at the /end/ of a computation, after all
    137 --   security-sensitive computations have been carried out.
    138 --
    139 --   >>> decide (true# ())
    140 --   True
    141 decide :: Choice -> Bool
    142 decide (Choice c) = Exts.isTrue# (Exts.neWord# c 0##)
    143 {-# INLINE decide #-}
    144 
    145 -- | Convert a 'Choice' to an unboxed 'Word#'.
    146 --
    147 --   This essentially "unboxes" the 'Choice' for direct manipulation.
    148 --
    149 --   >>> import qualified GHC.Exts as Exts
    150 --   >>> Exts.isTrue# (Exts.eqWord# 0## (to_word# (false# ())))
    151 --   True
    152 to_word# :: Choice -> Word#
    153 to_word# (Choice c) = Exts.and# c 1##
    154 {-# INLINE to_word# #-}
    155 
    156 -- construction ---------------------------------------------------------------
    157 
    158 -- | Construct a 'Choice' from an unboxed full-word mask.
    159 --
    160 --   The input is /not/ checked to be a full-word mask.
    161 --
    162 --   >>> decide (from_full_mask# 0##)
    163 --   False
    164 --   >>> decide (from_full_mask# 0xFFFFFFFFF_FFFFFFFF##)
    165 --   True
    166 from_full_mask# :: Word# -> Choice
    167 from_full_mask# w = Choice w
    168 {-# INLINE from_full_mask# #-}
    169 
    170 -- | Construct a 'Choice' from an unboxed word, which should be either
    171 --   0## or 1##.
    172 --
    173 --   The input is /not/ checked to be a bit.
    174 --
    175 --   >>> decide (from_bit# 1##)
    176 --   True
    177 from_bit# :: Word# -> Choice
    178 from_bit# w = Choice (neg_w# w)
    179 {-# INLINE from_bit# #-}
    180 
    181 -- | Construct a 'Choice' from a /nonzero/ unboxed word.
    182 --
    183 --   The input is /not/ checked to be nonzero.
    184 --
    185 --   >>> decide (from_word_nonzero# 2##)
    186 --   True
    187 from_word_nonzero# :: Word# -> Choice
    188 from_word_nonzero# w =
    189   let !n = neg_w# w
    190       !s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1#
    191       !v = Exts.uncheckedShiftRL# (Exts.or# w n) s
    192   in  from_bit# v
    193 {-# INLINE from_word_nonzero# #-}
    194 
    195 -- | Construct a 'Choice' from an equality comparison.
    196 --
    197 --   >>> decide (from_word_eq# 0## 1##)
    198 --   False
    199 --   decide (from_word_eq# 1## 1##)
    200 --   True
    201 from_word_eq# :: Word# -> Word# -> Choice
    202 from_word_eq# x y = case from_word_nonzero# (Exts.xor# x y) of
    203   Choice w -> Choice (Exts.not# w)
    204 {-# INLINE from_word_eq# #-}
    205 
    206 -- | Construct a 'Choice from an at-most comparison.
    207 --
    208 --   >>> decide (from_word_le# 0## 1##)
    209 --   True
    210 --   >>> decide (from_word_le# 1## 1##)
    211 --   True
    212 from_word_le# :: Word# -> Word# -> Choice
    213 from_word_le# x y =
    214   let !s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1#
    215       !bit =
    216         Exts.uncheckedShiftRL#
    217           (Exts.and#
    218             (Exts.or# (Exts.not# x) y)
    219             (Exts.or# (Exts.xor# x y) (Exts.not# (Exts.minusWord# y x))))
    220           s
    221   in  from_bit# bit
    222 {-# INLINE from_word_le# #-}
    223 
    224 -- | Construct a 'Choice' from a less-than comparison.
    225 --
    226 --   >>> decide (from_word_lt# 0## 1##)
    227 --   True
    228 --   >>> decide (from_word_lt# 1## 1##)
    229 --   False
    230 from_word_lt# :: Word# -> Word# -> Choice
    231 from_word_lt# x y =
    232   let !s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1#
    233       !bit =
    234         Exts.uncheckedShiftRL#
    235           (Exts.or#
    236             (Exts.and# (Exts.not# x) y)
    237             (Exts.and# (Exts.or# (Exts.not# x) y) (Exts.minusWord# x y)))
    238           s
    239   in  from_bit# bit
    240 {-# INLINE from_word_lt# #-}
    241 
    242 -- | Construct a 'Choice' from a greater-than comparison.
    243 --
    244 --   >>> decide (from_word_gt# 0## 1##)
    245 --   False
    246 --   >>> decide (from_word_gt# 1## 1##)
    247 --   False
    248 from_word_gt# :: Word# -> Word# -> Choice
    249 from_word_gt# x y = from_word_lt# y x
    250 {-# INLINE from_word_gt# #-}
    251 
    252 -- manipulation ---------------------------------------------------------------
    253 
    254 -- | Logically negate a 'Choice'.
    255 --
    256 --   >>> decide (not (true# ()))
    257 --   False
    258 --   >>> decide (not (false# ()))
    259 --   True
    260 not :: Choice -> Choice
    261 not (Choice w) = Choice (Exts.not# w)
    262 {-# INLINE not #-}
    263 
    264 -- | Logical disjunction on 'Choice' values.
    265 --
    266 --   >>> decide (or (true# ()) (false# ()))
    267 --   True
    268 or :: Choice -> Choice -> Choice
    269 or (Choice w0) (Choice w1) = Choice (Exts.or# w0 w1)
    270 {-# INLINE or #-}
    271 
    272 -- | Logical conjunction on 'Choice' values.
    273 --
    274 --   >>> decide (and (true# ()) (false# ()))
    275 --   False
    276 and :: Choice -> Choice -> Choice
    277 and (Choice w0) (Choice w1) = Choice (Exts.and# w0 w1)
    278 {-# INLINE and #-}
    279 
    280 -- | Logical inequality on 'Choice' values.
    281 --
    282 --   >>> decide (xor (true# ()) (false# ()))
    283 --   True
    284 xor :: Choice -> Choice -> Choice
    285 xor (Choice w0) (Choice w1) = Choice (Exts.xor# w0 w1)
    286 {-# INLINE xor #-}
    287 
    288 -- | Logical inequality on 'Choice' values.
    289 --
    290 --   >>> decide (ne (true# ()) (false# ()))
    291 --   True
    292 ne :: Choice -> Choice -> Choice
    293 ne c0 c1 = xor c0 c1
    294 {-# INLINE ne #-}
    295 
    296 -- | Logical equality on 'Choice' values.
    297 --
    298 --   >>> decide (eq (true# ()) (false# ()))
    299 --   False
    300 eq :: Choice -> Choice -> Choice
    301 eq c0 c1 = not (ne c0 c1)
    302 {-# INLINE eq #-}
    303 
    304 -- constant-time selection ----------------------------------------------------
    305 
    306 -- | Select an unboxed word without branching, given a 'Choice'.
    307 --
    308 --   >>> let w = C.select_word# 0## 1## (C.true# ()) in GHC.Word.W# w
    309 --   1
    310 select_word# :: Word# -> Word# -> Choice -> Word#
    311 select_word# a b (Choice c) = Exts.xor# a (Exts.and# c (Exts.xor# a b))
    312 {-# INLINE select_word# #-}
    313 
    314 -- | Select an unboxed two-limb word without branching, given a 'Choice'.
    315 select_wide#
    316   :: Limb2
    317   -> Limb2
    318   -> Choice
    319   -> Limb2
    320 select_wide# a b (Choice w) =
    321   let !mask = or_w# (hi# w) (lo# w)
    322   in  xor_w# a (and_w# mask (xor_w# a b))
    323 {-# INLINE select_wide# #-}
    324 
    325 -- | Select an unboxed four-limb word without branching, given a 'Choice'.
    326 select_wider#
    327   :: Limb4
    328   -> Limb4
    329   -> Choice
    330   -> Limb4
    331 select_wider# (# a0, a1, a2, a3 #) (# b0, b1, b2, b3 #) (Choice w) =
    332   let !w0 = Exts.xor# a0 (Exts.and# w (Exts.xor# a0 b0))
    333       !w1 = Exts.xor# a1 (Exts.and# w (Exts.xor# a1 b1))
    334       !w2 = Exts.xor# a2 (Exts.and# w (Exts.xor# a2 b2))
    335       !w3 = Exts.xor# a3 (Exts.and# w (Exts.xor# a3 b3))
    336   in  (# w0, w1, w2, w3 #)
    337 {-# INLINE select_wider# #-}
    338 
    339 -- constant-time equality -----------------------------------------------------
    340 
    341 -- | Compare unboxed words for equality in constant time.
    342 --
    343 --   >>> decide (eq_word# 0## 1##)
    344 --   False
    345 eq_word# :: Word# -> Word# -> Choice
    346 eq_word# a b =
    347   let !s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1#
    348       !x = Exts.xor# a b
    349       !y = Exts.uncheckedShiftRL# (Exts.or# x (neg_w# x)) s
    350   in  Choice (Exts.xor# y 1##)
    351 {-# INLINE eq_word# #-}
    352 
    353 -- | Compare unboxed two-limb words for equality in constant time.
    354 --
    355 --   >>> decide (eq_wide (# 0##, 0## #) (# 0##, 0## #))
    356 --   True
    357 eq_wide#
    358   :: Limb2
    359   -> Limb2
    360   -> Choice
    361 eq_wide# (# a0, a1 #) (# b0, b1 #) =
    362   let !s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1#
    363       !x = Exts.or# (Exts.xor# a0 b0) (Exts.xor# a1 b1)
    364       !y = Exts.uncheckedShiftRL# (Exts.or# x (neg_w# x)) s
    365   in  Choice (Exts.xor# y 1##)
    366 {-# INLINE eq_wide# #-}
    367 
    368 -- | Compare unboxed four-limb words for equality in constant time.
    369 --
    370 --   >>> let zero = (# 0##, 0##, 0##, 0## #) in decide (eq_wider# zero zero)
    371 --   True
    372 eq_wider#
    373   :: Limb4
    374   -> Limb4
    375   -> Choice
    376 eq_wider# (# a0, a1, a2, a3 #) (# b0, b1, b2, b3 #) =
    377   let !s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1#
    378       !x = Exts.or# (Exts.or# (Exts.xor# a0 b0) (Exts.xor# a1 b1))
    379                     (Exts.or# (Exts.xor# a2 b2) (Exts.xor# a3 b3))
    380       !y = Exts.uncheckedShiftRL# (Exts.or# x (neg_w# x)) s
    381   in  Choice (Exts.xor# y 1##)
    382 {-# INLINE eq_wider# #-}
    383