commit b71b4c1a2430922f3debe4d8053e22082110ce1c
parent 5bc84d5ba1a0990fbdd9fdc843de04bd489ea86a
Author: Jared Tobin <jared@jtobin.io>
Date: Fri, 14 Nov 2025 16:18:28 +0400
lib: montgomery multiplication first stab
Diffstat:
2 files changed, 149 insertions(+), 11 deletions(-)
diff --git a/lib/Data/Word/Montgomery.hs b/lib/Data/Word/Montgomery.hs
@@ -9,10 +9,31 @@
module Data.Word.Montgomery where
import qualified Data.Word.Limb as L
-import qualified Data.Word.Wider as W
+import qualified Data.Word.Wide as W
+import Data.Word.Wider (Wider(..))
+import qualified Data.Word.Wider as WW
import GHC.Exts
import Prelude hiding (div, mod, or, and, not, quot, rem, recip)
+-- utilities ------------------------------------------------------------------
+
+-- Convert an unboxed word to an unboxed wide word.
+w :: Word# -> (# Word#, Word# #)
+w m = (# m, 0## #)
+{-# INLINE w #-}
+
+-- Truncate an unboxed wide word to an unboxed word.
+lo :: (# Word#, Word# #) -> Word#
+lo (# l, _ #) = l
+{-# INLINE lo #-}
+
+-- Capture the high bits of an unboxed wide word.
+hi :: (# Word#, Word# #) -> Word#
+hi (# _, h #) = h
+{-# INLINE hi #-}
+
+-- innards --------------------------------------------------------------------
+
redc_inner#
:: (# Word#, Word#, Word#, Word# #) -- ^ upper words
-> (# Word#, Word#, Word#, Word# #) -- ^ lower words
@@ -82,13 +103,13 @@ redc#
-> (# Word#, Word#, Word#, Word# #) -- ^ result
redc# l u m n =
let !(# nu, mc #) = redc_inner# u l m n
- in W.sub_mod_c# nu mc m m
+ in WW.sub_mod_c# nu mc m m
{-# INLINE redc# #-}
-redc :: W.Wider -> W.Wider -> W.Wider -> Word -> W.Wider
-redc (W.Wider l) (W.Wider u) (W.Wider m) (W# n) =
+redc :: Wider -> Wider -> Wider -> Word -> Wider
+redc (Wider l) (Wider u) (Wider m) (W# n) =
let !res = redc# l u m n
- in (W.Wider res)
+ in (Wider res)
retr_inner#
:: (# Word#, Word#, Word#, Word# #) -- ^ value in montgomery form
@@ -140,11 +161,119 @@ retr# f m n = retr_inner# f m n
{-# INLINE retr# #-}
retr
- :: W.Wider -- ^ value in montgomery form
- -> W.Wider -- ^ modulus
- -> Word -- ^ mod neg inv
- -> W.Wider -- ^ retrieved value
-retr (W.Wider f) (W.Wider m) (W# n) =
+ :: Wider -- ^ value in montgomery form
+ -> Wider -- ^ modulus
+ -> Word -- ^ mod neg inv
+ -> Wider -- ^ retrieved value
+retr (Wider f) (Wider m) (W# n) =
let !res = retr# f m n
- in (W.Wider res)
+ in (Wider res)
+
+-- | Montgomery multiplication.
+mul_inner#
+ :: (# Word#, Word#, Word#, Word# #) -- ^ x
+ -> (# Word#, Word#, Word#, Word# #) -- ^ y
+ -> (# Word#, Word#, Word#, Word# #) -- ^ modulus
+ -> Word# -- ^ mod neg inv
+ -> (# (# Word#, Word#, Word#, Word# #), Word# #) -- ^ product, meta-carry
+mul_inner# (# x0, x1, x2, x3 #) (# y0, y1, y2, y3 #) (# m0, m1, m2, m3 #) n =
+ let -- outer loop, i == 0 ---------------------------------------------------
+ !axy0 = W.mul_w# (w x0) (w y0) -- out state
+ !u0 = L.mul_w# (lo axy0) n -- 0, 0, 0, 0
+ !(# (# _, a0 #), c0 #) = W.add_c# (W.mul_w# (w u0) (w m0)) axy0
+ !carry0 = (# a0, c0 #)
+ -- inner loop, j == 1
+ !axy0_1 = W.mul_w# (w x0) (w y1)
+ !umc0_1 = W.add_w# (W.mul_w# (w u0) (w m1)) carry0
+ !(# (# o0, ab0_1 #), c0_1 #) = W.add_c# axy0_1 umc0_1 -- o0, 0, 0, 0
+ !carry0_1 = (# ab0_1, c0_1 #)
+ -- inner loop, j == 2
+ !axy0_2 = W.mul_w# (w x0) (w y2)
+ !umc0_2 = W.add_w# (W.mul_w# (w u0) (w m2)) carry0_1
+ !(# (# p0, ab0_2 #), c0_2 #) = W.add_c# axy0_2 umc0_2 -- o0, p0, 0, 0
+ !carry0_2 = (# ab0_2, c0_2 #)
+ -- inner loop, j == 3
+ !axy0_3 = W.mul_w# (w x0) (w y3)
+ !umc0_3 = W.add_w# (W.mul_w# (w u0) (w m3)) carry0_2
+ !(# (# q0, ab0_3 #), c0_3 #) = W.add_c# axy0_3 umc0_3 -- o0, p0, q0, 0
+ !carry0_3 = (# ab0_3, c0_3 #)
+ -- final stanza
+ !(# r0, mc0 #) = carry0_3 -- o0, p0, q0, r0
+ -- outer loop, i == 1 ---------------------------------------------------
+ !axy1 = W.add_w# (W.mul_w# (w x1) (w y0)) (w o0)
+ !u1 = L.mul_w# (lo axy1) n
+ !(# (# _, a1 #), c1 #) = W.add_c# (W.mul_w# (w u1) (w m0)) axy1
+ !carry1 = (# a1, c1 #)
+ -- inner loop, j == 1
+ !axy1_1 = W.add_w# (W.mul_w# (w x1) (w y1)) (w p0)
+ !umc1_1 = W.add_w# (W.mul_w# (w u1) (w m1)) carry1
+ !(# (# o1, ab1_1 #), c1_1 #) = W.add_c# axy1_1 umc1_1 -- o1, p0, q0, r0
+ !carry1_1 = (# ab1_1, c1_1 #)
+ -- inner loop, j == 2
+ !axy1_2 = W.add_w# (W.mul_w# (w x1) (w y2)) (w q0)
+ !umc1_2 = W.add_w# (W.mul_w# (w u1) (w m2)) carry1_1
+ !(# (# p1, ab1_2 #), c1_2 #) = W.add_c# axy1_2 umc1_2 -- o1, p1, q0, r0
+ !carry1_2 = (# ab1_2, c1_2 #)
+ -- inner loop, j == 3
+ !axy1_3 = W.add_w# (W.mul_w# (w x1) (w y3)) (w r0)
+ !umc1_3 = W.add_w# (W.mul_w# (w u1) (w m3)) carry1_2
+ !(# (# q1, ab1_3 #), c1_3 #) = W.add_c# axy1_3 umc1_3 -- o1, p1, q1, r0
+ !carry1_3 = (# ab1_3, c1_3 #)
+ -- final stanza
+ !(# r1, mc1 #) = W.add_w# carry1_3 (w mc0) -- o1, p1, q1, r1
+ -- outer loop, i == 2 ---------------------------------------------------
+ !axy2 = W.add_w# (W.mul_w# (w x2) (w y0)) (w o1)
+ !u2 = L.mul_w# (lo axy2) n
+ !(# (# _, a2 #), c2 #) = W.add_c# (W.mul_w# (w u2) (w m0)) axy2
+ !carry2 = (# a2, c2 #)
+ -- inner loop, j == 1
+ !axy2_1 = W.add_w# (W.mul_w# (w x2) (w y1)) (w p1)
+ !umc2_1 = W.add_w# (W.mul_w# (w u2) (w m1)) carry2
+ !(# (# o2, ab2_1 #), c2_1 #) = W.add_c# axy2_1 umc2_1 -- o2, p1, q1, r1
+ !carry2_1 = (# ab2_1, c2_1 #)
+ -- inner loop, j == 2
+ !axy2_2 = W.add_w# (W.mul_w# (w x2) (w y2)) (w q1)
+ !umc2_2 = W.add_w# (W.mul_w# (w u2) (w m2)) carry2_1
+ !(# (# p2, ab2_2 #), c2_2 #) = W.add_c# axy2_2 umc2_2 -- o2, p2, q1, r1
+ !carry2_2 = (# ab2_2, c2_2 #)
+ -- inner loop, j == 3
+ !axy2_3 = W.add_w# (W.mul_w# (w x2) (w y3)) (w r1)
+ !umc2_3 = W.add_w# (W.mul_w# (w u2) (w m3)) carry2_2
+ !(# (# q2, ab2_3 #), c2_3 #) = W.add_c# axy2_3 umc2_3 -- o2, p2, q2, r1
+ !carry2_3 = (# ab2_3, c2_3 #)
+ -- final stanza
+ !(# r2, mc2 #) = W.add_w# carry2_3 (w mc1) -- o2, p2, q2, r2
+ -- outer loop, i == 3 ---------------------------------------------------
+ !axy3 = W.add_w# (W.mul_w# (w x3) (w y0)) (w o2)
+ !u3 = L.mul_w# (lo axy3) n
+ !(# (# _, a3 #), c3 #) = W.add_c# (W.mul_w# (w u3) (w m0)) axy3
+ !carry3 = (# a3, c3 #)
+ -- inner loop, j == 1
+ !axy3_1 = W.add_w# (W.mul_w# (w x3) (w y1)) (w p2)
+ !umc3_1 = W.add_w# (W.mul_w# (w u3) (w m1)) carry3
+ !(# (# o3, ab3_1 #), c3_1 #) = W.add_c# axy3_1 umc3_1 -- o3, p2, q2, r2
+ !carry3_1 = (# ab3_1, c3_1 #)
+ -- inner loop, j == 2
+ !axy3_2 = W.add_w# (W.mul_w# (w x3) (w y2)) (w q2)
+ !umc3_2 = W.add_w# (W.mul_w# (w u3) (w m2)) carry3_1
+ !(# (# p3, ab3_2 #), c3_2 #) = W.add_c# axy3_2 umc3_2 -- o3, p3, q1, r2
+ !carry3_2 = (# ab3_2, c3_2 #)
+ -- inner loop, j == 3
+ !axy3_3 = W.add_w# (W.mul_w# (w x3) (w y3)) (w r2)
+ !umc3_3 = W.add_w# (W.mul_w# (w u3) (w m3)) carry3_2
+ !(# (# q3, ab3_3 #), c3_3 #) = W.add_c# axy3_3 umc3_3 -- o3, p3, q3, r2
+ !carry3_3 = (# ab3_3, c3_3 #)
+ -- final stanza
+ !(# r3, mc3 #) = W.add_w# carry3_3 (w mc2) -- o3, p3, q3, r3
+ in (# (# o3, p3, q3, r3 #), mc3 #)
+
+mul
+ :: Wider -- ^ lhs in montgomery form
+ -> Wider -- ^ rhs in montgomery form
+ -> Wider -- ^ modulus
+ -> Word -- ^ mod neg inv
+ -> Wider -- ^ montgomery product
+mul (Wider a) (Wider b) (Wider m) (W# n) =
+ let !(# nu, mc #) = mul_inner# a b m n
+ in Wider (WW.sub_mod_c# nu mc m m)
diff --git a/lib/Data/Word/Wide.hs b/lib/Data/Word/Wide.hs
@@ -20,7 +20,9 @@ module Data.Word.Wide (
, from
, lo#
+ , get_lo#
, hi#
+ , get_hi#
-- * Bit Manipulation
, or
@@ -33,6 +35,13 @@ module Data.Word.Wide (
, sub
, mul
+ -- * Unboxed Arithmetic
+ , add_c#
+ , add_w#
+ , sub_b#
+ , sub_w#
+ , mul_w#
+
, add_w#
, mul_w#
) where