commit 417434b294e22a5f2da9c6c607747522844901a4
parent dbfca99d6d99325c846ab35fbde595a34d355852
Author: Jared Tobin <jared@jtobin.io>
Date: Tue, 28 Jan 2025 10:55:15 +0400
lib: the unboxing will continue
Diffstat:
4 files changed, 111 insertions(+), 186 deletions(-)
diff --git a/bench/Main.hs b/bench/Main.hs
@@ -1,10 +1,12 @@
{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE NumericUnderscores #-}
module Main where
import Criterion.Main
import Data.Bits ((.|.), (.&.), (.^.))
+import qualified Data.Bits as B
import qualified Data.Word.Extended as W
import Prelude hiding (or, and, div, mod)
import qualified Prelude (div)
@@ -19,7 +21,9 @@ multiplication = bgroup "multiplication" [
]
division = bgroup "division" [
- -- quotrem_r#
+ quotrem_r
+ , quot_r
+ , quotrem_2by1
]
main :: IO ()
@@ -27,6 +31,8 @@ main = defaultMain [
division
]
+-- addition and subtraction ---------------------------------------------------
+
add_baseline :: Benchmark
add_baseline = bench "add (baseline)" $ nf ((+) w0) w1 where
w0, w1 :: Integer
@@ -53,13 +59,13 @@ sub = bench "sub" $ nf (W.sub w0) w1 where
!w1 = W.to_word256
0x7fffffffffffffffffffffffffffffffffffffffffffffbfffffffffffffffed
+-- multiplication -------------------------------------------------------------
mul_baseline :: Benchmark
mul_baseline = bench "mul (baseline)" $ nf ((*) w0) w1 where
w0, w1 :: Integer
!w0 = 0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed
!w1 = 0x7fffffffffffffffffffffffffffffffffffffffffffffbfffffffffffffffed
--- XX overflows; unsure if valid comparison
mul :: Benchmark
mul = bench "mul" $ nf (W.mul w0) w1 where
!w0 = W.to_word256
@@ -67,6 +73,23 @@ mul = bench "mul" $ nf (W.mul w0) w1 where
!w1 = W.to_word256
0x7fffffffffffffffffffffffffffffffffffffffffffffbfffffffffffffffed
+-- division -------------------------------------------------------------------
+
+quotrem_r :: Benchmark
+quotrem_r = bench "quotrem_r" $
+ nf (W.quotrem_r 4 0xffffffffffffffff) (B.complement 4)
+
+quot_r :: Benchmark
+quot_r = bench "quot_r" $
+ nf (W.quot_r 4 0xffffffffffffffff) (B.complement 4)
+
+quotrem_2by1 :: Benchmark
+quotrem_2by1 = bench "quotrem_2by1" $
+ nf (W.quotrem_2by1 8 4 0xFFFF_FFFF_FFFF_FF00) r
+ where
+ !r = W.recip_2by1 0xFFFF_FFFF_FFFF_FF00
+
+
-- or_baseline :: Benchmark
-- or_baseline = bench "or (baseline)" $ nf ((.|.) w0) w1 where
-- w0, w1 :: Integer
diff --git a/bench/Weight.hs b/bench/Weight.hs
@@ -1,10 +1,12 @@
{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PackageImports #-}
module Main where
+import qualified Data.Bits as B
import qualified Data.Word.Extended as E
import qualified Weigh as W
@@ -26,13 +28,16 @@ w3 = E.to_word256 i3
main :: IO ()
main = W.mainWith $ do
- W.func "add (baseline)" ((+) i0) i1
- W.func "add" (E.add w0) w1
- W.func "sub (baseline)" ((-) i0) i1
- W.func "sub" (E.sub w0) w1
- W.func "mul (baseline)" ((*) i0) i1
- W.func "mul" (E.mul w0) w1
- W.func "quotrem_r" (E.quotrem_r 2 4) 4
+ W.func "add (baseline)" ((+) i0) i1
+ W.func "add" (E.add w0) w1
+ W.func "sub (baseline)" ((-) i0) i1
+ W.func "sub" (E.sub w0) w1
+ W.func "mul (baseline)" ((*) i0) i1
+ W.func "mul" (E.mul w0) w1
+ W.func "quotrem_r" (E.quotrem_r 4 0xffffffffffffffff) (B.complement 4)
+ W.func "quotrem_2by1" (E.quotrem_2by1 8 4 0xffffffffffffffff) r
+ where
+ !r = E.recip_2by1 0xFFFF_FFFF_FFFF_FF00
-- main :: IO ()
-- main = W.mainWith $ do
diff --git a/lib/Data/Word/Extended.hs b/lib/Data/Word/Extended.hs
@@ -263,6 +263,8 @@ xor :: Word256 -> Word256 -> Word256
xor (Word256 a0 a1 a2 a3) (Word256 b0 b1 b2 b3) =
Word256 (a0 .^. b0) (a1 .^. b1) (a2 .^. b2) (a3 .^. b3)
+-- XX rename add_c, sub_b, mul_c to something that conveys hi/lo bits stuff
+
-- addition, subtraction ------------------------------------------------------
-- add-with-carry
@@ -273,11 +275,9 @@ xor (Word256 a0 a1 a2 a3) (Word256 b0 b1 b2 b3) =
-- ARM ADDS
-- ADC
add_c :: Word64 -> Word64 -> Word64 -> Word128
-add_c w64_0 w64_1 c =
- let !s = w64_0 + w64_1 + c
- !n | s < w64_0 || s < w64_1 = 1
- | otherwise = 0
- in P s n
+add_c (W64# a) (W64# b) (W64# c) =
+ let !(# s, n #) = add_c# a b c
+ in P (W64# s) (W64# n)
add_c# :: Word64# -> Word64# -> Word64# -> (# Word64#, Word64# #)
add_c# w64_0 w64_1 c =
@@ -285,15 +285,7 @@ add_c# w64_0 w64_1 c =
!n | isTrue# (orI# (ltWord64# s w64_0) (ltWord64# s w64_1)) = 1#
| otherwise = 0#
in (# s, wordToWord64# (int2Word# n) #)
-
--- addition with overflow indication
-add_of :: Word256 -> Word256 -> Word320
-add_of (Word256 a0 a1 a2 a3) (Word256 b0 b1 b2 b3) =
- let !(P s0 c0) = add_c a0 b0 0
- !(P s1 c1) = add_c a1 b1 c0
- !(P s2 c2) = add_c a2 b2 c1
- !(P s3 c3) = add_c a3 b3 c2
- in Word320 (Word256 s0 s1 s2 s3) c3
+{-# INLINE add_c# #-}
add_of#
:: (# Word64#, Word64#, Word64#, Word64# #)
@@ -306,6 +298,7 @@ add_of# (# a0, a1, a2, a3 #)
!(# s2, c2 #) = add_c# a2 b2 c1
!(# s3, c3 #) = add_c# a3 b3 c2
in (# s0, s1, s2, s3, c3 #)
+{-# INLINE add_of# #-}
-- | Addition on 'Word256' values, with overflow.
--
@@ -327,11 +320,9 @@ add (Word256 (W64# a0) (W64# a1) (W64# a2) (W64# a3))
-- ARM SUBS
-- SBC
sub_b :: Word64 -> Word64 -> Word64 -> Word128
-sub_b w64_0 w64_1 b =
- let !d = w64_0 - w64_1 - b
- !n | w64_0 < w64_1 + b = 1
- | otherwise = 0
- in P d n
+sub_b (W64# wa) (W64# wb) (W64# b) =
+ let !(# d, n #) = sub_b# wa wb b
+ in P (W64# d) (W64# n)
sub_b# :: Word64# -> Word64# -> Word64# -> (# Word64#, Word64# #)
sub_b# w64_0 w64_1 b =
@@ -339,14 +330,7 @@ sub_b# w64_0 w64_1 b =
!n | isTrue# (ltWord64# w64_0 (plusWord64# w64_1 b)) = wordToWord64# 1##
| otherwise = wordToWord64# 0##
in (# d, n #)
-
-sub_of :: Word256 -> Word256 -> Word320
-sub_of (Word256 a0 a1 a2 a3) (Word256 b0 b1 b2 b3) =
- let !(P s0 c0) = sub_b a0 b0 0
- !(P s1 c1) = sub_b a1 b1 c0
- !(P s2 c2) = sub_b a2 b2 c1
- !(P s3 c3) = sub_b a3 b3 c2
- in Word320 (Word256 s0 s1 s2 s3) c3
+{-# INLINE sub_b# #-}
sub_of#
:: (# Word64#, Word64#, Word64#, Word64# #)
@@ -359,6 +343,7 @@ sub_of# (# a0, a1, a2, a3 #)
!(# s2, c2 #) = sub_b# a2 b2 c1
!(# s3, c3 #) = sub_b# a3 b3 c2
in (# s0, s1, s2, s3, c3 #)
+{-# INLINE sub_of# #-}
-- | Subtraction on 'Word256' values.
--
@@ -381,22 +366,9 @@ sub (Word256 (W64# a0) (W64# a1) (W64# a2) (W64# a3))
--
-- translated from Mul64 in go's math/bits package
mul_c :: Word64 -> Word64 -> Word128
-mul_c x y =
- let !mask32 = 0xffffffff
- !x0 = x .&. mask32
- !y0 = y .&. mask32
- !x1 = x .>>. 32
- !y1 = y .>>. 32
-
- !w0 = x0 * y0
- !t = x1 * y0 + w0 .>>. 32
- !w1 = t .&. mask32
- !w2 = t .>>. 32
- !w1_1 = w1 + x0 * y1
-
- !hi = x1 * y1 + w2 + w1_1 .>>. 32
- !lo = x * y
- in P hi lo
+mul_c (W64# x) (W64# y) =
+ let !(# hi, lo #) = mul_c# x y
+ in P (W64# hi) (W64# lo)
mul_c# :: Word64# -> Word64# -> (# Word64#, Word64# #)
mul_c# x y =
@@ -417,14 +389,7 @@ mul_c# x y =
(plusWord64# w2 (uncheckedShiftRL64# w1_1 32#))
!lo = timesWord64# x y
in (# hi, lo #)
-
--- (hi * 2 ^ 64 + lo) = z + (x * y)
-umul_hop :: Word64 -> Word64 -> Word64 -> Word128
-umul_hop z x y =
- let !(P hi_0 lo_0) = mul_c x y
- !(P lo c) = add_c lo_0 z 0
- !(P hi _) = add_c hi_0 0 c
- in P hi lo
+{-# INLINE mul_c# #-}
umul_hop# :: Word64# -> Word64# -> Word64# -> (# Word64#, Word64# #)
umul_hop# z x y =
@@ -432,16 +397,7 @@ umul_hop# z x y =
!(# lo, c #) = add_c# lo_0 z (wordToWord64# 0##)
!(# hi, _ #) = add_c# hi_0 (wordToWord64# 0##) c
in (# hi, lo #)
-
--- (hi * 2 ^ 64 + lo) = z + (x * y) + c
-umul_step :: Word64 -> Word64 -> Word64 -> Word64 -> Word128
-umul_step z x y c =
- let !(P hi_0 lo_0) = mul_c x y
- !(P lo_1 c_0) = add_c lo_0 c 0
- !(P hi_1 _) = add_c hi_0 0 c_0
- !(P lo c_1) = add_c lo_1 z 0
- !(P hi _) = add_c hi_1 0 c_1
- in P hi lo
+{-# INLINE umul_hop# #-}
umul_step#
:: Word64#
@@ -456,6 +412,7 @@ umul_step# z x y c =
!(# lo, c_1 #) = add_c# lo_1 z (wordToWord64# 0##)
!(# hi, _ #) = add_c# hi_1 (wordToWord64# 0##) c_1
in (# hi, lo #)
+{-# INLINE umul_step# #-}
-- | Multiplication on 'Word256' values, with overflow.
--
@@ -489,77 +446,9 @@ instance PrimMonad m => NFData (Memory m)
--
-- x86-64 (RDX:RAX) DIV
quotrem_r :: Word64 -> Word64 -> Word64 -> Word128
-quotrem_r hi lo y_0
- | y_0 == 0 = error "ppad-fixed: division by zero"
- | y_0 <= hi = error "ppad-fixed: overflow"
- | hi == 0 = P (lo `quot` y_0) (lo `rem` y_0)
- | otherwise =
- let !s = B.countLeadingZeros y_0
- !y = y_0 .<<. s
-
- !yn1 = y .>>. 32
- !yn0 = y .&. mask32
- !un32 = (hi .<<. s) .|. (lo .>>. (64 - s))
- !un10 = lo .<<. s
- !un1 = un10 .>>. 32
- !un0 = un10 .&. mask32
- !q1 = un32 `quot` yn1
- !rhat = un32 - q1 * yn1
-
- !q1_l = q_loop q1 rhat yn0 yn1 un1
-
- !un21 = un32 * two32 + un1 - q1_l * y
- !q0 = un21 `quot` yn1
- !rhat_n = un21 - q0 * yn1
-
- !q0_l = q_loop q0 rhat_n yn0 yn1 un0
- in P
- (q1_l * two32 + q0_l)
- ((un21 * two32 + un0 - q0_l * y) .>>. s)
- where
- !two32 = 0x100000000
- !mask32 = 0x0ffffffff
-
- q_loop !q_acc !rhat_acc !yn_0 !yn_1 !un =
- let go !qa !rha
- | qa >= two32 || qa * yn_0 > two32 * rha + un =
- let !qn = qa - 1
- !rhn = rha + yn_1
- in if rhn >= two32
- then qn
- else go qn rhn
- | otherwise = qa
- in go q_acc rhat_acc
-
-quotrem_2by1#
- :: Word64# -> Word64# -> Word64# -> Word64# -> (# Word64#, Word64# #)
-quotrem_2by1# uh ul d rec =
- let !(# qh_0, ql #) = mul_c# rec uh
- !(# ql_0, c #) = add_c# ql ul (wordToWord64# 0##)
- !(# qh_1_l, _ #) = add_c# qh_0 uh c
- !qh_1 = plusWord64# qh_1_l (wordToWord64# 1##)
- !r = subWord64# ul (timesWord64# qh_1 d)
-
- !(# qh_y, r_y #)
- | isTrue# (geWord64# r ql_0) = (# qh_1_l, plusWord64# r d #)
- | otherwise = (# qh_1, r #)
-
- in if isTrue# (geWord64# r_y d)
- then (# plusWord64# qh_y (wordToWord64# 1##), subWord64# r_y d #)
- else (# qh_y, r_y #)
-
-recip_2by1' :: Word64 -> Word64
-recip_2by1' (W64# d) = W64# (recip_2by1# d)
-
-recip_2by1# :: Word64# -> Word64#
-recip_2by1# d =
- let !(# r, _ #) =
- quotrem_r# (not64# d) (wordToWord64# 0xffffffffffffffff##) d
- in r
-
-recip_2by1 :: Word64 -> Word64
-recip_2by1 d = r where
- !(P r _) = quotrem_r (B.complement d) 0xffffffffffffffff d
+quotrem_r (W64# hi) (W64# lo) (W64# y) =
+ let !(# q, r #) = quotrem_r# hi lo y
+ in P (W64# q) (W64# r)
quotrem_r# :: Word64# -> Word64# -> Word64# -> (# Word64#, Word64# #)
quotrem_r# hi lo y_0
@@ -622,12 +511,13 @@ quotrem_r# hi lo y_0
| otherwise = qa
in go# q_acc rhat_acc
{-# INLINE q_loop# #-}
+{-# INLINE quotrem_r# #-}
--- uses manually-unboxed internals
-quotrem_r' :: Word64 -> Word64 -> Word64 -> Word128
-quotrem_r' (W64# hi) (W64# lo) (W64# y_0) =
- let !(# q, r #) = quotrem_r# hi lo y_0
- in P (W64# q) (W64# r)
+-- same as quotrem_r, except only computes quotient
+quot_r :: Word64 -> Word64 -> Word64 -> Word64
+quot_r (W64# hi) (W64# lo) (W64# y) =
+ let !q = quot_r# hi lo y
+ in W64# q
quot_r# :: Word64# -> Word64# -> Word64# -> Word64#
quot_r# hi lo y_0
@@ -635,7 +525,8 @@ quot_r# hi lo y_0
error "ppad-fixed (quotrem_r): division by zero"
| isTrue# (leWord64# y_0 hi) =
error "ppad-fixed: overflow"
- | isTrue# (eqWord64# hi (wordToWord64# 0##)) = quotWord64# lo y_0
+ | isTrue# (eqWord64# hi (wordToWord64# 0##)) =
+ quotWord64# lo y_0
| otherwise =
let !s = int64ToInt# (word64ToInt64# (wordToWord64# (clz64# y_0)))
!y = uncheckedShiftL64# y_0 s
@@ -644,7 +535,9 @@ quot_r# hi lo y_0
!yn0 = and64# y mask32
!un32 = or64#
(uncheckedShiftL64# hi s)
- (uncheckedShiftRL64# lo (64# -# s))
+ (if (isTrue# (s ==# 0#))
+ then wordToWord64# 0##
+ else uncheckedShiftRL64# lo (64# -# s))
!un10 = uncheckedShiftL64# lo s
!un1 = uncheckedShiftRL64# un10 32#
!un0 = and64# un10 mask32
@@ -681,6 +574,37 @@ quot_r# hi lo y_0
| otherwise = qa
in go# q_acc rhat_acc
{-# INLINE q_loop# #-}
+{-# INLINE quot_r# #-}
+
+quotrem_2by1 :: Word64 -> Word64 -> Word64 -> Word64 -> Word128
+quotrem_2by1 (W64# uh) (W64# ul) (W64# d) (W64# rec) =
+ let !(# q, r #) = quotrem_2by1# uh ul d rec
+ in P (W64# q) (W64# r)
+
+quotrem_2by1#
+ :: Word64# -> Word64# -> Word64# -> Word64# -> (# Word64#, Word64# #)
+quotrem_2by1# uh ul d rec =
+ let !(# qh_0, ql #) = mul_c# rec uh
+ !(# ql_0, c #) = add_c# ql ul (wordToWord64# 0##)
+ !(# qh_1_l, _ #) = add_c# qh_0 uh c
+ !qh_1 = plusWord64# qh_1_l (wordToWord64# 1##)
+ !r = subWord64# ul (timesWord64# qh_1 d)
+
+ !(# qh_y, r_y #)
+ | isTrue# (geWord64# r ql_0) = (# qh_1_l, plusWord64# r d #)
+ | otherwise = (# qh_1, r #)
+
+ in if isTrue# (geWord64# r_y d)
+ then (# plusWord64# qh_y (wordToWord64# 1##), subWord64# r_y d #)
+ else (# qh_y, r_y #)
+{-# INLINE quotrem_2by1# #-}
+
+recip_2by1 :: Word64 -> Word64
+recip_2by1 (W64# d) = W64# (recip_2by1# d)
+
+recip_2by1# :: Word64# -> Word64#
+recip_2by1# d = quot_r# (not64# d) (wordToWord64# 0xffffffffffffffff##) d
+{-# INLINE recip_2by1# #-}
-- -- remainder by normalized word
-- rem_by_norm_word
diff --git a/test/Main.hs b/test/Main.hs
@@ -99,31 +99,6 @@ quotrem_r_case2 = do
let !(P q r) = quotrem_r 4 0xffffffffffffffff (B.complement 4)
H.assertEqual mempty (P 5 24) (P q r)
-quotrem_r_case0# :: H.Assertion
-quotrem_r_case0# = do
- let !(# q, r #) =
- quotrem_r# (wordToWord64# 2##) (wordToWord64# 4##) (wordToWord64# 4##)
- H.assertEqual mempty (P 9223372036854775809 0) (P (W64# q) (W64# r))
-
-quotrem_r_case1# :: H.Assertion
-quotrem_r_case1# = do
- let !(# q, r #) =
- quotrem_r# (wordToWord64# 0##) (wordToWord64# 4##) (wordToWord64# 2##)
- H.assertEqual mempty (P 2 0) (P (W64# q) (W64# r))
-
-quotrem_r_case2# :: H.Assertion
-quotrem_r_case2# = do
- let !(# q, r #) =
- quotrem_r#
- (wordToWord64# 4##)
- (wordToWord64# 0xffffffffffffffff##)
- (not64# (wordToWord64# 4##))
- H.assertEqual mempty (P 5 24) (P (W64# q) (W64# r))
-
--- recip_2by1 :: Word64 -> Word64
--- recip_2by1 d = r where
--- !(P r _) = quotrem_r (B.complement d) 0xffffffffffffffff d
-
recip_2by1_case0 :: H.Assertion
recip_2by1_case0 = do
let !q = recip_2by1 (B.complement 4)
@@ -134,6 +109,15 @@ recip_2by1_case1 = do
let !q = recip_2by1 (B.complement 0xff)
H.assertEqual mempty 256 q
+quotrem_2by1_case0 :: H.Assertion
+quotrem_2by1_case0 = do
+ let !d = B.complement 0xFF :: Word64
+ !o = quotrem_2by1 8 4 d (recip_2by1 d)
+ H.assertEqual mempty (P 8 2052) o
+
+
+
+
add_sub :: TestTree
add_sub = testGroup "addition & subtraction" [
Q.testProperty "addition matches (nonneg)" $
@@ -168,14 +152,9 @@ main = defaultMain $ testGroup "ppad-fixed" [
H.testCase "quotrem_r matches case0" quotrem_r_case0
, H.testCase "quotrem_r matches case1" quotrem_r_case1
, H.testCase "quotrem_r matches case2" quotrem_r_case2
- , H.testCase "quotrem_r# matches case0" quotrem_r_case0#
- , H.testCase "quotrem_r# matches case1" quotrem_r_case1#
- , H.testCase "quotrem_r# matches case2" quotrem_r_case2#
- -- , H.testCase "quotrem_r# matches case2" quotrem_r_case2
- -- , H.testCase "quotrem_r' matches case2" quotrem_r_case2'
- -- , H.testCase "quotrem_r_recip_case0 matches case0" quotrem_r_recip_case0
, H.testCase "recip_2by1 matches case0" recip_2by1_case0
, H.testCase "recip_2by1 matches case1" recip_2by1_case1
+ , H.testCase "quotrem_2by1 matches case0" quotrem_2by1_case0
]
]
@@ -280,13 +259,7 @@ main = defaultMain $ testGroup "ppad-fixed" [
-- recip_2by1_case1 = do
-- let !q = recip_2by1 (B.complement 0xff)
-- H.assertEqual mempty 256 q
---
--- quotrem_2by1_case0 :: H.Assertion
--- quotrem_2by1_case0 = do
--- let !d = B.complement 0xFF :: Word64
--- !o = quotrem_2by1 8 4 d (recip_2by1 d)
--- H.assertEqual mempty (P 8 2052) o
---
+
-- -- main -----------------------------------------------------------------------
--
-- comparison :: TestTree