fixed

Pure Haskell large fixed-width integers.
git clone git://git.ppad.tech/fixed.git
Log | Files | Refs | README | LICENSE

Wider.hs (13618B)


      1 {-# LANGUAGE BangPatterns #-}
      2 {-# LANGUAGE MagicHash #-}
      3 {-# LANGUAGE NumericUnderscores #-}
      4 {-# LANGUAGE ViewPatterns #-}
      5 {-# LANGUAGE UnboxedSums #-}
      6 {-# LANGUAGE UnboxedTuples #-}
      7 {-# LANGUAGE UnliftedNewtypes #-}
      8 
      9 -- |
     10 -- Module: Data.Word.Wider
     11 -- Copyright: (c) 2025 Jared Tobin
     12 -- License: MIT
     13 -- Maintainer: Jared Tobin <jared@ppad.tech>
     14 --
     15 -- Wider words, consisting of four 'Limb's.
     16 
     17 module Data.Word.Wider where
     18 
     19 import Control.DeepSeq
     20 import Data.Bits ((.|.), (.&.), (.<<.), (.>>.))
     21 import qualified Data.Bits as B
     22 import qualified Data.Choice as C
     23 import Data.Word.Limb (Limb(..))
     24 import qualified Data.Word.Limb as L
     25 import GHC.Exts ( Word(..), Int(..), Int#
     26                 , (-#), (*#)
     27                 , word2Int#, eqWord#, andI#, isTrue#
     28                 )
     29 import Prelude hiding (div, mod, or, and, not, quot, rem, recip)
     30 
     31 -- utilities ------------------------------------------------------------------
     32 
     33 fi :: (Integral a, Num b) => a -> b
     34 fi = fromIntegral
     35 {-# INLINE fi #-}
     36 
     37 -- wider words ----------------------------------------------------------------
     38 
     39 -- | Little-endian wider words.
     40 data Wider = Wider !(# Limb, Limb, Limb, Limb #)
     41 
     42 instance Show Wider where
     43   show = show . from
     44 
     45 instance Num Wider where
     46   (+) = add
     47   (-) = sub
     48   (*) = mul
     49   abs = id
     50   fromInteger = to
     51   negate w = add (not w) (Wider (# Limb 1##, Limb 0##, Limb 0##, Limb 0## #))
     52   signum a = case a of
     53     Wider (# Limb 0##, Limb 0##, Limb 0##, Limb 0## #) -> 0
     54     _ -> 1
     55 
     56 instance NFData Wider where
     57   rnf (Wider a) = case a of
     58     (# _, _, _, _ #) -> ()
     59 
     60 -- comparison -----------------------------------------------------------------
     61 
     62 eq#
     63   :: (# Limb, Limb, Limb, Limb #)
     64   -> (# Limb, Limb, Limb, Limb #)
     65   -> C.Choice
     66 eq# a b =
     67   let !(# Limb a0, Limb a1, Limb a2, Limb a3 #) = a
     68       !(# Limb b0, Limb b1, Limb b2, Limb b3 #) = b
     69   in  C.ct_eq_wider# (# a0, a1, a2, a3 #) (# b0, b1, b2, b3 #)
     70 {-# INLINE eq# #-}
     71 
     72 -- | Compare 'Wider' words for equality in variable time.
     73 eq_vartime :: Wider -> Wider -> Bool
     74 eq_vartime a b =
     75   let !(Wider (# Limb a0, Limb a1, Limb a2, Limb a3 #)) = a
     76       !(Wider (# Limb b0, Limb b1, Limb b2, Limb b3 #)) = b
     77   in  isTrue# $
     78         andI#
     79           (andI# (eqWord# a0 b0) (eqWord# a1 b1))
     80           (andI# (eqWord# a2 b2) (eqWord# a3 b3))
     81 
     82 lt#
     83   :: (# Limb, Limb, Limb, Limb #)
     84   -> (# Limb, Limb, Limb, Limb #)
     85   -> C.Choice
     86 lt# a b =
     87   let !(# _, Limb bor #) = sub_b# a b
     88   in  C.from_word_mask# bor
     89 {-# INLINE lt# #-}
     90 
     91 gt#
     92   :: (# Limb, Limb, Limb, Limb #)
     93   -> (# Limb, Limb, Limb, Limb #)
     94   -> C.Choice
     95 gt# a b =
     96   let !(# _, Limb bor #) = sub_b# b a
     97   in  C.from_word_mask# bor
     98 {-# INLINE gt# #-}
     99 
    100 cmp#
    101   :: (# Limb, Limb, Limb, Limb #)
    102   -> (# Limb, Limb, Limb, Limb #)
    103   -> Int#
    104 cmp# (# l0, l1, l2, l3 #) (# r0, r1, r2, r3 #) =
    105   let !(# w0, b0 #) = L.sub_b# r0 l0 (Limb 0##)
    106       !d0 = L.or# (Limb 0##) w0
    107       !(# w1, b1 #) = L.sub_b# r1 l1 b0
    108       !d1 = L.or# d0 w1
    109       !(# w2, b2 #) = L.sub_b# r2 l2 b1
    110       !d2 = L.or# d1 w2
    111       !(# w3, b3 #) = L.sub_b# r3 l3 b2
    112       !d3 = L.or# d2 w3
    113       !(Limb w) = L.and# b3 (Limb 2##)
    114       !s = word2Int# w -# 1#
    115   in  (word2Int# (C.to_word# (L.nonzero# d3))) *# s
    116 {-# INLINE cmp# #-}
    117 
    118 -- | Constant-time comparison between 'Wider' words.
    119 cmp :: Wider -> Wider -> Ordering
    120 cmp (Wider a) (Wider b) = case cmp# a b of
    121   1#  -> GT
    122   0#  -> EQ
    123   _   -> LT
    124 
    125 -- construction / conversion --------------------------------------------------
    126 
    127 -- | Construct a 'Wider' word from four 'Words', provided in
    128 --   little-endian order.
    129 wider :: Word -> Word -> Word -> Word -> Wider
    130 wider (W# w0) (W# w1) (W# w2) (W# w3) = Wider
    131   (# Limb w0, Limb w1, Limb w2, Limb w3 #)
    132 
    133 -- | Convert an 'Integer' to a 'Wider' word.
    134 to :: Integer -> Wider
    135 to n =
    136   let !size = B.finiteBitSize (0 :: Word)
    137       !mask = fi (maxBound :: Word) :: Integer
    138       !(W# w0) = fi (n .&. mask)
    139       !(W# w1) = fi ((n .>>. size) .&. mask)
    140       !(W# w2) = fi ((n .>>. (2 * size)) .&. mask)
    141       !(W# w3) = fi ((n .>>. (3 * size)) .&. mask)
    142   in  Wider (# Limb w0, Limb w1, Limb w2, Limb w3 #)
    143 
    144 -- | Convert a 'Wider' word to an 'Integer'.
    145 from :: Wider -> Integer
    146 from (Wider (# Limb w0, Limb w1, Limb w2, Limb w3 #)) =
    147         fi (W# w3) .<<. (3 * size)
    148     .|. fi (W# w2) .<<. (2 * size)
    149     .|. fi (W# w1) .<<. size
    150     .|. fi (W# w0)
    151   where
    152     !size = B.finiteBitSize (0 :: Word)
    153 
    154 -- bit manipulation -----------------------------------------------------------
    155 
    156 -- | Constant-time 1-bit shift-right with carry, indicating whether the
    157 --   lowest bit was set.
    158 shr1_c#
    159   :: (# Limb, Limb, Limb, Limb #)                 -- ^ argument
    160   -> (# (# Limb, Limb, Limb, Limb #), C.Choice #) -- ^ result, carry
    161 shr1_c# (# w0, w1, w2, w3 #) =
    162   let !s = case B.finiteBitSize (0 :: Word) of I# m -> m -# 1#
    163       !(# s3, c3 #) = (# L.shr# w3 1#, L.shl# w3 s #)
    164       !r3 = L.or# s3 (Limb 0##)
    165       !(# s2, c2 #) = (# L.shr# w2 1#, L.shl# w2 s #)
    166       !r2 = L.or# s2 c3
    167       !(# s1, c1 #) = (# L.shr# w1 1#, L.shl# w1 s #)
    168       !r1 = L.or# s1 c2
    169       !(# s0, c0 #) = (# L.shr# w0 1#, L.shl# w0 s #)
    170       !r0 = L.or# s0 c1
    171       !(Limb w) = L.shr# c0 s
    172   in  (# (# r0, r1, r2, r3 #), C.from_word_lsb# w #)
    173 {-# INLINE shr1_c# #-}
    174 
    175 shr1_c :: Wider -> (Wider, Bool)
    176 shr1_c (Wider w) =
    177   let !(# r, c #) = shr1_c# w
    178   in  (Wider r, C.decide c)
    179 
    180 and_w#
    181   :: (# Limb, Limb, Limb, Limb #)
    182   -> (# Limb, Limb, Limb, Limb #)
    183   -> (# Limb, Limb, Limb, Limb #)
    184 and_w# (# a0, a1, a2, a3 #) (# b0, b1, b2, b3 #) =
    185   (# L.and# a0 b0, L.and# a1 b1, L.and# a2 b2, L.and# a3 b3 #)
    186 {-# INLINE and_w# #-}
    187 
    188 and :: Wider -> Wider -> Wider
    189 and (Wider a) (Wider b) = Wider (and_w# a b)
    190 
    191 or_w#
    192   :: (# Limb, Limb, Limb, Limb #)
    193   -> (# Limb, Limb, Limb, Limb #)
    194   -> (# Limb, Limb, Limb, Limb #)
    195 or_w# (# a0, a1, a2, a3 #) (# b0, b1, b2, b3 #) =
    196   (# L.or# a0 b0, L.or# a1 b1, L.or# a2 b2, L.or# a3 b3 #)
    197 {-# INLINE or_w# #-}
    198 
    199 or :: Wider -> Wider -> Wider
    200 or (Wider a) (Wider b) = Wider (or_w# a b)
    201 
    202 not#
    203   :: (# Limb, Limb, Limb, Limb #)
    204   -> (# Limb, Limb, Limb, Limb #)
    205 not# (# l0, l1, l2, l3 #) = (# L.not# l0, L.not# l1, L.not# l2, L.not# l3 #)
    206 {-# INLINE not# #-}
    207 
    208 not
    209   :: Wider
    210   -> Wider
    211 not (Wider w) = Wider (not# w)
    212 
    213 -- conditional_shr#
    214 --   :: (# Word#, Word#, Word#, Word# #)
    215 --   -> Int#
    216 --   -> C.Choice
    217 --   -> (# (# Word#, Word#, Word#, Word# #), Word# #)
    218 -- conditional_shr# (# a0, a1, a2, a3 #) s c =
    219 --   let !size = case B.finiteBitSize (0 :: Word) of I# m -> m
    220 --       !rs = s
    221 --       !ls = size -# s
    222 --       !(# l3, c3 #) =
    223 --         (# C.ct_select_word# a3 (uncheckedShiftRL# a3 rs)
    224 
    225 -- addition, subtraction ------------------------------------------------------
    226 
    227 -- | Overflowing addition, computing 'a + b', returning the sum and a
    228 --   carry bit.
    229 add_o#
    230   :: (# Limb, Limb, Limb, Limb #)             -- ^ augend
    231   -> (# Limb, Limb, Limb, Limb #)             -- ^ addend
    232   -> (# (# Limb, Limb, Limb, Limb #), Limb #) -- ^ (# sum, carry bit #)
    233 add_o# (# a0, a1, a2, a3 #) (# b0, b1, b2, b3 #) =
    234   let !(# s0, c0 #) = L.add_o# a0 b0
    235       !(# s1, c1 #) = L.add_c# a1 b1 c0
    236       !(# s2, c2 #) = L.add_c# a2 b2 c1
    237       !(# s3, c3 #) = L.add_c# a3 b3 c2
    238   in  (# (# s0, s1, s2, s3 #), c3 #)
    239 {-# INLINE add_o# #-}
    240 
    241 -- | Overflowing addition, computing 'a + b', returning the sum and a
    242 --   carry bit.
    243 add_o
    244   :: Wider
    245   -> Wider
    246   -> (Wider, Word)
    247 add_o (Wider a) (Wider b) =
    248   let !(# s, Limb c #) = add_o# a b
    249   in  (Wider s, W# c)
    250 
    251 -- | Wrapping addition, computing 'a + b'.
    252 add_w#
    253   :: (# Limb, Limb, Limb, Limb #) -- ^ augend
    254   -> (# Limb, Limb, Limb, Limb #) -- ^ addend
    255   -> (# Limb, Limb, Limb, Limb #) -- ^ sum
    256 add_w# a b =
    257   let !(# c, _ #) = add_o# a b
    258   in  c
    259 {-# INLINE add_w# #-}
    260 
    261 -- | Wrapping addition, computing 'a + b'.
    262 add
    263   :: Wider
    264   -> Wider
    265   -> Wider
    266 add (Wider a) (Wider b) = Wider (add_w# a b)
    267 {-# INLINE add #-}
    268 
    269 -- | Modular addition.
    270 add_mod#
    271   :: (# Limb, Limb, Limb, Limb #) -- ^ augend
    272   -> (# Limb, Limb, Limb, Limb #) -- ^ addend
    273   -> (# Limb, Limb, Limb, Limb #) -- ^ modulus
    274   -> (# Limb, Limb, Limb, Limb #) -- ^ sum
    275 add_mod# a b m =
    276   let !(# w, c #) = add_o# a b
    277   in  sub_mod_c# w c m m
    278 {-# INLINE add_mod# #-}
    279 
    280 -- | Borrowing subtraction, computing 'a - b' and returning the
    281 --   difference with a borrow mask.
    282 sub_b#
    283   :: (# Limb, Limb, Limb, Limb #)              -- ^ minuend
    284   -> (# Limb, Limb, Limb, Limb #)              -- ^ subtrahend
    285   -> (# (# Limb, Limb, Limb, Limb #), Limb #) -- ^ (# diff, borrow mask #)
    286 sub_b# (# a0, a1, a2, a3 #) (# b0, b1, b2, b3 #) =
    287   let !(# s0, c0 #) = L.sub_b# a0 b0 (Limb 0##)
    288       !(# s1, c1 #) = L.sub_b# a1 b1 c0
    289       !(# s2, c2 #) = L.sub_b# a2 b2 c1
    290       !(# s3, c3 #) = L.sub_b# a3 b3 c2
    291   in  (# (# s0, s1, s2, s3 #), c3 #)
    292 {-# INLINE sub_b# #-}
    293 
    294 sub_b
    295   :: Wider
    296   -> Wider
    297   -> (Wider, Word)
    298 sub_b (Wider l) (Wider r) =
    299   let !(# d, Limb b #) = sub_b# l r
    300   in  (Wider d, W# b)
    301 
    302 sub
    303   :: Wider
    304   -> Wider
    305   -> Wider
    306 sub (Wider a) (Wider b) =
    307   let !(# d, _ #) = sub_b# a b
    308   in  Wider d
    309 
    310 -- | Modular subtraction. Computes a - b mod m.
    311 sub_mod#
    312   :: (# Limb, Limb, Limb, Limb #) -- ^ minuend
    313   -> (# Limb, Limb, Limb, Limb #) -- ^ subtrahend
    314   -> (# Limb, Limb, Limb, Limb #) -- ^ modulus
    315   -> (# Limb, Limb, Limb, Limb #) -- ^ difference
    316 sub_mod# a b (# p0, p1, p2, p3 #) =
    317   let !(# (# o0, o1, o2, o3 #), m #) = sub_b# a b
    318       !ba = (# L.and# p0 m, L.and# p1 m, L.and# p2 m, L.and# p3 m #)
    319   in  add_w# (# o0, o1, o2, o3 #) ba
    320 {-# INLINE sub_mod# #-}
    321 
    322 sub_mod
    323   :: Wider
    324   -> Wider
    325   -> Wider
    326   -> Wider
    327 sub_mod (Wider a) (Wider b) (Wider p) = Wider (sub_mod# a b p)
    328 
    329 -- | Modular subtraction with carry. Computes (# a, c #) - b mod m.
    330 sub_mod_c#
    331   :: (# Limb, Limb, Limb, Limb #) -- ^ minuend
    332   -> Limb                         -- ^ carry bit
    333   -> (# Limb, Limb, Limb, Limb #) -- ^ subtrahend
    334   -> (# Limb, Limb, Limb, Limb #) -- ^ modulus
    335   -> (# Limb, Limb, Limb, Limb #) -- ^ difference
    336 sub_mod_c# a c b (# p0, p1, p2, p3 #) =
    337   let !(# (# o0, o1, o2, o3 #), bb #) = sub_b# a b
    338       !(# _, m #) = L.sub_b# c (Limb 0##) bb
    339       !ba = (# L.and# p0 m, L.and# p1 m, L.and# p2 m, L.and# p3 m #)
    340   in  add_w# (# o0, o1, o2, o3 #) ba
    341 {-# INLINE sub_mod_c# #-}
    342 
    343 -- multiplication -------------------------------------------------------------
    344 
    345 -- widening multiplication
    346 mul_c#
    347   :: (# Limb, Limb, Limb, Limb #)
    348   -> (# Limb, Limb, Limb, Limb #)
    349   -> (# (# Limb, Limb, Limb, Limb #), (# Limb, Limb, Limb, Limb #) #)
    350 mul_c# (# x0, x1, x2, x3 #) (# y0, y1, y2, y3 #) =
    351   let !(# z0, c0_0 #)   = L.mac# x0 y0 (Limb 0##) (Limb 0##)
    352       !(# s1_0, c1_0 #) = L.mac# x0 y1 (Limb 0##) c0_0
    353       !(# z1, c1_1 #)   = L.mac# x1 y0 s1_0 (Limb 0##)
    354       !(# s2_0, c2_0 #) = L.mac# x0 y2 (Limb 0##) c1_0
    355       !(# s2_1, c2_1 #) = L.mac# x1 y1 s2_0 c1_1
    356       !(# z2, c2_2 #)   = L.mac# x2 y0 s2_1 (Limb 0##)
    357       !(# s3_0, c3_0 #) = L.mac# x0 y3 (Limb 0##) c2_0
    358       !(# s3_1, c3_1 #) = L.mac# x1 y2 s3_0 c2_1
    359       !(# s3_2, c3_2 #) = L.mac# x2 y1 s3_1 c2_2
    360       !(# z3, c3_3 #)   = L.mac# x3 y0 s3_2 (Limb 0##)
    361       !(# s4_0, c4_0 #) = L.mac# x1 y3 (Limb 0##) c3_0
    362       !(# s4_1, c4_1 #) = L.mac# x2 y2 s4_0 c3_1
    363       !(# s4_2, c4_2 #) = L.mac# x3 y1 s4_1 c3_2
    364       !(# w4, c4_3 #)   = L.add_c# s4_2 c3_3 (Limb 0##)
    365       !(# s5_0, c5_0 #) = L.mac# x2 y3 (Limb 0##) c4_0
    366       !(# s5_1, c5_1 #) = L.mac# x3 y2 s5_0 c4_1
    367       !(# w5, c5_2 #)   = L.add_c# s5_1 c4_2 (Limb 0##)
    368       !(# w5f, c5_3 #)  = L.add_c# w5 c4_3 (Limb 0##)
    369       !(# s6_0, c6_0 #) = L.mac# x3 y3 (Limb 0##) c5_0
    370       !(# w6, c6_1 #)   = L.add_c# s6_0 c5_1 (Limb 0##)
    371       !(# w6f, c6_2 #)  = L.add_c# w6 c5_2 (Limb 0##)
    372       !(# w6ff, c6_3 #) = L.add_c# w6f c5_3 (Limb 0##)
    373       !(# w7, _ #)      = L.add_c# c6_0 c6_1 (Limb 0##)
    374       !(# w7f, _ #)     = L.add_c# w7 c6_2 (Limb 0##)
    375       !(# w7ff, _ #)    = L.add_c# w7f c6_3 (Limb 0##)
    376   in  (# (# z0, z1, z2, z3 #), (# w4, w5f, w6ff, w7ff #) #)
    377 {-# INLINE mul_c# #-}
    378 
    379 mul
    380   :: Wider
    381   -> Wider
    382   -> Wider
    383 mul (Wider a) (Wider b) =
    384   let !(# l, _ #) = mul_c# a b
    385   in  Wider l
    386 
    387 sqr#
    388   :: (# Limb, Limb, Limb, Limb #)
    389   -> (# (# Limb, Limb, Limb, Limb #), (# Limb, Limb, Limb, Limb #) #)
    390 sqr# (# x0, x1, x2, x3 #) =
    391   let !sh = case B.finiteBitSize (0 :: Word) of I# m -> m -# 1#
    392       !(# q1_0, c1_0 #) = L.mac# x1 x0 (Limb 0##) (Limb 0##)
    393       !r1 = c1_0
    394       !(# r2_0, c2_0 #) = L.mac# x2 x0 r1 (Limb 0##)
    395       !(# s2_1, c2_1 #) = L.mac# x2 x1 (Limb 0##) c2_0
    396       !t2 = c2_1
    397       !(# s3_0, c3_0 #) = L.mac# x3 x0 s2_1 (Limb 0##)
    398       !(# t3, c3_1 #) = L.mac# x3 x1 t2 c3_0
    399       !(# u3, c3_2 #) = L.mac# x3 x2 (Limb 0##) c3_1
    400       !v3 = c3_2
    401       !(# lo1, car0_1 #) = (# L.shl# q1_0 1#, L.shr# q1_0 sh #)
    402       !(# lo2, car0_2 #) = (# L.or# (L.shl# r2_0 1#) car0_1, L.shr# r2_0 sh #)
    403       !(# lo3, car0_3 #) = (# L.or# (L.shl# s3_0 1#) car0_2, L.shr# s3_0 sh #)
    404       !(# hi0, car1_0 #) = (# L.or# (L.shl# t3 1#) car0_3, L.shr# t3 sh #)
    405       !(# hi1, car1_1 #) = (# L.or# (L.shl# u3 1#) car1_0, L.shr# u3 sh #)
    406       !(# hi2, car1_2 #) = (# L.or# (L.shl# v3 1#) car1_1, L.shr# v3 sh #)
    407       !hi3 = car1_2
    408       !(# pf, car2_0 #) = L.mac# x0 x0 (Limb 0##) (Limb 0##)
    409       !(# qf, car2_1 #) = L.add_c# lo1 car2_0 (Limb 0##)
    410       !(# rf, car2_2 #) = L.mac# x1 x1 lo2 car2_1
    411       !(# sf, car2_3 #) = L.add_c# lo3 car2_2 (Limb 0##)
    412       !(# tf, car2_4 #) = L.mac# x2 x2 hi0 car2_3
    413       !(# uf, car2_5 #) = L.add_c# hi1 car2_4 (Limb 0##)
    414       !(# vf, car2_6 #) = L.mac# x3 x3 hi2 car2_5
    415       !(# wf, _      #) = L.add_c# hi3 car2_6 (Limb 0##)
    416   in  (# (# pf, qf, rf, sf #), (# tf, uf, vf, wf #) #)
    417 {-# INLINE sqr# #-}
    418 
    419 sqr :: Wider -> (Wider, Wider)
    420 sqr (Wider w) =
    421   let !(# l, h #) = sqr# w
    422   in  (Wider l, Wider h)
    423