commit 75cdb8d4295292c8c5af41084f9771e02367e024
parent 75f67e89d47355dd979e84db5cd439951bf7394a
Author: Jared Tobin <jared@jtobin.io>
Date: Mon, 14 Jul 2025 12:03:10 -0230
lib: montgomery reduction initial stab
Diffstat:
6 files changed, 215 insertions(+), 8 deletions(-)
diff --git a/bench/Main.hs b/bench/Main.hs
@@ -48,10 +48,12 @@ division_utils = bgroup "division utilities" [
main :: IO ()
main = defaultMain [
- add_sub
- , multiplication
- , division
- , division_utils
+ bgroup "extended" [
+ add_sub
+ , multiplication
+ , division
+ , division_utils
+ ]
, Wide.benches
, Limb.benches
]
diff --git a/lib/Data/Word/Limb.hs b/lib/Data/Word/Limb.hs
@@ -11,8 +11,10 @@ module Data.Word.Limb (
add_c#
, sub_b#
, mul_c#
+ , mul_w#
, recip#
, quot#
+ , mul_add_c#
-- * Reciprocal
, Reciprocal(..)
@@ -56,6 +58,37 @@ mul_c# a b =
in (# l, h #)
{-# INLINE mul_c# #-}
+-- wrapping multiplication
+mul_w# :: Word# -> Word# -> Word#
+mul_w# a b =
+ let !(# _, l #) = timesWord2# a b
+ in l
+{-# INLINE mul_w# #-}
+
+mul_add_c# :: Word# -> Word# -> Word# -> Word# -> (# Word#, Word# #)
+mul_add_c# lhs rhs addend carry =
+ let !(# l_0, h_0 #) = add_w# (mul_c# lhs rhs) (# addend, 0## #)
+ !(# l_1, c #) = add_c# l_0 carry 0##
+ !h_1 = plusWord# h_0 c
+ in (# l_1, h_1 #)
+ where
+ -- duplicated w/Data.Word.Wide to avoid awkward module structuring
+ add_wc#
+ :: (# Word#, Word# #)
+ -> (# Word#, Word# #)
+ -> (# Word#, Word#, Word# #)
+ add_wc# (# a0, a1 #) (# b0, b1 #) =
+ let !(# s0, c0 #) = add_c# a0 b0 0##
+ !(# s1, c1 #) = add_c# a1 b1 c0
+ in (# s0, s1, c1 #)
+ {-# INLINE add_wc# #-}
+
+ add_w# :: (# Word#, Word# #) -> (# Word#, Word# #) -> (# Word#, Word# #)
+ add_w# a b =
+ let !(# c0, c1, _ #) = add_wc# a b
+ in (# c0, c1 #)
+ {-# INLINE add_w# #-}
+
-- division -------------------------------------------------------------------
-- normalized divisor, shift, reciprocal
diff --git a/lib/Data/Word/Montgomery.hs b/lib/Data/Word/Montgomery.hs
@@ -0,0 +1,86 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE NumericUnderscores #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# LANGUAGE UnboxedSums #-}
+{-# LANGUAGE UnboxedTuples #-}
+{-# LANGUAGE UnliftedNewtypes #-}
+
+-- XX this should probably be its own library
+
+module Data.Word.Montgomery where
+
+import Control.DeepSeq
+import qualified Data.Choice as C
+import Data.Bits ((.|.), (.&.), (.<<.), (.>>.))
+import qualified Data.Bits as B
+import qualified Data.Word.Limb as L
+import qualified Data.Word.Wider as L
+import GHC.Exts
+import Prelude hiding (div, mod, or, and, not, quot, rem, recip)
+
+-- XX my eyes, it burns
+
+redc_inner#
+ :: (# Word#, Word#, Word#, Word# #) -- upper
+ -> (# Word#, Word#, Word#, Word# #) -- lower
+ -> (# Word#, Word#, Word#, Word# #) -- modulus
+ -> Word#
+ -> (# (# Word#, Word#, Word#, Word# #), Word# #) -- upper, meta-carry
+redc_inner#
+ (# u0, u1, u2, u3 #)
+ (# l0, l1, l2, l3 #)
+ (# m0, m1, m2, m3 #)
+ mninv =
+ let -- outer loop, i == 0 ---------------------------------------------------
+ !w_0 = L.mul_w# l0 mninv
+ !(# _, c_00 #) = L.mul_add_c# w_0 m0 l0 0## -- m0, l0
+ -- first inner loop (j < 4)
+ !(# l0_1, c_01 #) = L.mul_add_c# w_0 m1 l1 c_00 -- l<i idx>_<j idx>
+ !(# l0_2, c_02 #) = L.mul_add_c# w_0 m2 l2 c_01
+ !(# l0_3, c_03 #) = L.mul_add_c# w_0 m3 l3 c_02
+ -- final stanza
+ !(# u_0, mc_0 #) = L.add_c# u0 c_03 0##
+ -- end states
+ -- (# l0, l0_1, l0_2, l0_3 #)
+ -- (# u_0, u1, u2, u3 #)
+ -- outer loop, i == 1 ---------------------------------------------------
+ !w_1 = L.mul_w# l0_1 mninv
+ !(# _, c_10 #) = L.mul_add_c# w_1 m0 l0_1 0##
+ -- first inner loop (j < 3)
+ !(# l1_1, c_11 #) = L.mul_add_c# w_1 m1 l0_2 c_10 -- j == 1
+ !(# l1_2, c_12 #) = L.mul_add_c# w_1 m2 l0_3 c_11 -- j == 2
+ -- second inner loop (j < 4)
+ !(# u1_3, c_13 #) = L.mul_add_c# w_1 m3 u_0 c_12 -- j == 3
+ -- final stanza
+ !(# u_1, mc_1 #) = L.add_c# u1 c_13 mc_0
+ -- end states
+ -- (# l0, l0_1, l1_1, l1_2 #)
+ -- (# u1_3, u_1, u2, u3 #)
+ -- outer loop, i == 2 ---------------------------------------------------
+ !w_2 = L.mul_w# l1_1 mninv
+ !(# _, c_20 #) = L.mul_add_c# w_2 m0 l1_1 0##
+ -- first inner loop (j < 2)
+ !(# l2_1, c_21 #) = L.mul_add_c# w_2 m1 l1_2 c_20 -- j == 1
+ -- second inner loop (j < 4)
+ !(# u2_2, c_22 #) = L.mul_add_c# w_2 m2 u1_3 c_21 -- j == 2
+ !(# u2_3, c_23 #) = L.mul_add_c# w_2 m3 u_1 c_22 -- j == 3
+ -- final stanza
+ !(# u_2, mc_2 #) = L.add_c# u2 c_23 mc_1
+ -- end states
+ -- (# l0, l0_1, l1_1, l2_1 #)
+ -- (# u2_2, u2_3, u_2, u3 #)
+ -- outer loop, i == 3 ---------------------------------------------------
+ !w_3 = L.mul_w# l2_1 mninv
+ !(# _, c_30 #) = L.mul_add_c# w_3 m0 l2_1 0##
+ -- second inner loop (j < 4)
+ !(# u3_1, c_31 #) = L.mul_add_c# w_3 m1 u2_2 c_30 -- j == 1
+ !(# u3_2, c_32 #) = L.mul_add_c# w_3 m2 u2_3 c_31 -- j == 2
+ !(# u3_3, c_33 #) = L.mul_add_c# w_3 m3 u_2 c_32 -- j == 3
+ -- final stanza
+ !(# u_3, mc_3 #) = L.add_c# u3 c_33 mc_2
+ -- end states
+ -- (# l0, l0_1, l1_1, l2_1 #)
+ -- (# u3_1, u3_2, u3_3, u_3 #)
+ in (# (# u3_1, u3_2, u3_3, u_3 #), mc_3 #)
+
diff --git a/lib/Data/Word/Wide.hs b/lib/Data/Word/Wide.hs
@@ -177,12 +177,12 @@ shr_of# (# l, h #) s =
!shift = remWord# (int2Word# s) (int2Word# wide_size)
loop !j !res
| isTrue# (j <# shift_bits) =
- let !bit = C.from_word_lsb#
+ let !bit = C.from_word_lsb# -- XX not inlined
(and# (uncheckedShiftRL# shift j) 1##)
- !nres = C.ct_select_wide#
+ !nres = C.ct_select_wide# -- XX
res
- (C.expect_wide#
- (shr_of_vartime#
+ (C.expect_wide# -- XX
+ (shr_of_vartime# -- XX
res
(word2Int# (uncheckedShiftL# 1## j)))
"shift within range")
diff --git a/lib/Data/Word/Wider.hs b/lib/Data/Word/Wider.hs
@@ -0,0 +1,84 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE NumericUnderscores #-}
+{-# LANGUAGE ViewPatterns #-}
+{-# LANGUAGE UnboxedSums #-}
+{-# LANGUAGE UnboxedTuples #-}
+{-# LANGUAGE UnliftedNewtypes #-}
+
+module Data.Word.Wider where
+
+import Control.DeepSeq
+import qualified Data.Choice as C
+import Data.Bits ((.|.), (.&.), (.<<.), (.>>.))
+import qualified Data.Bits as B
+import qualified Data.Word.Limb as L
+import GHC.Exts
+import Prelude hiding (div, mod, or, and, not, quot, rem, recip)
+
+-- utilities ------------------------------------------------------------------
+
+fi :: (Integral a, Num b) => a -> b
+fi = fromIntegral
+{-# INLINE fi #-}
+
+-- wide words -----------------------------------------------------------------
+
+-- little-endian, i.e. (# w0, w1, w2, w3 #)
+data Wider = Wider (# Word#, Word#, Word#, Word# #)
+
+instance Show Wider where
+ show (Wider (# a, b, c, d #)) =
+ "(" <> show (W# a) <> ", " <> show (W# b) <> ", "
+ <> show (W# c) <> ", " <> show (W# d) <> ")"
+
+instance Eq Wider where
+ Wider (# a0, b0, c0, d0 #) == Wider (# a1, b1, c1, d1 #) =
+ isTrue# (andI#
+ ((andI# (eqWord# a0 a1) (eqWord# b0 b1)))
+ ((andI# (eqWord# c0 c1) (eqWord# d0 d1))))
+
+instance NFData Wider where
+ rnf (Wider a) = case a of (# _, _, _, _ #) -> ()
+
+-- construction / conversion --------------------------------------------------
+
+-- construct from lo, hi
+wider :: Word -> Word -> Word -> Word -> Wider
+wider (W# w0) (W# w1) (W# w2) (W# w3) = Wider (# w0, w1, w2, w3 #)
+
+to :: Integer -> Wider
+to n =
+ let !size = B.finiteBitSize (0 :: Word)
+ !mask = fi (maxBound :: Word) :: Integer
+ !(W# w0) = fi (n .&. mask)
+ !(W# w1) = fi ((n .>>. size) .&. mask)
+ !(W# w2) = fi ((n .>>. (2 * size)) .&. mask)
+ !(W# w3) = fi ((n .>>. (3 * size)) .&. mask)
+ in Wider (# w0, w1, w2, w3 #)
+
+from :: Wider -> Integer
+from (Wider (# w0, w1, w2, w3 #)) =
+ fi (W# w3) .<<. (3 * size)
+ .|. fi (W# w2) .<<. (2 * size)
+ .|. fi (W# w1) .<<. size
+ .|. fi (W# w0)
+ where
+ !size = B.finiteBitSize (0 :: Word)
+
+-- subtract-with-overflow
+sub_of#
+ :: (# Word#, Word#, Word#, Word# #)
+ -> (# Word#, Word#, Word#, Word# #)
+ -> (# Word#, Word#, Word#, Word#, Word# #)
+sub_of# (# a0, a1, a2, a3 #)
+ (# b0, b1, b2, b3 #) =
+ let !(# s0, c0 #) = L.sub_b# a0 b0 0##
+ !(# s1, c1 #) = L.sub_b# a1 b1 c0
+ !(# s2, c2 #) = L.sub_b# a2 b2 c1
+ !(# s3, c3 #) = L.sub_b# a3 b3 c2
+ in (# s0, s1, s2, s3, c3 #)
+{-# INLINE sub_of# #-}
+
+
+
diff --git a/ppad-fixed.cabal b/ppad-fixed.cabal
@@ -27,6 +27,8 @@ library
, Data.Word.Extended
, Data.Word.Limb
, Data.Word.Wide
+ , Data.Word.Wider
+ , Data.Word.Montgomery
build-depends:
base >= 4.9 && < 5
, deepseq >= 1.5 && < 1.6