commit 4efb694413988ff7322930e0ea88adf87e9173b7
parent 75cdb8d4295292c8c5af41084f9771e02367e024
Author: Jared Tobin <jared@jtobin.io>
Date: Fri, 31 Oct 2025 11:26:29 +0400
lib: montgomery retrieval first stab
Diffstat:
2 files changed, 125 insertions(+), 12 deletions(-)
diff --git a/lib/Data/Word/Montgomery.hs b/lib/Data/Word/Montgomery.hs
@@ -19,12 +19,11 @@ import qualified Data.Word.Wider as L
import GHC.Exts
import Prelude hiding (div, mod, or, and, not, quot, rem, recip)
--- XX my eyes, it burns
-
+-- reference 'montgomery_reduction_inner'
redc_inner#
- :: (# Word#, Word#, Word#, Word# #) -- upper
- -> (# Word#, Word#, Word#, Word# #) -- lower
- -> (# Word#, Word#, Word#, Word# #) -- modulus
+ :: (# Word#, Word#, Word#, Word# #) -- upper
+ -> (# Word#, Word#, Word#, Word# #) -- lower
+ -> (# Word#, Word#, Word#, Word# #) -- modulus
-> Word#
-> (# (# Word#, Word#, Word#, Word# #), Word# #) -- upper, meta-carry
redc_inner#
@@ -83,4 +82,81 @@ redc_inner#
-- (# l0, l0_1, l1_1, l2_1 #)
-- (# u3_1, u3_2, u3_3, u_3 #)
in (# (# u3_1, u3_2, u3_3, u_3 #), mc_3 #)
+{-# INLINE redc_inner# #-}
+
+redc#
+ :: (# Word#, Word#, Word#, Word# #) -- lower
+ -> (# Word#, Word#, Word#, Word# #) -- upper
+ -> (# Word#, Word#, Word#, Word# #) -- modulus
+ -> Word# -- mod neg inv
+ -> (# Word#, Word#, Word#, Word# #)
+redc# l u m n =
+ let !(# nu, mc #) = redc_inner# u l m n
+ in L.sub_mod_c# nu mc m m
+{-# INLINE redc# #-}
+
+redc :: L.Wider -> L.Wider -> L.Wider -> Word -> L.Wider
+redc (L.Wider l) (L.Wider u) (L.Wider m) (W# n) =
+ let !res = redc# l u m n
+ in (L.Wider res)
+
+-- reference 'montgomery_retrieve_inner'
+retr_inner#
+ :: (# Word#, Word#, Word#, Word# #) -- x
+ -> (# Word#, Word#, Word#, Word# #) -- out
+ -> (# Word#, Word#, Word#, Word# #) -- modulus
+ -> Word# -- mod neg inv
+ -> (# Word#, Word#, Word#, Word# #)
+retr_inner#
+ (# x0, x1, x2, x3 #)
+ (# o, p, q, r #)
+ (# m0, m1, m2, m3 #)
+ mninv =
+ let -- outer loop, i == 0 ---------------------------------------------------
+ !u_0 = L.mul_w# (plusWord# o x0) mninv -- out state
+ !(# _, o0 #) = L.mul_add_c# u_0 m0 x0 o -- o0, p, q, r
+ -- inner loop
+ !(# o0_1, p0_1 #) = L.mul_add_c# u_0 m1 p o0 -- o0_1, p0_1, q, r
+ !(# p0_2, q0_2 #) = L.mul_add_c# u_0 m2 q p0_1 -- o0_1, p0_2, q0_2, r
+ !(# q0_3, r0_3 #) = L.mul_add_c# u_0 m3 r q0_2 -- o0_1, p0_2, q0_3, r0_3
+ -- end state: (# o0_1, p0_2, q0_3, r0_3 #)
+ -- outer loop, i == 1 ---------------------------------------------------
+ !u_1 = L.mul_w# (plusWord# o0_1 x1) mninv
+ !(# _, o1 #) = L.mul_add_c# u_1 m0 x1 o0_1 -- o1, p0_2, q0_3, r0_3
+ -- inner loop
+ !(# o1_1, p1_1 #) = L.mul_add_c# u_1 m1 p0_2 o1 -- o1_1, p1_1, q0_3, r0_3
+ !(# p1_2, q1_2 #) = L.mul_add_c# u_1 m2 q0_3 p1_1 -- o1_1, p1_2, q1_2, r0_3
+ !(# q1_3, r1_3 #) = L.mul_add_c# u_1 m3 r0_3 q1_2 -- o1_1, p1_2, q1_3, r1_3
+ -- end state: (# o1_1, p1_2, q1_3, r1_3 #)
+ -- outer loop, i == 2 ---------------------------------------------------
+ !u_2 = L.mul_w# (plusWord# o1_1 x2) mninv
+ !(# _, o2 #) = L.mul_add_c# u_2 m0 x2 o1_1 -- o2, p1_2, q1_3, r1_3
+ -- inner loop
+ !(# o2_1, p2_1 #) = L.mul_add_c# u_2 m1 p1_2 o2 -- o2_1, p2_1, q1_3, r1_3
+ !(# p2_2, q2_2 #) = L.mul_add_c# u_2 m2 q1_3 p2_1 -- o2_1, p2_2, q2_2, r1_3
+ !(# q2_3, r2_3 #) = L.mul_add_c# u_2 m3 r1_3 q2_2 -- o2_1, p2_2, q2_3, r2_3
+ -- end state: (# o2_1, p2_2, q2_3, r2_3 #)
+ -- outer loop, i == 3 ---------------------------------------------------
+ !u_3 = L.mul_w# (plusWord# o2_1 x3) mninv
+ !(# _, o3 #) = L.mul_add_c# u_3 m0 x3 o2_1 -- o3, p2_2, q2_3, r2_3
+ -- inner loop
+ !(# o3_1, p3_1 #) = L.mul_add_c# u_3 m1 p2_2 o3 -- o3_1, p3_1, q2_3, r2_3
+ !(# p3_2, q3_2 #) = L.mul_add_c# u_3 m2 q2_3 p3_1 -- o3_1, p3_2, q3_2, r2_3
+ !(# q3_3, r3_3 #) = L.mul_add_c# u_3 m3 r2_3 q3_2 -- o3_1, p3_2, q3_3, r3_3
+ -- final state: (# o3_1, p3_2, q3_3, r3_3 #)
+ in (# o3_1, p3_2, q3_3, r3_3 #)
+{-# INLINE retr_inner# #-}
+
+retr#
+ :: (# Word#, Word#, Word#, Word# #) -- montgomery form
+ -> (# Word#, Word#, Word#, Word# #) -- modulus
+ -> Word# -- mod neg inv
+ -> (# Word#, Word#, Word#, Word# #)
+retr# f m n = retr_inner# f (# 0##, 0##, 0##, 0## #) m n
+{-# INLINE retr# #-}
+
+retr :: L.Wider -> L.Wider -> Word -> L.Wider
+retr (L.Wider f) (L.Wider m) (W# n) =
+ let !res = retr# f m n
+ in (L.Wider res)
diff --git a/lib/Data/Word/Wider.hs b/lib/Data/Word/Wider.hs
@@ -66,19 +66,56 @@ from (Wider (# w0, w1, w2, w3 #)) =
where
!size = B.finiteBitSize (0 :: Word)
--- subtract-with-overflow
-sub_of#
+-- wider-add-with-carry, i.e. (# sum, carry bit #)
+add_wc#
:: (# Word#, Word#, Word#, Word# #)
-> (# Word#, Word#, Word#, Word# #)
-> (# Word#, Word#, Word#, Word#, Word# #)
-sub_of# (# a0, a1, a2, a3 #)
- (# b0, b1, b2, b3 #) =
+add_wc# (# a0, a1, a2, a3 #) (# b0, b1, b2, b3 #) =
+ let !(# s0, c0 #) = L.add_c# a0 b0 0##
+ !(# s1, c1 #) = L.add_c# a1 b1 c0
+ !(# s2, c2 #) = L.add_c# a2 b2 c1
+ !(# s3, c3 #) = L.add_c# a3 b3 c2
+ in (# s0, s1, s2, s3, c3 #)
+{-# INLINE add_wc# #-}
+
+-- wider addition (wrapping)
+add_w#
+ :: (# Word#, Word#, Word#, Word# #)
+ -> (# Word#, Word#, Word#, Word# #)
+ -> (# Word#, Word#, Word#, Word# #)
+add_w# a b =
+ let !(# c0, c1, c2, c3, _ #) = add_wc# a b
+ in (# c0, c1, c2, c3 #)
+{-# INLINE add_w# #-}
+
+-- reference: borrowing_sub
+sub_b#
+ :: (# Word#, Word#, Word#, Word# #)
+ -> (# Word#, Word#, Word#, Word# #)
+ -> (# Word#, Word#, Word#, Word#, Word# #)
+sub_b# (# a0, a1, a2, a3 #) (# b0, b1, b2, b3 #) =
let !(# s0, c0 #) = L.sub_b# a0 b0 0##
!(# s1, c1 #) = L.sub_b# a1 b1 c0
!(# s2, c2 #) = L.sub_b# a2 b2 c1
!(# s3, c3 #) = L.sub_b# a3 b3 c2
in (# s0, s1, s2, s3, c3 #)
-{-# INLINE sub_of# #-}
-
-
+{-# INLINE sub_b# #-}
+
+-- reference sub_mod_with_carry
+sub_mod_c#
+ :: (# Word#, Word#, Word#, Word# #) -- lhs
+ -> Word# -- carry
+ -> (# Word#, Word#, Word#, Word# #) -- rhs
+ -> (# Word#, Word#, Word#, Word# #) -- p
+ -> (# Word#, Word#, Word#, Word# #)
+sub_mod_c# a c b (# p0, p1, p2, p3 #) =
+ let !(# o0, o1, o2, o3, borrow #) = sub_b# a b
+ !mask = and# (not# (C.wrapping_neg# c)) borrow
+ !band = (# plusWord# p0 mask
+ , plusWord# p1 mask
+ , plusWord# p2 mask
+ , plusWord# p3 mask #)
+ in add_w# (# o0, o1, o2, o3 #) band
+{-# INLINE sub_mod_c# #-}