commit 9c417daa84b3c054aa0787da3782c3cdcb29df93
parent 417434b294e22a5f2da9c6c607747522844901a4
Author: Jared Tobin <jared@jtobin.io>
Date: Tue, 28 Jan 2025 12:15:48 +0400
lib: the quixotic vibes are thick
Diffstat:
4 files changed, 167 insertions(+), 126 deletions(-)
diff --git a/bench/Main.hs b/bench/Main.hs
@@ -7,7 +7,9 @@ module Main where
import Criterion.Main
import Data.Bits ((.|.), (.&.), (.^.))
import qualified Data.Bits as B
+import qualified Data.Primitive.PrimArray as PA
import qualified Data.Word.Extended as W
+import Data.Word (Word64)
import Prelude hiding (or, and, div, mod)
import qualified Prelude (div)
@@ -21,9 +23,11 @@ multiplication = bgroup "multiplication" [
]
division = bgroup "division" [
- quotrem_r
- , quot_r
+ quotrem_by1
+ , rem_by1
, quotrem_2by1
+ , quot_r
+ , quotrem_r
]
main :: IO ()
@@ -89,6 +93,34 @@ quotrem_2by1 = bench "quotrem_2by1" $
where
!r = W.recip_2by1 0xFFFF_FFFF_FFFF_FF00
+quotrem_by1 :: Benchmark
+quotrem_by1 = env setup $ \ ~(q, u, d) ->
+ bench "quotrem_by1" $
+ nfIO (W.quotrem_by1 q u d)
+ where
+ setup = do
+ qm <- PA.newPrimArray 2
+ PA.setPrimArray qm 0 2 0
+ let !u = PA.primArrayFromList [4, 8]
+ !d = B.complement 0xFF :: Word64
+ pure (qm, u, d)
+
+rem_by1 :: Benchmark
+rem_by1 = bench "rem_by1" $
+ nf (W.rem_by1 (PA.primArrayFromList [4, 8])) (B.complement 0xFF :: Word64)
+
+
+-- quotrem_by1_case0 :: H.Assertion
+-- quotrem_by1_case0 = do
+-- qm <- PA.newPrimArray 2
+-- PA.setPrimArray qm 0 2 0
+-- let !u = PA.primArrayFromList [4, 8]
+-- !d = B.complement 0xFF :: Word64
+-- r <- quotrem_by1 qm u d
+-- q <- PA.unsafeFreezePrimArray qm
+-- H.assertEqual "quotient" (PA.primArrayFromList [8, 0]) q
+-- H.assertEqual "remainder" 2052 r
+
-- or_baseline :: Benchmark
-- or_baseline = bench "or (baseline)" $ nf ((.|.) w0) w1 where
diff --git a/lib/Data/Word/Extended.hs b/lib/Data/Word/Extended.hs
@@ -469,6 +469,7 @@ quotrem_r# hi lo y_0
(if (isTrue# (s ==# 0#))
then wordToWord64# 0##
else uncheckedShiftRL64# lo (64# -# s))
+
!un10 = uncheckedShiftL64# lo s
!un1 = uncheckedShiftRL64# un10 32#
!un0 = and64# un10 mask32
@@ -606,112 +607,41 @@ recip_2by1# :: Word64# -> Word64#
recip_2by1# d = quot_r# (not64# d) (wordToWord64# 0xffffffffffffffff##) d
{-# INLINE recip_2by1# #-}
--- -- remainder by normalized word
--- rem_by_norm_word
--- :: PrimMonad m
--- => Memory m -- memory
--- -> Int -- normalized dividend offset
--- -> Int -- length of normalized dividend
--- -> Int -- normalized divisor offset
--- -> m Word64 -- remainder
--- rem_by_norm_word (Memory buf) un_offset lun dn_offset = do
--- d <- PA.readPrimArray buf dn_offset
--- let rec = recip_2by1 d
--- r0 <- PA.readPrimArray buf (un_offset + lun - 1)
--- let loop !j !racc
--- | j < 0 = pure racc
--- | otherwise = do
--- !uj <- PA.readPrimArray buf (un_offset + j)
--- let !(P _ rnex) = quotrem_2by1 racc uj d rec
--- -- PA.writePrimArray buf j qj
--- loop (j - 1) rnex
--- loop (lun - 2) r0
---
--- -- quotient & remainder by normalized word
--- quotrem_by_norm_word
--- :: PrimMonad m
--- => Memory m -- memory
--- -> Int -- normalized dividend offset
--- -> Int -- length of normalized dividend
--- -> Int -- normalized divisor offset
--- -> m Word64 -- remainder
--- quotrem_by_norm_word (Memory buf) un_offset lun dn_offset = do
--- d <- PA.readPrimArray buf dn_offset
--- let rec = recip_2by1 d
--- r0 <- PA.readPrimArray buf (un_offset + lun - 1)
--- let loop !j !racc
--- | j < 0 = pure racc
--- | otherwise = do
--- !uj <- PA.readPrimArray buf (un_offset + j)
--- let !(P _ rnex) = quotrem_2by1 racc uj d rec
--- PA.writePrimArray buf j qj
--- loop (j - 1) rnex
--- loop (lun - 2) r0
-
--- x =- y * m
--- requires (len x - x_offset) >= len y > 0
-sub_mul
+quotrem_by1
:: PrimMonad m
- => PA.MutablePrimArray (PrimState m) Word64
- -> Int
- -> PA.PrimArray Word64
- -> Int
- -> Word64
- -> m Word64
-sub_mul x x_offset y l m = do
- let loop !j !borrow
- | j == l = pure borrow
+ => PA.MutablePrimArray (PrimState m) Word64 -- quotient
+ -> PA.PrimArray Word64 -- variable-length dividend
+ -> Word64 -- divisor
+ -> m Word64 -- remainder
+quotrem_by1 q u d = do
+ let !rec = recip_2by1 d
+ loop !j !hj
+ | j < 0 = pure hj
| otherwise = do
- !x_j <- PA.readPrimArray x (j + x_offset)
- let !y_j = PA.indexPrimArray y j
- let !(P s carry1) = sub_b x_j borrow 0
- !(P ph pl) = mul_c y_j m
- !(P t carry2) = sub_b s pl 0
- PA.writePrimArray x (j + x_offset) t
- loop (succ j) (ph + carry1 + carry2)
- loop 0 0
+ let !lj = PA.indexPrimArray u j
+ !(P qj rj) = quotrem_2by1 hj lj d rec
+ PA.writePrimArray q j qj
+ loop (j - 1) rj
+ !l = PA.sizeofPrimArray u
+ !hl = PA.indexPrimArray u (l - 1)
+ loop (l - 2) hl
+
+rem_by1
+ :: PA.PrimArray Word64 -- variable-length dividend
+ -> Word64 -- divisor
+ -> Word64 -- remainder
+rem_by1 u d = do
+ let !rec = recip_2by1 d
+ loop !j !hj
+ | j < 0 = hj
+ | otherwise = do
+ let !lj = PA.indexPrimArray u j
+ !(P _ rj) = quotrem_2by1 hj lj d rec
+ loop (j - 1) rj
+ !l = PA.sizeofPrimArray u
+ !hl = PA.indexPrimArray u (l - 1)
+ loop (l - 2) hl
--- quotrem_by1
--- :: PrimMonad m
--- => PA.MutablePrimArray (PrimState m) Word64
--- -> PA.PrimArray Word64
--- -> Word64
--- -> m Word64
--- quotrem_by1 quo u d = do
--- let !rec = recip_2by1 d
--- !lu = PA.sizeofPrimArray u
--- !r0 = PA.indexPrimArray u (lu - 1)
--- loop !j !racc
--- | j < 0 = pure racc
--- | otherwise = do
--- let uj = PA.indexPrimArray u j
--- !(P qj rnex) = quotrem_2by1 racc uj d rec
--- PA.writePrimArray quo j qj
--- loop (pred j) rnex
--- loop (lu - 2) r0
---
--- add_to
--- :: PrimMonad m
--- => PA.MutablePrimArray (PrimState m) Word64
--- -> Int
--- -> Word256
--- -> Int
--- -> m Word64
--- add_to x x_offset (Word256 y0 y1 y2 y3) l = do
--- let loop !j !cacc
--- | j == l = pure cacc
--- | otherwise = do
--- xj <- PA.readPrimArray x (j + x_offset)
--- let !(P nex carry) = case j of
--- 0 -> add_c xj y0 cacc
--- 1 -> add_c xj y1 cacc
--- 2 -> add_c xj y2 cacc
--- 3 -> add_c xj y3 cacc
--- _ -> error "ppad-fixed (add_to): bad index"
--- PA.writePrimArray x (j + x_offset) nex
--- loop (succ j) carry
--- loop 0 0
---
-- quotrem
-- :: PrimMonad m
-- => PA.MutablePrimArray (PrimState m) Word64 -- quotient (potentially large)
@@ -779,6 +709,91 @@ sub_mul x x_offset y l m = do
--
-- !un_0 <- PA.readPrimArray un 0
-- {-# SCC "unn_rem" #-} unn_rem 0 un_0
+
+-- x =- y * m
+-- requires (len x - x_offset) >= len y > 0
+sub_mul
+ :: PrimMonad m
+ => PA.MutablePrimArray (PrimState m) Word64
+ -> Int
+ -> PA.PrimArray Word64
+ -> Int
+ -> Word64
+ -> m Word64
+sub_mul x x_offset y l m = do
+ let loop !j !borrow
+ | j == l = pure borrow
+ | otherwise = do
+ !x_j <- PA.readPrimArray x (j + x_offset)
+ let !y_j = PA.indexPrimArray y j
+ let !(P s carry1) = sub_b x_j borrow 0
+ !(P ph pl) = mul_c y_j m
+ !(P t carry2) = sub_b s pl 0
+ PA.writePrimArray x (j + x_offset) t
+ loop (succ j) (ph + carry1 + carry2)
+ loop 0 0
+
+-- -- quotrem of dividend divided by word
+-- quotrem_by1
+-- :: PrimMonad m
+-- => Memory m -- memory
+-- -> Int -- normalized dividend offset
+-- -> Int -- length of normalized dividend
+-- -> Int -- normalized divisor offset
+-- -> m Word64 -- remainder
+-- quotrem_by1 (Memory buf) un_offset lun dn_offset = do
+-- d <- PA.readPrimArray buf dn_offset
+-- let rec = recip_2by1 d
+-- r0 <- PA.readPrimArray buf (un_offset + lun - 1)
+-- let loop !j !racc
+-- | j < 0 = pure racc
+-- | otherwise = do
+-- !uj <- PA.readPrimArray buf (un_offset + j)
+-- let !(P qj rnex) = quotrem_2by1 racc uj d rec
+-- PA.writePrimArray buf j qj
+-- loop (j - 1) rnex
+-- loop (lun - 2) r0
+
+-- quotrem_by1
+-- :: PrimMonad m
+-- => PA.MutablePrimArray (PrimState m) Word64
+-- -> PA.PrimArray Word64
+-- -> Word64
+-- -> m Word64
+-- quotrem_by1 quo u d = do
+-- let !rec = recip_2by1 d
+-- !lu = PA.sizeofPrimArray u
+-- !r0 = PA.indexPrimArray u (lu - 1)
+-- loop !j !racc
+-- | j < 0 = pure racc
+-- | otherwise = do
+-- let uj = PA.indexPrimArray u j
+-- !(P qj rnex) = quotrem_2by1 racc uj d rec
+-- PA.writePrimArray quo j qj
+-- loop (pred j) rnex
+-- loop (lu - 2) r0
+--
+-- add_to
+-- :: PrimMonad m
+-- => PA.MutablePrimArray (PrimState m) Word64
+-- -> Int
+-- -> Word256
+-- -> Int
+-- -> m Word64
+-- add_to x x_offset (Word256 y0 y1 y2 y3) l = do
+-- let loop !j !cacc
+-- | j == l = pure cacc
+-- | otherwise = do
+-- xj <- PA.readPrimArray x (j + x_offset)
+-- let !(P nex carry) = case j of
+-- 0 -> add_c xj y0 cacc
+-- 1 -> add_c xj y1 cacc
+-- 2 -> add_c xj y2 cacc
+-- 3 -> add_c xj y3 cacc
+-- _ -> error "ppad-fixed (add_to): bad index"
+-- PA.writePrimArray x (j + x_offset) nex
+-- loop (succ j) carry
+-- loop 0 0
--
-- quotrem_knuth
-- :: PrimMonad m
@@ -828,27 +843,6 @@ sub_mul x x_offset y l m = do
-- PA.writePrimArray quo j qhat
-- loop (pred j)
-- loop (lu - ld - 1)
-
---
--- recip_2by1 :: Word64 -> Word64
--- recip_2by1 d = r where
--- !(P r _) = quotrem_r (B.complement d) 0xffffffffffffffff d
---
--- quotrem_2by1 :: Word64 -> Word64 -> Word64 -> Word64 -> Word128
--- quotrem_2by1 uh ul d rec =
--- let !(P qh_0 ql) = mul_c rec uh
--- !(P ql_0 c) = add_c ql ul 0
--- !(P (succ -> qh_1) _) = add_c qh_0 uh c
--- !r = ul - qh_1 * d
---
--- !(P qh_y r_y)
--- | r > ql_0 = P (qh_1 - 1) (r + d)
--- | otherwise = P qh_1 r
---
--- in if r_y >= d
--- then P (qh_y + 1) (r_y - d)
--- else P qh_y r_y
---
--
-- div :: Word256 -> Word256 -> Word256
-- div u@(Word256 u0 u1 u2 u3) d@(Word256 d0 _ _ _)
diff --git a/ppad-fixed.cabal b/ppad-fixed.cabal
@@ -42,6 +42,7 @@ test-suite fixed-tests
base
, bytestring
, ppad-fixed
+ , primitive
, tasty
, tasty-hunit
, tasty-quickcheck
diff --git a/test/Main.hs b/test/Main.hs
@@ -8,6 +8,7 @@ module Main where
import Data.Bits ((.|.), (.&.), (.>>.), (.^.))
import qualified Data.Bits as B
+import qualified Data.Primitive.PrimArray as PA
import Data.Word.Extended
import GHC.Exts
import GHC.Word
@@ -115,6 +116,18 @@ quotrem_2by1_case0 = do
!o = quotrem_2by1 8 4 d (recip_2by1 d)
H.assertEqual mempty (P 8 2052) o
+quotrem_by1_case0 :: H.Assertion
+quotrem_by1_case0 = do
+ qm <- PA.newPrimArray 2
+ PA.setPrimArray qm 0 2 0
+ let !u = PA.primArrayFromList [4, 8]
+ !d = B.complement 0xFF :: Word64
+ r <- quotrem_by1 qm u d
+ q <- PA.unsafeFreezePrimArray qm
+ H.assertEqual "quotient" (PA.primArrayFromList [8, 0]) q
+ H.assertEqual "remainder" 2052 r
+
+-- tests ----------------------------------------------------------------------
@@ -155,6 +168,7 @@ main = defaultMain $ testGroup "ppad-fixed" [
, H.testCase "recip_2by1 matches case0" recip_2by1_case0
, H.testCase "recip_2by1 matches case1" recip_2by1_case1
, H.testCase "quotrem_2by1 matches case0" quotrem_2by1_case0
+ , H.testCase "quotrem_by1 matches case0" quotrem_by1_case0
]
]