commit 3709c4147b02303bb380caee5786b5c9c7b5e335
parent 44bb38cf9de6b30cdc984915aa04bd14aa6f34af
Author: Jared Tobin <jared@jtobin.io>
Date: Sat, 12 Jul 2025 09:51:42 -0230
lib: fix multi-word shift bug
Diffstat:
2 files changed, 23 insertions(+), 2 deletions(-)
diff --git a/lib/Data/Choice.hs b/lib/Data/Choice.hs
@@ -19,6 +19,7 @@ module Data.Choice (
, just_wide#
, none_wide#
, expect_wide#
+ , expect_wide_or#
-- * Construction
, from_word_lsb#
@@ -151,6 +152,15 @@ expect_wide# (MaybeWide# (# w, Choice c #)) msg
| otherwise = error $ "ppad-fixed (expect_wide#): " <> msg
where
!(Choice t#) = true# ()
+{-# INLINE expect_wide# #-}
+
+expect_wide_or# :: MaybeWide# -> (# Word#, Word# #) -> (# Word#, Word# #)
+expect_wide_or# (MaybeWide# (# w, Choice c #)) alt
+ | isTrue# (eqWord# c t#) = w
+ | otherwise = alt
+ where
+ !(Choice t#) = true# ()
+{-# INLINE expect_wide_or# #-}
-- construction ---------------------------------------------------------------
diff --git a/lib/Data/Word/Wide.hs b/lib/Data/Word/Wide.hs
@@ -25,6 +25,7 @@ module Data.Word.Wide (
, xor
, not
, shr
+ , unchecked_shr
-- * Arithmetic
, add
@@ -157,8 +158,8 @@ shr_of_vartime# (# l, h #) s
!l_0 = or# shf car
in (# l_0, h_0 #)
1# ->
- let !l_0 = uncheckedShiftRL# l rem
- in (# l_0, h #)
+ let !l_0 = uncheckedShiftRL# h rem
+ in (# l_0, 0## #)
2# ->
(# l, h #)
_ -> error "ppad-fixed (shr_of_vartime#): internal error"
@@ -201,9 +202,19 @@ shr# :: (# Word#, Word# #) -> Int# -> (# Word#, Word# #)
shr# w s = C.expect_wide# (shr_of# w s) "invalid shift"
{-# INLINE shr# #-}
+-- wrapping
+unchecked_shr# :: (# Word#, Word# #) -> Int# -> (# Word#, Word# #)
+unchecked_shr# w s = C.expect_wide_or# (shr_of# w s) (# 0##, 0## #)
+{-# INLINE unchecked_shr# #-}
+
+-- constant-time shr, ErrorCall on invalid shift
shr :: Wide -> Int -> Wide
shr (Wide w) (I# s) = Wide (shr# w s)
+-- constant-time shr, saturating
+unchecked_shr :: Wide -> Int -> Wide
+unchecked_shr (Wide w) (I# s) = Wide (unchecked_shr# w s)
+
-- addition, subtraction ------------------------------------------------------
-- wide-add-with-carry, i.e. (# sum, carry bit #)