commit b4eddf48496f541180b1b9aae4e96635b0de4247
parent a3f6aba9cad92e4fc8e59e4ba759a40631f41278
Author: Jared Tobin <jared@jtobin.io>
Date: Thu, 23 Jan 2025 09:34:26 +0400
lib: division preliminaries
Diffstat:
2 files changed, 203 insertions(+), 10 deletions(-)
diff --git a/lib/Data/Word/Extended.hs b/lib/Data/Word/Extended.hs
@@ -5,8 +5,11 @@
module Data.Word.Extended where
+import Control.Monad.Primitive
+import Control.Monad.ST
import Data.Bits ((.|.), (.&.), (.<<.), (.>>.), (.^.))
import qualified Data.Bits as B
+import qualified Data.Primitive.PrimArray as PA
import Data.Word (Word64)
import GHC.Generics
@@ -84,6 +87,19 @@ to_word512 n =
!w7 = fi ((n .>>. 448) .&. mask64)
in Word512 w0 w1 w2 w3 w4 w5 w6 w7
+-- comparison -----------------------------------------------------------------
+
+lt :: Word256 -> Word256 -> Bool
+lt (Word256 a0 a1 a2 a3) (Word256 b0 b1 b2 b3) =
+ let !(P _ c0) = sub_b a0 b0 0
+ !(P _ c1) = sub_b a1 b1 c0
+ !(P _ c2) = sub_b a2 b2 c1
+ !(P _ c3) = sub_b a3 b3 c2
+ in c3 /= 0
+
+gt :: Word256 -> Word256 -> Bool
+gt a b = lt b a
+
-- bits -----------------------------------------------------------------------
or :: Word256 -> Word256 -> Word256
@@ -246,7 +262,6 @@ mul_512 (Word256 x0 x1 x2 x3) (Word256 y0 y1 y2 y3) =
-- division -------------------------------------------------------------------
--- XX make this work on variable-length x, y
-- sub_mul x y m = (x - y * m, rem)
sub_mul :: Word256 -> Word256 -> Word64 -> Word256WithOverflow
sub_mul (Word256 x0 x1 x2 x3) (Word256 y0 y1 y2 y3) m =
@@ -271,6 +286,48 @@ sub_mul (Word256 x0 x1 x2 x3) (Word256 y0 y1 y2 y3) m =
!b3 = ph3 + c5 + c6
in Word256WithOverflow (Word256 z0 z1 z2 z3) b3
+-- x =- y * m
+-- requires (len x - x_offset) >= len y > 0
+sub_mul_to
+ :: PrimMonad m
+ => PA.MutablePrimArray (PrimState m) Word64
+ -> Int
+ -> PA.PrimArray Word64
+ -> Word64
+ -> m Word64
+sub_mul_to x x_offset y m = do
+ let l = PA.sizeofPrimArray y
+ loop !j !borrow
+ | j == l = pure borrow
+ | otherwise = do
+ !x_j <- PA.readPrimArray x (j + x_offset)
+ let !y_j = PA.indexPrimArray y j
+ !(P s carry1) = sub_b x_j borrow 0
+ !(P ph pl) = mul_c y_j m
+ !(P t carry2) = sub_b s pl 0
+ PA.writePrimArray x (j + x_offset) t
+ loop (succ j) (ph + carry1 + carry2)
+ loop 0 0
+
+-- requires (len x - x_offset) >= len y > 0
+add_to
+ :: PrimMonad m
+ => PA.MutablePrimArray (PrimState m) Word64
+ -> Int
+ -> PA.PrimArray Word64
+ -> m Word64
+add_to x x_offset y = do
+ let l = PA.sizeofPrimArray y
+ loop !j !cacc
+ | j == l = pure cacc
+ | otherwise = do
+ xj <- PA.readPrimArray x (j + x_offset)
+ let yj = PA.indexPrimArray y j
+ !(P nex carry) = add_c xj yj carry
+ PA.writePrimArray x (j + x_offset) nex
+ loop (succ j) carry
+ loop 0 0
+
-- quotient, remainder of (hi, lo) divided by y
-- translated from Div64 in go's math/bits package
--
@@ -336,13 +393,149 @@ quotrem_2by1 uh ul d rec =
then P (qh_y + 1) (r_y - d)
else P qh_y r_y
--- XX make this work on variable-length x, y (udivremBy1)
-quotrem_by1 :: Word256 -> Word64 -> Word256WithOverflow
-quotrem_by1 (Word256 u0 u1 u2 u3) d =
+-- quotrem_by1 :: Word256 -> Word64 -> Word256WithOverflow
+-- quotrem_by1 (Word256 u0 u1 u2 u3) d =
+-- let !rec = recip_2by1 d
+-- !r0 = u3
+-- !(P q2 r1) = quotrem_2by1 r0 u2 d rec
+-- !(P q1 r2) = quotrem_2by1 r1 u1 d rec
+-- !(P q0 r3) = quotrem_2by1 r2 u0 d rec
+-- in Word256WithOverflow (Word256 q0 q1 q2 0) r3
+
+quotrem_by1
+ :: PrimMonad m
+ => PA.MutablePrimArray (PrimState m) Word64
+ -> PA.PrimArray Word64
+ -> Word64
+ -> m Word64
+quotrem_by1 quo u d = do
let !rec = recip_2by1 d
- !r0 = u3
- !(P q2 r1) = quotrem_2by1 r0 u2 d rec
- !(P q1 r2) = quotrem_2by1 r1 u1 d rec
- !(P q0 r3) = quotrem_2by1 r2 u0 d rec
- in Word256WithOverflow (Word256 q0 q1 q2 0) r3
+ !lu = PA.sizeofPrimArray u
+ !r0 = PA.indexPrimArray u (lu - 1)
+ loop !j !racc
+ | j < 0 = pure racc
+ | otherwise = do
+ let uj = PA.indexPrimArray u j
+ !(P qj rnex) = quotrem_2by1 racc uj d rec
+ PA.writePrimArray quo j qj
+ loop (pred j) rnex
+ loop (lu - 2) r0
+
+quotrem_knuth
+ :: PrimMonad m
+ => PA.MutablePrimArray (PrimState m) Word64
+ -> PA.MutablePrimArray (PrimState m) Word64
+ -> PA.PrimArray Word64
+ -> m ()
+quotrem_knuth quo u d = do
+ !lu <- PA.getSizeofMutablePrimArray u
+ let !ld = PA.sizeofPrimArray d
+ !dh = PA.indexPrimArray d (ld - 1)
+ !dl = PA.indexPrimArray d (ld - 2)
+ !rec = recip_2by1 dh
+ loop !j
+ | j < 0 = pure ()
+ | otherwise = do
+ !u2 <- PA.readPrimArray u (j + ld)
+ !u1 <- PA.readPrimArray u (j + ld - 1)
+ !u0 <- PA.readPrimArray u (j + ld - 2)
+ let !qhat | u2 >= dh = 0xffffffffffffffff
+ | otherwise =
+ let !(P qh rh) = quotrem_2by1 u2 u1 dh rec
+ !(P ph pl) = mul_c qh dl
+ in if ph > rh || (ph == rh && pl > u0)
+ then qh - 1
+ else qh
+
+ borrow <- sub_mul_to u j d qhat
+ PA.writePrimArray u (j + ld) (u2 - borrow)
+ if u2 < borrow
+ then do
+ let !qh = qhat - 1
+ r <- add_to u j d
+ PA.writePrimArray u (j + ld) r
+ PA.writePrimArray quo j qh
+ else
+ PA.writePrimArray quo j qhat
+ loop (pred j)
+ loop (lu - ld - 1)
+
+quotrem
+ :: PrimMonad m
+ => PA.MutablePrimArray (PrimState m) Word64
+ -> PA.PrimArray Word64
+ -> PA.PrimArray Word64
+ -> Maybe (PA.MutablePrimArray (PrimState m) Word64)
+ -> m ()
+quotrem quo u d mr = do
+ let !ld = PA.sizeofPrimArray d
+ !lu = PA.sizeofPrimArray u
+ !dlen = len_loop d (ld - 1)
+ !shift = B.countLeadingZeros (PA.indexPrimArray d (dlen - 1))
+ dn <- PA.newPrimArray dlen
+ let go_dn !j
+ | j < 0 = pure ()
+ | otherwise = do
+ let !dj = PA.indexPrimArray d j
+ !dj_1 = PA.indexPrimArray d (j - 1)
+ !val = (dj .<<. shift) .|. (dj_1 .>>. (64 - shift))
+ PA.writePrimArray dn j val
+ go_dn (pred j)
+ go_dn (dlen - 1)
+ PA.writePrimArray dn 0 (PA.indexPrimArray d 0 .<<. shift)
+ let !ulen = len_loop u (lu - 1)
+ if ulen < dlen
+ then case mr of
+ Nothing -> pure ()
+ Just !r -> PA.copyPrimArray r 0 u 0 lu
+ else do
+ un <- PA.newPrimArray (ulen + 1)
+ let u_ulen = PA.indexPrimArray u (ulen - 1)
+ PA.writePrimArray un ulen (u_ulen .>>. (64 - shift))
+ -- duplicated, but easy to handle mutableprimarrays this way
+ let go_un !j
+ | j < 0 = pure ()
+ | otherwise = do
+ let !uj = PA.indexPrimArray u j
+ !uj_1 = PA.indexPrimArray u (j - 1)
+ !val = (uj .<<. shift) .|. (uj_1 .>>. (64 - shift))
+ PA.writePrimArray un j val
+ go_dn (pred j)
+ go_un (ulen - 1)
+ PA.writePrimArray un 0 (PA.indexPrimArray u 0 .<<. shift)
+ if dlen == 1
+ then do
+ dn_0 <- PA.readPrimArray dn 0
+ un_c <- PA.freezePrimArray un 0 (ulen + 1)
+ r <- quotrem_by1 quo un_c dn_0
+ case mr of
+ Nothing -> pure ()
+ Just !re -> do
+ PA.writePrimArray re 0 0
+ PA.writePrimArray re 1 0
+ PA.writePrimArray re 2 0
+ PA.writePrimArray re 3 (r .>>. shift)
+ else do
+ dnf <- PA.unsafeFreezePrimArray dn
+ quotrem_knuth quo un dnf
+ case mr of
+ Nothing -> pure ()
+ Just !r -> do
+ let go_r !j
+ | j == dlen = pure ()
+ | otherwise = do
+ un_j <- PA.readPrimArray un j
+ un_j_1 <- PA.readPrimArray un (j + 1)
+ let !val = (un_j .>>. shift)
+ .|. (un_j_1 .<<. (64 - shift))
+ PA.writePrimArray r j val
+ go_r (succ j)
+ go_r 0
+ un_dlen_1 <- PA.readPrimArray un (dlen - 1)
+ PA.writePrimArray r (dlen - 1) (un_dlen_1 .>>. shift)
+ where
+ len_loop !arr !j
+ | j < 0 = 0
+ | PA.indexPrimArray arr j /= 0 = j + 1
+ | otherwise = len_loop arr (pred j)
diff --git a/ppad-fixed.cabal b/ppad-fixed.cabal
@@ -26,7 +26,7 @@ library
Data.Word.Extended
build-depends:
base >= 4.9 && < 5
- , primitive >= 0.8 && > 0.10
+ , primitive >= 0.8 && < 0.10
test-suite fixed-tests
type: exitcode-stdio-1.0