commit ac6cb8602318ffb079681289c9f43c3408a29b90
parent e6a960955e8b67215ac3ee67446d468814ba9cad
Author: Jared Tobin <jared@jtobin.io>
Date: Wed, 29 Jan 2025 22:47:15 +0400
lib: further optimisation
Diffstat:
3 files changed, 134 insertions(+), 64 deletions(-)
diff --git a/bench/Weight.hs b/bench/Weight.hs
@@ -38,6 +38,8 @@ main = W.mainWith $ do
W.func "quotrem_2by1" (E.quotrem_2by1 8 4 0xffffffffffffffff) r
W.func "div (baseline)" (Prelude.div i2) i3
W.func "div" (E.div w2) w3
+ W.func "mod (baseline)" (Prelude.mod i2) i3
+ W.func "mod" (E.mod w2) w3
where
!r = E.recip_2by1 0xFFFF_FFFF_FFFF_FF00
diff --git a/lib/Data/Word/Extended.hs b/lib/Data/Word/Extended.hs
@@ -38,6 +38,7 @@ module Data.Word.Extended (
, sub
, mul
, div
+ , rem
-- * Modular Arithmetic
, mod
@@ -69,8 +70,8 @@ import qualified Data.Primitive.PrimArray as PA
import GHC.Exts
import GHC.Generics
import GHC.Word
-import Prelude hiding (div, mod, or, and, quot)
-import qualified Prelude (mod, quot)
+import Prelude hiding (div, mod, or, and, quot, rem)
+import qualified Prelude (quot, rem)
fi :: (Integral a, Num b) => a -> b
fi = fromIntegral
@@ -437,11 +438,6 @@ mul (Word256 (W64# a0) (W64# a1) (W64# a2) (W64# a3))
-- division -------------------------------------------------------------------
-newtype Memory m = Memory (PA.MutablePrimArray (PrimState m) Word64)
- deriving Generic
-
-instance PrimMonad m => NFData (Memory m)
-
-- quotient, remainder of (hi, lo) divided by y
-- translated from Div64 in go's math/bits package
--
@@ -715,6 +711,7 @@ quotrem_knuth quo u ulen d = do
PA.writePrimArray u (j + ld) (u2 - borrow)
if u2 < borrow
then do
+ -- rare case
let !qh = qhat - 1
r <- add_to u j d ld
PA.writePrimArray u (j + ld) r
@@ -724,10 +721,62 @@ quotrem_knuth quo u ulen d = do
loop (pred j)
loop (ulen - ld - 1)
+rem_knuth
+ :: PrimMonad m
+ => PA.MutablePrimArray (PrimState m) Word64 -- normalized dividend
+ -> Int -- size of normalize dividend
+ -> PA.PrimArray Word64 -- normalized divisor
+ -> m ()
+rem_knuth u ulen d = do
+ let !ld = PA.sizeofPrimArray d
+ !dh = PA.indexPrimArray d (ld - 1)
+ !dl = PA.indexPrimArray d (ld - 2)
+ !rec = recip_2by1 dh
+ loop !j
+ | j < 0 = pure ()
+ | otherwise = do
+ !u2 <- PA.readPrimArray u (j + ld)
+ !u1 <- PA.readPrimArray u (j + ld - 1)
+ !u0 <- PA.readPrimArray u (j + ld - 2)
+ let !qhat
+ | u2 >= dh = 0xffff_ffff_ffff_ffff
+ | otherwise =
+ let !(P qh rh) = quotrem_2by1 u2 u1 dh rec
+ !(P ph pl) = mul_c qh dl
+ in if ph > rh || (ph == rh && pl > u0)
+ then qh - 1
+ else qh
+
+ !borrow <- sub_mul u j d ld qhat
+ PA.writePrimArray u (j + ld) (u2 - borrow)
+ if u2 < borrow
+ then do
+ -- rare case
+ r <- add_to u j d ld
+ PA.writePrimArray u (j + ld) r
+ else
+ pure ()
+ loop (pred j)
+ loop (ulen - ld - 1)
+
+normalized_dividend_length
+ :: PrimMonad m
+ => PA.MutablePrimArray (PrimState m) Word64 -- dividend
+ -> m Int
+normalized_dividend_length u = do
+ !lu <- PA.getSizeofMutablePrimArray u
+ let loop !j
+ | j < 0 = pure 0
+ | otherwise = do
+ uj <- PA.readPrimArray u j
+ if uj /= 0 then pure (j + 1) else loop (j - 1)
+ loop (lu - 2) -- last word will be uninitialized, skip it
+{-# INLINE normalized_dividend_length #-}
+
normalize_divisor
:: PrimMonad m
=> Word256
- -> m (PA.PrimArray Word64, Int, Word64) -- XX more efficient
+ -> m (PA.PrimArray Word64, Int, Int, Word64) -- XX more efficient
normalize_divisor (Word256 d0 d1 d2 d3) = do
let (dlen, d_last, shift)
| d3 /= 0 = (4, d3, B.countLeadingZeros d3)
@@ -763,9 +812,29 @@ normalize_divisor (Word256 d0 d1 d2 d3) = do
norm (j - 1) dj_1
dn_0 <- norm (dlen - 1) d_last
d_final <- PA.unsafeFreezePrimArray dn
- pure (d_final, shift, dn_0)
+ pure (d_final, dlen, shift, dn_0)
{-# INLINE normalize_divisor #-}
+normalize_dividend
+ :: PrimMonad m
+ => PA.MutablePrimArray (PrimState m) Word64
+ -> Int
+ -> Int
+ -> m ()
+normalize_dividend u ulen s = do
+ u_hi <- PA.readPrimArray u (ulen - 1)
+ PA.writePrimArray u ulen (u_hi .>>. (64 - s))
+ let loop !j !uj
+ | j == 0 =
+ PA.writePrimArray u 0 (uj .<<. s)
+ | otherwise = do
+ !uj_1 <- PA.readPrimArray u (j - 1)
+ PA.writePrimArray u j $
+ (uj .<<. s) .|. (uj_1 .>>. (64 - s))
+ loop (j - 1) uj_1
+ loop (ulen - 1) u_hi
+{-# INLINE normalize_dividend #-}
+
quotrem
:: PrimMonad m
=> PA.MutablePrimArray (PrimState m) Word64 -- quotient (potentially large)
@@ -774,36 +843,14 @@ quotrem
-> m (PA.PrimArray Word64) -- remainder (256-bit)
quotrem quo u d = do
-- normalize divisor
- !(dn, shift, dn_0) <- normalize_divisor d
- let !dlen = PA.sizeofPrimArray dn
-
+ !(dn, dlen, shift, dn_0) <- normalize_divisor d
-- get size of normalized dividend
- !lu <- PA.getSizeofMutablePrimArray u
- !ulen <- let loop !j
- | j < 0 = pure 0
- | otherwise = do
- uj <- PA.readPrimArray u j
- if uj /= 0 then pure (j + 1) else loop (j - 1)
- in loop (lu - 2) -- don't touch the uninitialized word
+ !ulen <- normalized_dividend_length u
if ulen < dlen
- then do
- -- u always has size at least 4
- !r <- PA.newPrimArray 4
- PA.copyMutablePrimArray r 0 u 0 4
- PA.unsafeFreezePrimArray r
+ then PA.freezePrimArray u 0 4
else do
-- normalize dividend
- u_hi <- PA.readPrimArray u (ulen - 1)
- PA.writePrimArray u ulen (u_hi .>>. (64 - shift))
- let normalize_u !j !uj
- | j == 0 =
- PA.writePrimArray u 0 (uj .<<. shift)
- | otherwise = do
- !uj_1 <- PA.readPrimArray u (j - 1)
- PA.writePrimArray u j $
- (uj .<<. shift) .|. (uj_1 .>>. (64 - shift))
- normalize_u (j - 1) uj_1
- normalize_u (ulen - 1) u_hi
+ normalize_dividend u ulen shift
if dlen == 1
then do
-- normalized divisor is small
@@ -811,6 +858,7 @@ quotrem quo u d = do
!r <- quotrem_by1 quo un dn_0
pure $ PA.primArrayFromList [r .>>. shift, 0, 0, 0] -- XX
else do
+ -- quotrem of normalized dividend divided by normalized divisor
quotrem_knuth quo u (ulen + 1) dn
-- unnormalize remainder
let unn_rem !j !unj
@@ -837,34 +885,13 @@ quot
-> m Int -- length of quotient
quot quo u d = do
-- normalize divisor
- !(dn, shift, dn_0) <- normalize_divisor d
- let !dlen = PA.sizeofPrimArray dn
-
+ !(dn, dlen, shift, dn_0) <- normalize_divisor d
-- get size of normalized dividend
- -- XX extract this
- !lu <- PA.getSizeofMutablePrimArray u
- !ulen <- let loop !j
- | j < 0 = pure 0
- | otherwise = do
- uj <- PA.readPrimArray u j
- if uj /= 0 then pure (j + 1) else loop (j - 1)
- in loop (lu - 2) -- don't touch the uninitialized word
+ !ulen <- normalized_dividend_length u
if ulen < dlen
then pure 0
else do
- -- normalize dividend
- -- XX extract this
- u_hi <- PA.readPrimArray u (ulen - 1)
- PA.writePrimArray u ulen (u_hi .>>. (64 - shift))
- let normalize_u !j !uj
- | j == 0 =
- PA.writePrimArray u 0 (uj .<<. shift)
- | otherwise = do
- !uj_1 <- PA.readPrimArray u (j - 1)
- PA.writePrimArray u j $
- (uj .<<. shift) .|. (uj_1 .>>. (64 - shift))
- normalize_u (j - 1) uj_1
- normalize_u (ulen - 1) u_hi
+ normalize_dividend u ulen shift
if dlen == 1
then do
-- normalized divisor is small
@@ -876,6 +903,48 @@ quot quo u d = do
pure (ulen + 1 - dlen)
{-# INLINE quot #-}
+rem
+ :: PrimMonad m
+ => PA.MutablePrimArray (PrimState m) Word64 -- quotient (potentially large)
+ -> PA.MutablePrimArray (PrimState m) Word64 -- unnormalized dividend
+ -> Word256 -- unnormalized divisor
+ -> m (PA.PrimArray Word64) -- remainder (256-bit)
+rem quo u d = do
+ -- normalize divisor
+ !(dn, dlen, shift, dn_0) <- normalize_divisor d
+ -- get size of normalized dividend
+ !ulen <- normalized_dividend_length u
+ if ulen < dlen
+ then PA.freezePrimArray u 0 4
+ else do
+ -- normalize dividend
+ normalize_dividend u ulen shift
+ if dlen == 1
+ then do
+ -- normalized divisor is small
+ !un <- PA.unsafeFreezePrimArray u
+ !r <- quotrem_by1 quo un dn_0
+ pure $ PA.primArrayFromList [r .>>. shift, 0, 0, 0] -- XX
+ else do
+ -- quotrem of normalized dividend divided by normalized divisor
+ rem_knuth u (ulen + 1) dn
+ -- unnormalize remainder
+ let unn_rem !j !unj
+ | j == dlen = do
+ PA.unsafeFreezePrimArray u
+ | j + 1 == ulen = do
+ PA.writePrimArray u j (unj .>>. shift)
+ PA.unsafeFreezePrimArray u
+ | otherwise = do
+ !unj_1 <- PA.readPrimArray u (j + 1)
+ PA.writePrimArray u j $
+ (unj .>>. shift) .|. (unj_1 .<<. (64 - shift))
+ unn_rem (j + 1) unj_1
+
+ !un_0 <- PA.readPrimArray u 0
+ unn_rem 0 un_0
+{-# INLINE rem #-}
+
div :: Word256 -> Word256 -> Word256
div u@(Word256 u0 u1 u2 u3) d@(Word256 d0 _ _ _)
| is_zero d || d `gt` u = zero -- ?
@@ -921,11 +990,10 @@ mod :: Word256 -> Word256 -> Word256
mod u@(Word256 u0 u1 u2 u3) d@(Word256 d0 _ _ _)
| is_zero d || d `gt` u = zero -- ?
| u == d = one
- | is_word64 u = Word256 (u0 `Prelude.mod` d0) 0 0 0
+ | is_word64 u = Word256 (u0 `Prelude.rem` d0) 0 0 0
| otherwise = runST $ do
-- allocate quotient
quo <- PA.newPrimArray 4
- PA.setPrimArray quo 0 4 0 -- XX avoid
-- allocate dividend, leaving enough space for normalization
u_hot <- PA.newPrimArray 5
PA.writePrimArray u_hot 0 u0
@@ -933,7 +1001,7 @@ mod u@(Word256 u0 u1 u2 u3) d@(Word256 d0 _ _ _)
PA.writePrimArray u_hot 2 u2
PA.writePrimArray u_hot 3 u3
-- last index of u_hot intentionally unset
- r <- quotrem quo u_hot d
+ r <- rem quo u_hot d
let r0 = PA.indexPrimArray r 0
r1 = PA.indexPrimArray r 1
r2 = PA.indexPrimArray r 2
diff --git a/test/Main.hs b/test/Main.hs
@@ -13,7 +13,7 @@ import Data.Word.Extended
import GHC.Exts
import GHC.Word
import Prelude hiding (and, or, div, mod)
-import qualified Prelude (div)
+import qualified Prelude (div, rem)
import Test.Tasty
import qualified Test.Tasty.HUnit as H
import qualified Test.Tasty.QuickCheck as Q
@@ -106,7 +106,7 @@ div_matches (DivMonotonic (a, b)) =
mod_matches :: DivMonotonic -> Bool
mod_matches (DivMonotonic (a, b)) =
let !left = to_word256 a `mod` to_word256 b
- !rite = to_word256 (a `rem` b)
+ !rite = to_word256 (a `Prelude.rem` b)
in left == rite
quotrem_r_case0 :: H.Assertion