fixed

Pure Haskell large fixed-width integers and Montgomery arithmetic.
git clone git://git.ppad.tech/fixed.git
Log | Files | Refs | README | LICENSE

Choice.hs (12605B)


      1 {-# LANGUAGE BangPatterns #-}
      2 {-# LANGUAGE MagicHash #-}
      3 {-# LANGUAGE UnliftedNewtypes #-}
      4 {-# LANGUAGE UnboxedTuples #-}
      5 {-# LANGUAGE ViewPatterns #-}
      6 
      7 -- |
      8 -- Module: Data.Choice
      9 -- Copyright: (c) 2025 Jared Tobin
     10 -- License: MIT
     11 -- Maintainer: Jared Tobin <jared@ppad.tech>
     12 --
     13 -- Constant-time choice.
     14 
     15 module Data.Choice (
     16   -- * Choice
     17     Choice
     18   , true#
     19   , false#
     20   , decide
     21   , to_word#
     22 
     23   -- * MaybeWord#
     24   , MaybeWord#(..)
     25   , some_word#
     26   , none_word#
     27 
     28   -- * MaybeWide#
     29   , MaybeWide#(..)
     30   , some_wide#
     31   , just_wide#
     32   , none_wide#
     33   , expect_wide#
     34   , expect_wide_or#
     35 
     36   -- * Construction
     37   , from_word_mask#
     38   , from_word#
     39   , from_word_nonzero#
     40   , from_word_eq#
     41   , from_word_le#
     42   , from_word_lt#
     43   , from_word_gt#
     44 
     45   , from_wide#
     46   , from_wide_le#
     47 
     48   -- * Manipulation
     49   , or#
     50   , and#
     51   , xor#
     52   , not#
     53   , ne#
     54   , eq#
     55 
     56   -- * Constant-time Selection
     57   , select_word#
     58   , select_wide#
     59   , select_wider#
     60 
     61   -- * Constant-time Equality
     62   , eq_word#
     63   , eq_wide#
     64   , eq_wider#
     65   ) where
     66 
     67 import qualified Data.Bits as B
     68 import GHC.Exts (Word#, Int(..), Word(..))
     69 import qualified GHC.Exts as Exts
     70 
     71 -- utilities ------------------------------------------------------------------
     72 
     73 -- wrapping negation
     74 neg_w# :: Word# -> Word#
     75 neg_w# w = Exts.plusWord# (Exts.not# w) 1##
     76 {-# INLINE neg_w# #-}
     77 
     78 hi# :: Word# -> (# Word#, Word# #)
     79 hi# w = (# 0##, w #)
     80 {-# INLINE hi# #-}
     81 
     82 lo# :: Word# -> (# Word#, Word# #)
     83 lo# w = (# w, 0## #)
     84 {-# INLINE lo# #-}
     85 
     86 not_w# :: (# Word#, Word# #) -> (# Word#, Word# #)
     87 not_w# (# a0, a1 #) = (# Exts.not# a0, Exts.not# a1 #)
     88 {-# INLINE not_w# #-}
     89 
     90 or_w# :: (# Word#, Word# #) -> (# Word#, Word# #) -> (# Word#, Word# #)
     91 or_w# (# a0, a1 #) (# b0, b1 #) = (# Exts.or# a0 b0, Exts.or# a1 b1 #)
     92 {-# INLINE or_w# #-}
     93 
     94 and_w# :: (# Word#, Word# #) -> (# Word#, Word# #) -> (# Word#, Word# #)
     95 and_w# (# a0, a1 #) (# b0, b1 #) = (# Exts.and# a0 b0, Exts.and# a1 b1 #)
     96 {-# INLINE and_w# #-}
     97 
     98 xor_w# :: (# Word#, Word# #) -> (# Word#, Word# #) -> (# Word#, Word# #)
     99 xor_w# (# a0, a1 #) (# b0, b1 #) = (# Exts.xor# a0 b0, Exts.xor# a1 b1 #)
    100 {-# INLINE xor_w# #-}
    101 
    102 -- subtract-with-borrow
    103 sub_b# :: Word# -> Word# -> Word# -> (# Word#, Word# #)
    104 sub_b# m n b =
    105   let !(# d0, b0 #) = Exts.subWordC# m n
    106       !(#  d, b1 #) = Exts.subWordC# d0 b
    107       !c = Exts.int2Word# (Exts.orI# b0 b1)
    108   in  (# d, c #)
    109 {-# INLINE sub_b# #-}
    110 
    111 -- wide subtract-with-borrow
    112 sub_wb#
    113   :: (# Word#, Word# #)
    114   -> (# Word#, Word# #)
    115   -> (# Word#, Word#, Word# #)
    116 sub_wb# (# a0, a1 #) (# b0, b1 #) =
    117   let !(# s0, c0 #) = sub_b# a0 b0 0##
    118       !(# s1, c1 #) = sub_b# a1 b1 c0
    119   in  (# s0, s1, c1 #)
    120 {-# INLINE sub_wb# #-}
    121 
    122 -- wide subtraction (wrapping)
    123 sub_w#
    124   :: (# Word#, Word# #)
    125   -> (# Word#, Word# #)
    126   -> (# Word#, Word# #)
    127 sub_w# a b =
    128   let !(# c0, c1, _ #) = sub_wb# a b
    129   in  (# c0, c1 #)
    130 {-# INLINE sub_w# #-}
    131 
    132 -- choice ---------------------------------------------------------------------
    133 
    134 -- | Constant-time choice, encoded as a mask.
    135 --
    136 --   Note that 'Choice' is defined as an unboxed newtype, and so a
    137 --   'Choice' value cannot be bound at the top level. You should work
    138 --   with it locally in the context of a computation.
    139 --
    140 --   It's safe to 'decide' a choice, reducing it to a 'Bool', at any
    141 --   time, but the general encouraged pattern is to do that only at the
    142 --   end of a computation.
    143 --
    144 --   >>> decide (or# (false# ()) (true# ()))
    145 --   True
    146 newtype Choice = Choice Word#
    147 
    148 -- | Construct the falsy value.
    149 --
    150 --   >>> decide (false# ())
    151 --   False
    152 false# :: () -> Choice
    153 false# _ = Choice 0##
    154 {-# INLINE false# #-}
    155 
    156 -- | Construct the truthy value.
    157 --
    158 --   >>> decide (true# ())
    159 --   True
    160 true# :: () -> Choice
    161 true# _ = case maxBound :: Word of
    162   W# w -> Choice w
    163 {-# INLINE true# #-}
    164 
    165 -- | Decide a 'Choice' by reducing it to a 'Bool'.
    166 --
    167 --   >>> decide (true# ())
    168 --   True
    169 decide :: Choice -> Bool
    170 decide (Choice c) = Exts.isTrue# (Exts.neWord# c 0##)
    171 {-# INLINE decide #-}
    172 
    173 -- | Convert a 'Choice' to an unboxed 'Word#'.
    174 to_word# :: Choice -> Word#
    175 to_word# (Choice c) = Exts.and# c 1##
    176 {-# INLINE to_word# #-}
    177 
    178 -- constant time 'Maybe Word#'
    179 newtype MaybeWord# = MaybeWord# (# Word#, Choice #)
    180 
    181 some_word# :: Word# -> MaybeWord#
    182 some_word# w = MaybeWord# (# w, true# () #)
    183 {-# INLINE some_word# #-}
    184 
    185 none_word# :: Word# -> MaybeWord#
    186 none_word# w = MaybeWord# (# w, false# () #)
    187 {-# INLINE none_word# #-}
    188 
    189 -- constant time 'Maybe (# Word#, Word# #)'
    190 newtype MaybeWide# = MaybeWide# (# (# Word#, Word# #), Choice #)
    191 
    192 just_wide# :: (# Word#, Word# #) -> Choice -> MaybeWide#
    193 just_wide# w c = MaybeWide# (# w, c #)
    194 {-# INLINE just_wide# #-}
    195 
    196 some_wide# :: (# Word#, Word# #) -> MaybeWide#
    197 some_wide# w = MaybeWide# (# w, true# () #)
    198 {-# INLINE some_wide# #-}
    199 
    200 none_wide# :: (# Word#, Word# #) -> MaybeWide#
    201 none_wide# w = MaybeWide# (# w, false# () #)
    202 {-# INLINE none_wide# #-}
    203 
    204 expect_wide# :: MaybeWide# -> String -> (# Word#, Word# #)
    205 expect_wide# (MaybeWide# (# w, Choice c #)) msg
    206     | Exts.isTrue# (Exts.eqWord# c t#) = w
    207     | otherwise = error $ "ppad-fixed (expect_wide#): " <> msg
    208   where
    209     !(Choice t#) = true# ()
    210 {-# INLINE expect_wide# #-}
    211 
    212 expect_wide_or# :: MaybeWide# -> (# Word#, Word# #) -> (# Word#, Word# #)
    213 expect_wide_or# (MaybeWide# (# w, Choice c #)) alt
    214     | Exts.isTrue# (Exts.eqWord# c t#) = w
    215     | otherwise = alt
    216   where
    217     !(Choice t#) = true# ()
    218 {-# INLINE expect_wide_or# #-}
    219 
    220 -- construction ---------------------------------------------------------------
    221 
    222 -- | Construct a 'Choice' from an unboxed mask.
    223 --
    224 --   The input is /not/ checked.
    225 --
    226 --   >>> decide (from_word_mask# 0##)
    227 --   False
    228 --   >>> decide (from_word_mask# 0xFFFFFFFFF_FFFFFFFF##)
    229 --   True
    230 from_word_mask# :: Word# -> Choice
    231 from_word_mask# w = Choice w
    232 {-# INLINE from_word_mask# #-}
    233 
    234 -- | Construct a 'Choice' from an unboxed word, which should be either
    235 --   0## or 1##.
    236 --
    237 --   The input is /not/ checked.
    238 --
    239 --   >>> decide (from_word# 1##)
    240 --   True
    241 from_word# :: Word# -> Choice
    242 from_word# w = Choice (neg_w# w)
    243 {-# INLINE from_word# #-}
    244 
    245 -- | Construct a 'Choice' from a two-limb word, constructing a mask from
    246 --   the lower limb, which should be 0## or 1##.
    247 --
    248 --   The input is /not/ checked.
    249 --
    250 --   >>> decide (from_wide# (# 0##, 1## #))
    251 --   False
    252 from_wide# :: (# Word#, Word# #) -> Choice
    253 from_wide# (# l, _ #) = from_word# l
    254 {-# INLINE from_wide# #-}
    255 
    256 -- | Construct a 'Choice' from a /nonzero/ unboxed word.
    257 --
    258 --   The input is /not/ checked.
    259 --
    260 --   >>> decide (from_word_nonzero# 2##)
    261 --   True
    262 from_word_nonzero# :: Word# -> Choice
    263 from_word_nonzero# w =
    264   let !n = neg_w# w
    265       !s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1#
    266       !v = Exts.uncheckedShiftRL# (Exts.or# w n) s
    267   in  from_word# v
    268 {-# INLINE from_word_nonzero# #-}
    269 
    270 -- | Construct a 'Choice' from an equality comparison.
    271 --
    272 --   >>> decide (from_word_eq# 0## 1##)
    273 --   False
    274 --   decide (from_word_eq# 1## 1##)
    275 --   True
    276 from_word_eq# :: Word# -> Word# -> Choice
    277 from_word_eq# x y = case from_word_nonzero# (Exts.xor# x y) of
    278   Choice w -> Choice (Exts.not# w)
    279 {-# INLINE from_word_eq# #-}
    280 
    281 -- | Construct a 'Choice from an at most comparison.
    282 --
    283 --   >>> decide (from_word_le# 0## 1##)
    284 --   True
    285 --   >>> decide (from_word_le# 1## 1##)
    286 --   True
    287 from_word_le# :: Word# -> Word# -> Choice
    288 from_word_le# x y =
    289   let !s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1#
    290       !bit =
    291         Exts.uncheckedShiftRL#
    292           (Exts.and#
    293             (Exts.or# (Exts.not# x) y)
    294             (Exts.or# (Exts.xor# x y) (Exts.not# (Exts.minusWord# y x))))
    295           s
    296   in  from_word# bit
    297 {-# INLINE from_word_le# #-}
    298 
    299 -- | Construct a 'Choice' from an at most comparison on a two-limb
    300 --   unboxed word.
    301 --
    302 --   >>> decide (from_wide_le# (# 0##, 0## #) (# 1##, 0## #))
    303 --   True
    304 --   >>> decide (from_wide_le# (# 1##, 0## #) (# 1##, 0## #))
    305 --   True
    306 from_wide_le# :: (# Word#, Word# #) -> (# Word#, Word# #) -> Choice
    307 from_wide_le# x y =
    308   let !s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1#
    309       !mask =
    310         (and_w#
    311           (or_w# (not_w# x) y)
    312           (or_w# (xor_w# x y) (not_w# (sub_w# y x))))
    313       !bit = case mask of
    314         (# l, _ #) -> Exts.uncheckedShiftRL# l s
    315   in  from_word# bit
    316 {-# INLINE from_wide_le# #-}
    317 
    318 -- | Construct a 'Choice' from a less-than comparison.
    319 --
    320 --   >>> decide (from_word_lt# 0## 1##)
    321 --   True
    322 --   >>> decide (from_word_lt# 1## 1##)
    323 --   False
    324 from_word_lt# :: Word# -> Word# -> Choice
    325 from_word_lt# x y =
    326   let !s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1#
    327       !bit =
    328         Exts.uncheckedShiftRL#
    329           (Exts.or#
    330             (Exts.and# (Exts.not# x) y)
    331             (Exts.and# (Exts.or# (Exts.not# x) y) (Exts.minusWord# x y)))
    332           s
    333   in  from_word# bit
    334 {-# INLINE from_word_lt# #-}
    335 
    336 -- | Construct a 'Choice' from a greater-than comparison.
    337 --
    338 --   >>> decide (from_word_gt# 0## 1##)
    339 --   False
    340 --   >>> decide (from_word_gt# 1## 1##)
    341 --   False
    342 from_word_gt# :: Word# -> Word# -> Choice
    343 from_word_gt# x y = from_word_lt# y x
    344 {-# INLINE from_word_gt# #-}
    345 
    346 -- manipulation ---------------------------------------------------------------
    347 
    348 -- | Logically negate a 'Choice'.
    349 not# :: Choice -> Choice
    350 not# (Choice w) = Choice (Exts.not# w)
    351 {-# INLINE not# #-}
    352 
    353 -- | Logical disjunction on 'Choice' values.
    354 or# :: Choice -> Choice -> Choice
    355 or# (Choice w0) (Choice w1) = Choice (Exts.or# w0 w1)
    356 {-# INLINE or# #-}
    357 
    358 -- | Logical conjunction on 'Choice' values.
    359 and# :: Choice -> Choice -> Choice
    360 and# (Choice w0) (Choice w1) = Choice (Exts.and# w0 w1)
    361 {-# INLINE and# #-}
    362 
    363 -- | Logical inequality on 'Choice' values.
    364 xor# :: Choice -> Choice -> Choice
    365 xor# (Choice w0) (Choice w1) = Choice (Exts.xor# w0 w1)
    366 {-# INLINE xor# #-}
    367 
    368 -- | Logical inequality on 'Choice' values.
    369 ne# :: Choice -> Choice -> Choice
    370 ne# c0 c1 = xor# c0 c1
    371 {-# INLINE ne# #-}
    372 
    373 -- | Logical equality on 'Choice' values.
    374 eq# :: Choice -> Choice -> Choice
    375 eq# c0 c1 = not# (ne# c0 c1)
    376 {-# INLINE eq# #-}
    377 
    378 -- constant-time selection ----------------------------------------------------
    379 
    380 -- | Select an unboxed word, given a 'Choice'.
    381 select_word# :: Word# -> Word# -> Choice -> Word#
    382 select_word# a b (Choice c) = Exts.xor# a (Exts.and# c (Exts.xor# a b))
    383 {-# INLINE select_word# #-}
    384 
    385 -- | Select an unboxed two-limb word, given a 'Choice'.
    386 select_wide#
    387   :: (# Word#, Word# #)
    388   -> (# Word#, Word# #)
    389   -> Choice
    390   -> (# Word#, Word# #)
    391 select_wide# a b (Choice w) =
    392   let !mask = or_w# (hi# w) (lo# w)
    393   in  xor_w# a (and_w# mask (xor_w# a b))
    394 {-# INLINE select_wide# #-}
    395 
    396 -- | Select an unboxed four-limb word, given a 'Choice'.
    397 select_wider#
    398   :: (# Word#, Word#, Word#, Word# #)
    399   -> (# Word#, Word#, Word#, Word# #)
    400   -> Choice
    401   -> (# Word#, Word#, Word#, Word# #)
    402 select_wider# (# a0, a1, a2, a3 #) (# b0, b1, b2, b3 #) (Choice w) =
    403   let !w0 = Exts.xor# a0 (Exts.and# w (Exts.xor# a0 b0))
    404       !w1 = Exts.xor# a1 (Exts.and# w (Exts.xor# a1 b1))
    405       !w2 = Exts.xor# a2 (Exts.and# w (Exts.xor# a2 b2))
    406       !w3 = Exts.xor# a3 (Exts.and# w (Exts.xor# a3 b3))
    407   in  (# w0, w1, w2, w3 #)
    408 {-# INLINE select_wider# #-}
    409 
    410 -- constant-time equality -----------------------------------------------------
    411 
    412 -- | Compare unboxed words for equality in constant time.
    413 --
    414 --   >>> decide (eq_word# 0## 1##)
    415 --   False
    416 eq_word# :: Word# -> Word# -> Choice
    417 eq_word# a b =
    418   let !s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1#
    419       !x = Exts.xor# a b
    420       !y = Exts.uncheckedShiftRL# (Exts.or# x (neg_w# x)) s
    421   in  Choice (Exts.xor# y 1##)
    422 {-# INLINE eq_word# #-}
    423 
    424 -- | Compare unboxed two-limb words for equality in constant time.
    425 --
    426 --   >>> decide (eq_wide (# 0##, 0## #) (# 0##, 0## #))
    427 --   True
    428 eq_wide#
    429   :: (# Word#, Word# #)
    430   -> (# Word#, Word# #)
    431   -> Choice
    432 eq_wide# (# a0, a1 #) (# b0, b1 #) =
    433   let !s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1#
    434       !x = Exts.or# (Exts.xor# a0 b0) (Exts.xor# a1 b1)
    435       !y = Exts.uncheckedShiftRL# (Exts.or# x (neg_w# x)) s
    436   in  Choice (Exts.xor# y 1##)
    437 {-# INLINE eq_wide# #-}
    438 
    439 -- | Compare unboxed four-limb words for equality in constant time.
    440 --
    441 --   >>> let zero = (# 0##, 0##, 0##, 0## #) in decide (eq_wider# zero zero)
    442 --   True
    443 eq_wider#
    444   :: (# Word#, Word#, Word#, Word# #)
    445   -> (# Word#, Word#, Word#, Word# #)
    446   -> Choice
    447 eq_wider# (# a0, a1, a2, a3 #) (# b0, b1, b2, b3 #) =
    448   let !s = case B.finiteBitSize (0 :: Word) of I# m -> m Exts.-# 1#
    449       !x = Exts.or# (Exts.or# (Exts.xor# a0 b0) (Exts.xor# a1 b1))
    450                     (Exts.or# (Exts.xor# a2 b2) (Exts.xor# a3 b3))
    451       !y = Exts.uncheckedShiftRL# (Exts.or# x (neg_w# x)) s
    452   in  Choice (Exts.xor# y 1##)
    453 {-# INLINE eq_wider# #-}
    454