commit f2c8f595c79569338bce841103950dfe7b051f28
parent 208a7d6e3b506a77dbc06227cdf610c34edc6aef
Author: Jared Tobin <jared@jtobin.io>
Date: Wed, 29 Jan 2025 17:21:22 +0400
lib: working, but broken, refactoring snapshot
Diffstat:
1 file changed, 93 insertions(+), 73 deletions(-)
diff --git a/lib/Data/Word/Extended.hs b/lib/Data/Word/Extended.hs
@@ -646,22 +646,13 @@ quotrem_knuth
:: PrimMonad m
=> PA.MutablePrimArray (PrimState m) Word64 -- quotient (potentially large)
-> PA.MutablePrimArray (PrimState m) Word64 -- normalized dividend
- -> Word256 -- normalized divisor
- -> Int -- words in normalized divisor
+ -> Int -- size of normalize dividend
+ -> PA.PrimArray Word64 -- normalized divisor
-> m ()
-quotrem_knuth quo u d@(Word256 d0 d1 d2 d3) ld = do
- !lu <- PA.getSizeofMutablePrimArray u
- darr <- PA.newPrimArray 4
- PA.writePrimArray darr 0 d0
- PA.writePrimArray darr 1 d1
- PA.writePrimArray darr 2 d2
- PA.writePrimArray darr 3 d3
- d_final <- PA.unsafeFreezePrimArray darr
- let (dh, dl) = case ld of
- 4 -> (d3, d2)
- 3 -> (d2, d1)
- 2 -> (d1, d0)
- _ -> error "ppad-fixed (quotrem_knuth): bad index"
+quotrem_knuth quo u ulen d = do
+ let !ld = PA.sizeofPrimArray d
+ !dh = PA.indexPrimArray d (ld - 1)
+ !dl = PA.indexPrimArray d (ld - 2)
!rec = recip_2by1 dh
loop !j
| j < 0 = pure ()
@@ -678,7 +669,7 @@ quotrem_knuth quo u d@(Word256 d0 d1 d2 d3) ld = do
then qh - 1
else qh
- !borrow <- sub_mul u j d_final ld qhat
+ !borrow <- sub_mul u j d ld qhat
PA.writePrimArray u (j + ld) (u2 - borrow)
if u2 < borrow
then do
@@ -689,73 +680,107 @@ quotrem_knuth quo u d@(Word256 d0 d1 d2 d3) ld = do
else
PA.writePrimArray quo j qhat
loop (pred j)
- loop (lu - ld - 1)
+ loop (ulen - ld - 1)
+
+normalize_divisor
+ :: PrimMonad m
+ => Word256
+ -> m (PA.PrimArray Word64, Int, Word64) -- XX more efficient
+normalize_divisor (Word256 d0 d1 d2 d3) = do
+ let (dlen, d_last, shift)
+ | d3 /= 0 = (4, d3, B.countLeadingZeros d3)
+ | d2 /= 0 = (3, d2, B.countLeadingZeros d2)
+ | d1 /= 0 = (2, d1, B.countLeadingZeros d1)
+ | d0 /= 0 = (1, d0, B.countLeadingZeros d0)
+ | otherwise = error "ppad-fixed (normalize): invalid 256-bit word"
+ dn <- PA.newPrimArray dlen
+ case dlen of
+ 4 -> do
+ PA.writePrimArray dn 3 d3
+ PA.writePrimArray dn 2 d2
+ PA.writePrimArray dn 1 d1
+ PA.writePrimArray dn 0 d0
+ 3 -> do
+ PA.writePrimArray dn 2 d2
+ PA.writePrimArray dn 1 d1
+ PA.writePrimArray dn 0 d0
+ 2 -> do
+ PA.writePrimArray dn 1 d1
+ PA.writePrimArray dn 0 d0
+ _ -> do
+ PA.writePrimArray dn 0 d0
+ let norm !j !dj
+ | j == 0 = do
+ let !dn_0 = dj .<<. shift
+ PA.writePrimArray dn 0 dn_0
+ pure dn_0
+ | otherwise = do
+ dj_1 <- PA.readPrimArray dn (j - 1)
+ PA.writePrimArray dn j $
+ (dj .<<. shift) .|. (dj_1 .>>. (64 - shift))
+ norm (j - 1) dj_1
+ dn_0 <- norm (dlen - 1) d_last
+ d_final <- PA.unsafeFreezePrimArray dn
+ pure (d_final, shift, dn_0)
quotrem
:: PrimMonad m
=> PA.MutablePrimArray (PrimState m) Word64 -- quotient (potentially large)
- -> PA.PrimArray Word64 -- dividend (potentially large)
- -> Word256 -- divisor (256-bit)
+ -> PA.MutablePrimArray (PrimState m) Word64 -- unnormalized dividend
+ -> Word256 -- unnormalized divisor
-> m (PA.PrimArray Word64) -- remainder (256-bit)
-quotrem quo u (Word256 d0 d1 d2 d3) = do
- let -- normalize divisor
- (dlen, shift)
- | d3 /= 0 = (4, B.countLeadingZeros d3)
- | d2 /= 0 = (3, B.countLeadingZeros d2)
- | d1 /= 0 = (2, B.countLeadingZeros d1)
- | otherwise = (1, B.countLeadingZeros d0) -- zero not checked
- dn_3 = (d3 .<<. shift) .|. (d2 .>>. (64 - shift))
- dn_2 = (d2 .<<. shift) .|. (d1 .>>. (64 - shift))
- dn_1 = (d1 .<<. shift) .|. (d0 .>>. (64 - shift))
- dn_0 = d0 .<<. shift
- !dn = Word256 dn_0 dn_1 dn_2 dn_3
- -- get size of normalized dividend
- lu = PA.sizeofPrimArray u
- ulen = let loop !j
- | j < 0 = 0
- | PA.indexPrimArray u j /= 0 = j + 1
- | otherwise = loop (j - 1)
- in loop (lu - 1)
+quotrem quo u d = do
+ -- normalize divisor
+ !(dn, shift, dn_0) <- normalize_divisor d
+ let !dlen = PA.sizeofPrimArray dn
+
+ -- get size of normalized dividend
+ !lu <- PA.getSizeofMutablePrimArray u
+ !ulen <- let loop !j
+ | j < 0 = pure 0
+ | otherwise = do
+ uj <- PA.readPrimArray u j
+ if uj /= 0 then pure (j + 1) else loop (j - 1)
+ in loop (lu - 2) -- don't touch the uninitialized word
if ulen < dlen
then do
-- u always has size at least 4
!r <- PA.newPrimArray 4
- PA.copyPrimArray r 0 u 0 4
+ PA.copyMutablePrimArray r 0 u 0 4
PA.unsafeFreezePrimArray r
else do
-- normalize dividend
- !un <- PA.newPrimArray (ulen + 1)
- let u_hi = PA.indexPrimArray u (ulen - 1)
- PA.writePrimArray un ulen (u_hi .>>. (64 - shift))
+ u_hi <- PA.readPrimArray u (ulen - 1)
+ PA.writePrimArray u ulen (u_hi .>>. (64 - shift))
let normalize_u !j !uj
| j == 0 =
- PA.writePrimArray un 0 (PA.indexPrimArray u 0 .<<. shift)
+ PA.writePrimArray u 0 (uj .<<. shift)
| otherwise = do
- let !uj_1 = PA.indexPrimArray u (j - 1)
- !val = (uj .<<. shift) .|. (uj_1 .>>. (64 - shift))
- PA.writePrimArray un j val
- normalize_u (pred j) uj_1
+ !uj_1 <- PA.readPrimArray u (j - 1)
+ PA.writePrimArray u j $
+ (uj .<<. shift) .|. (uj_1 .>>. (64 - shift))
+ normalize_u (j - 1) uj_1
normalize_u (ulen - 1) u_hi
if dlen == 1
then do
-- normalized divisor is small
- !un_final <- PA.unsafeFreezePrimArray un
- !r <- quotrem_by1 quo un_final dn_0
+ !un <- PA.unsafeFreezePrimArray u
+ !r <- quotrem_by1 quo un dn_0
pure $ PA.primArrayFromList [r .>>. shift, 0, 0, 0] -- XX
else do
- quotrem_knuth quo un dn dlen
+ quotrem_knuth quo u (ulen + 1) dn
-- unnormalize remainder
let unn_rem !j !un_j
| j == dlen = do
- PA.writePrimArray un (j - 1) (un_j .>>. shift)
- PA.unsafeFreezePrimArray un
+ PA.writePrimArray u (j - 1) (un_j .>>. shift)
+ PA.unsafeFreezePrimArray u
| otherwise = do
- !un_j_1 <- PA.readPrimArray un (j + 1)
+ !un_j_1 <- PA.readPrimArray u (j + 1)
let !unn_j = (un_j .>>. shift) .|. (un_j_1 .<<. (64 - shift))
- PA.writePrimArray un j unn_j
+ PA.writePrimArray u j unn_j
unn_rem (j + 1) un_j_1
- !un_0 <- PA.readPrimArray un 0
+ !un_0 <- PA.readPrimArray u 0
unn_rem 0 un_0
-- x =- y * m
@@ -785,20 +810,16 @@ add_to
:: PrimMonad m
=> PA.MutablePrimArray (PrimState m) Word64
-> Int
- -> Word256
+ -> PA.PrimArray Word64
-> Int
-> m Word64
-add_to x x_offset (Word256 y0 y1 y2 y3) l = do
+add_to x x_offset y l = do
let loop !j !cacc
| j == l = pure cacc
| otherwise = do
xj <- PA.readPrimArray x (j + x_offset)
- let !(P nex carry) = case j of
- 0 -> add_c xj y0 cacc
- 1 -> add_c xj y1 cacc
- 2 -> add_c xj y2 cacc
- 3 -> add_c xj y3 cacc
- _ -> error "ppad-fixed (add_to): bad index"
+ let yj = PA.indexPrimArray y j
+ !(P nex carry) = add_c xj yj cacc
PA.writePrimArray x (j + x_offset) nex
loop (succ j) carry
loop 0 0
@@ -818,8 +839,7 @@ div u@(Word256 u0 u1 u2 u3) d@(Word256 d0 _ _ _)
PA.writePrimArray u_arr 1 u1
PA.writePrimArray u_arr 2 u2
PA.writePrimArray u_arr 3 u3
- u_final <- PA.unsafeFreezePrimArray u_arr
- _ <- quotrem quo u_final d
+ _ <- quotrem quo u_arr d
q0 <- PA.readPrimArray quo 0
q1 <- PA.readPrimArray quo 1
q2 <- PA.readPrimArray quo 2
@@ -839,14 +859,14 @@ mod u@(Word256 u0 u1 u2 u3) d@(Word256 d0 _ _ _)
-- allocate quotient
quo <- PA.newPrimArray 4
PA.setPrimArray quo 0 4 0
- -- allocate dividend
- u_arr <- PA.newPrimArray 4
- PA.writePrimArray u_arr 0 u0
- PA.writePrimArray u_arr 1 u1
- PA.writePrimArray u_arr 2 u2
- PA.writePrimArray u_arr 3 u3
- u_final <- PA.unsafeFreezePrimArray u_arr
- r <- quotrem quo u_final d
+ -- allocate dividend, leaving enough space for normalization
+ u_hot <- PA.newPrimArray 5
+ PA.writePrimArray u_hot 0 u0
+ PA.writePrimArray u_hot 1 u1
+ PA.writePrimArray u_hot 2 u2
+ PA.writePrimArray u_hot 3 u3
+ -- last index of u_hot intentionally unset
+ r <- quotrem quo u_hot d
let r0 = PA.indexPrimArray r 0
r1 = PA.indexPrimArray r 1
r2 = PA.indexPrimArray r 2