commit c1cf18ddc2d718f6d7fedefaa34e4b3d15261fb4
parent ed8b10400ae32927c139d8e63b8749489f6250e3
Author: Jared Tobin <jared@jtobin.io>
Date: Sun, 30 Nov 2025 16:02:39 +0400
lib: borrow handling fixes
These handled the previous bits returned by borrowing subtraction calls,
which have since been changed to masks.
Diffstat:
3 files changed, 101 insertions(+), 26 deletions(-)
diff --git a/lib/Data/Choice.hs b/lib/Data/Choice.hs
@@ -26,6 +26,7 @@ module Data.Choice (
, expect_wide_or#
-- * Construction
+ , from_word_mask#
, from_word_lsb#
, from_word_nonzero#
, from_word_eq#
@@ -129,7 +130,6 @@ true# _ = case maxBound :: Word of
W# w -> Choice w
{-# INLINE true# #-}
--- XX this is probably stupid. check
decide :: Choice -> Bool
decide (Choice c) = isTrue# (neWord# c 0##)
{-# INLINE decide #-}
@@ -182,6 +182,10 @@ expect_wide_or# (MaybeWide# (# w, Choice c #)) alt
-- construction ---------------------------------------------------------------
+from_word_mask# :: Word# -> Choice
+from_word_mask# w = Choice w
+{-# INLINE from_word_mask# #-}
+
from_word_lsb# :: Word# -> Choice
from_word_lsb# w = Choice (wrapping_neg# w)
{-# INLINE from_word_lsb# #-}
diff --git a/lib/Data/Word/Wider.hs b/lib/Data/Word/Wider.hs
@@ -42,18 +42,6 @@ data Wider = Wider !(# Limb, Limb, Limb, Limb #)
instance Show Wider where
show = show . from
-instance Eq Wider where
- Wider a == Wider b =
- let !(# Limb a0, Limb a1, Limb a2, Limb a3 #) = a
- !(# Limb b0, Limb b1, Limb b2, Limb b3 #) = b
- in C.decide (C.ct_eq_wider# (# a0, a1, a2, a3 #) (# b0, b1, b2, b3 #)) -- XX sane?
-
-instance Ord Wider where
- compare (Wider a) (Wider b) = case cmp# a b of -- XX sane?
- 1# -> GT
- 0# -> EQ
- _ -> LT
-
instance Num Wider where
(+) = add
(-) = sub
@@ -61,9 +49,9 @@ instance Num Wider where
abs = id
fromInteger = to
negate w = add (not w) (Wider (# Limb 1##, Limb 0##, Limb 0##, Limb 0## #))
- signum a
- | a == Wider (# Limb 0##, Limb 0##, Limb 0##, Limb 0## #) = 0
- | otherwise = 1
+ signum a = case a of
+ Wider (# Limb 0##, Limb 0##, Limb 0##, Limb 0## #) -> 0
+ _ -> 1
instance NFData Wider where
rnf (Wider a) = case a of
@@ -71,6 +59,16 @@ instance NFData Wider where
-- comparison -----------------------------------------------------------------
+eq#
+ :: (# Limb, Limb, Limb, Limb #)
+ -> (# Limb, Limb, Limb, Limb #)
+ -> C.Choice
+eq# a b =
+ let !(# Limb a0, Limb a1, Limb a2, Limb a3 #) = a
+ !(# Limb b0, Limb b1, Limb b2, Limb b3 #) = b
+ in C.ct_eq_wider# (# a0, a1, a2, a3 #) (# b0, b1, b2, b3 #)
+{-# INLINE eq# #-}
+
-- | Compare 'Wider' words for equality in variable time.
eq_vartime :: Wider -> Wider -> Bool
eq_vartime a b =
@@ -86,16 +84,18 @@ lt#
-> (# Limb, Limb, Limb, Limb #)
-> C.Choice
lt# a b =
- let !(# _, Limb bit #) = sub_b# a b
- in C.from_word_lsb# bit
+ let !(# _, Limb bor #) = sub_b# a b
+ in C.from_word_mask# bor
+{-# INLINE lt# #-}
gt#
:: (# Limb, Limb, Limb, Limb #)
-> (# Limb, Limb, Limb, Limb #)
-> C.Choice
gt# a b =
- let !(# _, Limb bit #) = sub_b# b a
- in C.from_word_lsb# bit
+ let !(# _, Limb bor #) = sub_b# b a
+ in C.from_word_mask# bor
+{-# INLINE gt# #-}
cmp#
:: (# Limb, Limb, Limb, Limb #)
@@ -110,11 +110,18 @@ cmp# (# l0, l1, l2, l3 #) (# r0, r1, r2, r3 #) =
!d2 = L.or# d1 w2
!(# w3, b3 #) = L.sub_b# r3 l3 b2
!d3 = L.or# d2 w3
- !(Limb w) = L.shl# b3 1#
+ !(Limb w) = L.and# b3 (Limb 2##)
!s = word2Int# w -# 1#
in (word2Int# (C.to_word# (L.nonzero# d3))) *# s
{-# INLINE cmp# #-}
+-- | Constant-time comparison between 'Wider' words.
+cmp :: Wider -> Wider -> Ordering
+cmp (Wider a) (Wider b) = case cmp# a b of
+ 1# -> GT
+ 0# -> EQ
+ _ -> LT
+
-- construction / conversion --------------------------------------------------
-- | Construct a 'Wider' word from four 'Words', provided in
@@ -271,11 +278,11 @@ add_mod# a b m =
{-# INLINE add_mod# #-}
-- | Borrowing subtraction, computing 'a - b' and returning the
--- difference with a borrow bit.
+-- difference with a borrow mask.
sub_b#
:: (# Limb, Limb, Limb, Limb #) -- ^ minuend
-> (# Limb, Limb, Limb, Limb #) -- ^ subtrahend
- -> (# (# Limb, Limb, Limb, Limb #), Limb #) -- ^ (# diff, borrow bit #)
+ -> (# (# Limb, Limb, Limb, Limb #), Limb #) -- ^ (# diff, borrow mask #)
sub_b# (# a0, a1, a2, a3 #) (# b0, b1, b2, b3 #) =
let !(# s0, c0 #) = L.sub_b# a0 b0 (Limb 0##)
!(# s1, c1 #) = L.sub_b# a1 b1 c0
@@ -299,8 +306,7 @@ sub_mod#
-> (# Limb, Limb, Limb, Limb #) -- ^ modulus
-> (# Limb, Limb, Limb, Limb #) -- ^ difference
sub_mod# a b (# p0, p1, p2, p3 #) =
- let !(# (# o0, o1, o2, o3 #), bb #) = sub_b# a b
- !m = L.neg# bb
+ let !(# (# o0, o1, o2, o3 #), m #) = sub_b# a b
!ba = (# L.and# p0 m, L.and# p1 m, L.and# p2 m, L.and# p3 m #)
in add_w# (# o0, o1, o2, o3 #) ba
{-# INLINE sub_mod# #-}
@@ -314,7 +320,7 @@ sub_mod_c#
-> (# Limb, Limb, Limb, Limb #) -- ^ difference
sub_mod_c# a c b (# p0, p1, p2, p3 #) =
let !(# (# o0, o1, o2, o3 #), bb #) = sub_b# a b
- !m = L.and# (L.not# (L.neg# c)) (L.neg# bb)
+ !m = L.and# (L.not# (L.neg# c)) bb
!ba = (# L.and# p0 m, L.and# p1 m, L.and# p2 m, L.and# p3 m #)
in add_w# (# o0, o1, o2, o3 #) ba
{-# INLINE sub_mod_c# #-}
diff --git a/test/Wider.hs b/test/Wider.hs
@@ -6,6 +6,7 @@ module Wider (
tests
) where
+import qualified Data.Choice as C
import qualified Data.Word.Wider as W
import Test.Tasty
import qualified Test.Tasty.HUnit as H
@@ -32,12 +33,76 @@ wrapping_add_with_carry = do
let !r = W.add (2 ^ (256 :: Word) - 1) 1
H.assertBool mempty (W.eq_vartime r 0)
+eq :: H.Assertion
+eq = do
+ let !(W.Wider a) = 0
+ !(W.Wider b) = 2 ^ (256 :: Word) - 1
+ H.assertBool mempty (C.decide (W.eq# a a))
+ H.assertBool mempty (not (C.decide (W.eq# a b)))
+ H.assertBool mempty (not (C.decide (W.eq# b a)))
+ H.assertBool mempty (C.decide (W.eq# b b))
+
+gt :: H.Assertion
+gt = do
+ let !(W.Wider a) = 0
+ !(W.Wider b) = 1
+ !(W.Wider c) = 2 ^ (256 :: Word) - 1
+ H.assertBool mempty (C.decide (W.gt# b a))
+ H.assertBool mempty (C.decide (W.gt# c a))
+ H.assertBool mempty (C.decide (W.gt# c b))
+
+ H.assertBool mempty (not (C.decide (W.gt# a a)))
+ H.assertBool mempty (not (C.decide (W.gt# b b)))
+ H.assertBool mempty (not (C.decide (W.gt# c c)))
+
+ H.assertBool mempty (not (C.decide (W.gt# a b)))
+ H.assertBool mempty (not (C.decide (W.gt# a c)))
+ H.assertBool mempty (not (C.decide (W.gt# b c)))
+
+lt :: H.Assertion
+lt = do
+ let !(W.Wider a) = 0
+ !(W.Wider b) = 1
+ !(W.Wider c) = 2 ^ (256 :: Word) - 1
+ H.assertBool mempty (C.decide (W.lt# a b))
+ H.assertBool mempty (C.decide (W.lt# a c))
+ H.assertBool mempty (C.decide (W.lt# b c))
+
+ H.assertBool mempty (not (C.decide (W.lt# a a)))
+ H.assertBool mempty (not (C.decide (W.lt# b b)))
+ H.assertBool mempty (not (C.decide (W.lt# c c)))
+
+ H.assertBool mempty (not (C.decide (W.lt# b a)))
+ H.assertBool mempty (not (C.decide (W.lt# c a)))
+ H.assertBool mempty (not (C.decide (W.lt# c b)))
+
+cmp :: H.Assertion
+cmp = do
+ let !a = 0
+ !b = 1
+ !c = 2 ^ (256 :: Word) - 1
+ H.assertEqual mempty (W.cmp a b) LT
+ H.assertEqual mempty (W.cmp a c) LT
+ H.assertEqual mempty (W.cmp b c) LT
+
+ H.assertEqual mempty (W.cmp a a) EQ
+ H.assertEqual mempty (W.cmp b b) EQ
+ H.assertEqual mempty (W.cmp c c) EQ
+
+ H.assertEqual mempty (W.cmp b a) GT
+ H.assertEqual mempty (W.cmp c a) GT
+ H.assertEqual mempty (W.cmp c b) GT
+
tests :: TestTree
tests = testGroup "wider tests" [
H.testCase "overflowing add, no carry" overflowing_add_no_carry
, H.testCase "overflowing add, carry" overflowing_add_with_carry
, H.testCase "wrapping add, no carry" wrapping_add_no_carry
, H.testCase "wrapping add, carry" wrapping_add_with_carry
+ , H.testCase "eq" eq
+ , H.testCase "gt" gt
+ , H.testCase "lt" lt
+ , H.testCase "cmp" cmp
]