commit dd46cb2efd48fcfc0ceeef1684dbba0cda0b8699
parent 4efb694413988ff7322930e0ea88adf87e9173b7
Author: Jared Tobin <jared@jtobin.io>
Date: Fri, 31 Oct 2025 16:43:03 +0400
lib: sub_mod_c# fixes
Need bitwise and, and also to convert the borrow bit to a mask.
Diffstat:
2 files changed, 14 insertions(+), 17 deletions(-)
diff --git a/lib/Data/Word/Montgomery.hs b/lib/Data/Word/Montgomery.hs
@@ -30,10 +30,10 @@ redc_inner#
(# u0, u1, u2, u3 #)
(# l0, l1, l2, l3 #)
(# m0, m1, m2, m3 #)
- mninv =
+ n =
let -- outer loop, i == 0 ---------------------------------------------------
- !w_0 = L.mul_w# l0 mninv
- !(# _, c_00 #) = L.mul_add_c# w_0 m0 l0 0## -- m0, l0
+ !w_0 = L.mul_w# l0 n
+ !(# _, c_00 #) = L.mul_add_c# w_0 m0 l0 0## -- m0, l0
-- first inner loop (j < 4)
!(# l0_1, c_01 #) = L.mul_add_c# w_0 m1 l1 c_00 -- l<i idx>_<j idx>
!(# l0_2, c_02 #) = L.mul_add_c# w_0 m2 l2 c_01
@@ -44,7 +44,7 @@ redc_inner#
-- (# l0, l0_1, l0_2, l0_3 #)
-- (# u_0, u1, u2, u3 #)
-- outer loop, i == 1 ---------------------------------------------------
- !w_1 = L.mul_w# l0_1 mninv
+ !w_1 = L.mul_w# l0_1 n
!(# _, c_10 #) = L.mul_add_c# w_1 m0 l0_1 0##
-- first inner loop (j < 3)
!(# l1_1, c_11 #) = L.mul_add_c# w_1 m1 l0_2 c_10 -- j == 1
@@ -57,7 +57,7 @@ redc_inner#
-- (# l0, l0_1, l1_1, l1_2 #)
-- (# u1_3, u_1, u2, u3 #)
-- outer loop, i == 2 ---------------------------------------------------
- !w_2 = L.mul_w# l1_1 mninv
+ !w_2 = L.mul_w# l1_1 n
!(# _, c_20 #) = L.mul_add_c# w_2 m0 l1_1 0##
-- first inner loop (j < 2)
!(# l2_1, c_21 #) = L.mul_add_c# w_2 m1 l1_2 c_20 -- j == 1
@@ -70,7 +70,7 @@ redc_inner#
-- (# l0, l0_1, l1_1, l2_1 #)
-- (# u2_2, u2_3, u_2, u3 #)
-- outer loop, i == 3 ---------------------------------------------------
- !w_3 = L.mul_w# l2_1 mninv
+ !w_3 = L.mul_w# l2_1 n
!(# _, c_30 #) = L.mul_add_c# w_3 m0 l2_1 0##
-- second inner loop (j < 4)
!(# u3_1, c_31 #) = L.mul_add_c# w_3 m1 u2_2 c_30 -- j == 1
@@ -111,9 +111,9 @@ retr_inner#
(# x0, x1, x2, x3 #)
(# o, p, q, r #)
(# m0, m1, m2, m3 #)
- mninv =
+ n =
let -- outer loop, i == 0 ---------------------------------------------------
- !u_0 = L.mul_w# (plusWord# o x0) mninv -- out state
+ !u_0 = L.mul_w# (plusWord# o x0) n -- out state
!(# _, o0 #) = L.mul_add_c# u_0 m0 x0 o -- o0, p, q, r
-- inner loop
!(# o0_1, p0_1 #) = L.mul_add_c# u_0 m1 p o0 -- o0_1, p0_1, q, r
@@ -121,7 +121,7 @@ retr_inner#
!(# q0_3, r0_3 #) = L.mul_add_c# u_0 m3 r q0_2 -- o0_1, p0_2, q0_3, r0_3
-- end state: (# o0_1, p0_2, q0_3, r0_3 #)
-- outer loop, i == 1 ---------------------------------------------------
- !u_1 = L.mul_w# (plusWord# o0_1 x1) mninv
+ !u_1 = L.mul_w# (plusWord# o0_1 x1) n
!(# _, o1 #) = L.mul_add_c# u_1 m0 x1 o0_1 -- o1, p0_2, q0_3, r0_3
-- inner loop
!(# o1_1, p1_1 #) = L.mul_add_c# u_1 m1 p0_2 o1 -- o1_1, p1_1, q0_3, r0_3
@@ -129,7 +129,7 @@ retr_inner#
!(# q1_3, r1_3 #) = L.mul_add_c# u_1 m3 r0_3 q1_2 -- o1_1, p1_2, q1_3, r1_3
-- end state: (# o1_1, p1_2, q1_3, r1_3 #)
-- outer loop, i == 2 ---------------------------------------------------
- !u_2 = L.mul_w# (plusWord# o1_1 x2) mninv
+ !u_2 = L.mul_w# (plusWord# o1_1 x2) n
!(# _, o2 #) = L.mul_add_c# u_2 m0 x2 o1_1 -- o2, p1_2, q1_3, r1_3
-- inner loop
!(# o2_1, p2_1 #) = L.mul_add_c# u_2 m1 p1_2 o2 -- o2_1, p2_1, q1_3, r1_3
@@ -137,7 +137,7 @@ retr_inner#
!(# q2_3, r2_3 #) = L.mul_add_c# u_2 m3 r1_3 q2_2 -- o2_1, p2_2, q2_3, r2_3
-- end state: (# o2_1, p2_2, q2_3, r2_3 #)
-- outer loop, i == 3 ---------------------------------------------------
- !u_3 = L.mul_w# (plusWord# o2_1 x3) mninv
+ !u_3 = L.mul_w# (plusWord# o2_1 x3) n
!(# _, o3 #) = L.mul_add_c# u_3 m0 x3 o2_1 -- o3, p2_2, q2_3, r2_3
-- inner loop
!(# o3_1, p3_1 #) = L.mul_add_c# u_3 m1 p2_2 o3 -- o3_1, p3_1, q2_3, r2_3
diff --git a/lib/Data/Word/Wider.hs b/lib/Data/Word/Wider.hs
@@ -110,12 +110,9 @@ sub_mod_c#
-> (# Word#, Word#, Word#, Word# #) -- p
-> (# Word#, Word#, Word#, Word# #)
sub_mod_c# a c b (# p0, p1, p2, p3 #) =
- let !(# o0, o1, o2, o3, borrow #) = sub_b# a b
- !mask = and# (not# (C.wrapping_neg# c)) borrow
- !band = (# plusWord# p0 mask
- , plusWord# p1 mask
- , plusWord# p2 mask
- , plusWord# p3 mask #)
+ let !(# o0, o1, o2, o3, bb #) = sub_b# a b
+ !mask = and# (not# (C.wrapping_neg# c)) (C.wrapping_neg# bb)
+ !band = (# and# p0 mask, and# p1 mask, and# p2 mask, and# p3 mask #)
in add_w# (# o0, o1, o2, o3 #) band
{-# INLINE sub_mod_c# #-}