commit 309e8d49acea74e9a5c9e8516bf019103e80e6cd
parent 3f815531d3702e7bd3a3925f52f1353328799782
Author: Jared Tobin <jared@jtobin.io>
Date: Sat, 14 Sep 2024 18:19:01 +0400
lib: mark unsafe internals, add safety commentary
Diffstat:
1 file changed, 68 insertions(+), 35 deletions(-)
diff --git a/lib/Crypto/Hash/SHA256.hs b/lib/Crypto/Hash/SHA256.hs
@@ -38,21 +38,25 @@ import Foreign.ForeignPtr (plusForeignPtr)
-- preliminary utils
--- keystroke savers
+-- keystroke saver
fi :: (Integral a, Num b) => a -> b
fi = fromIntegral
{-# INLINE fi #-}
--- unsafe parse, strict ByteString to Word32 (verbatim from Data.Binary)
-word32be :: BS.ByteString -> Word32
-word32be s =
+-- parse strict ByteString in BE order to Word32 (verbatim from
+-- Data.Binary)
+--
+-- invariant:
+-- the input bytestring is at least 32 bits in length
+unsafe_word32be :: BS.ByteString -> Word32
+unsafe_word32be s =
(fromIntegral (s `BU.unsafeIndex` 0) `B.unsafeShiftL` 24) .|.
(fromIntegral (s `BU.unsafeIndex` 1) `B.unsafeShiftL` 16) .|.
(fromIntegral (s `BU.unsafeIndex` 2) `B.unsafeShiftL` 8) .|.
(fromIntegral (s `BU.unsafeIndex` 3))
-{-# INLINE word32be #-}
+{-# INLINE unsafe_word32be #-}
--- following are utility types for more efficient ByteString management
+-- utility types for more efficient ByteString management
data SLPair = SLPair {-# UNPACK #-} !BS.ByteString !BL.ByteString
@@ -72,10 +76,13 @@ splitAt64 = splitAt' (64 :: Int) where
-- variant of Data.ByteString.splitAt that behaves like an incremental
-- Word32 parser
-parseWord32 :: BS.ByteString -> WSPair
-parseWord32 (BI.BS x l) =
- WSPair (word32be (BI.BS x 4)) (BI.BS (plusForeignPtr x 4) (l - 4))
-{-# INLINE parseWord32 #-}
+--
+-- invariant:
+-- the input bytestring is at least 32 bits in length
+unsafe_parseWsPair :: BS.ByteString -> WSPair
+unsafe_parseWsPair (BI.BS x l) =
+ WSPair (unsafe_word32be (BI.BS x 4)) (BI.BS (plusForeignPtr x 4) (l - 4))
+{-# INLINE unsafe_parseWsPair #-}
-- message padding and parsing
-- https://datatracker.ietf.org/doc/html/rfc6234#section-4.1
@@ -90,7 +97,6 @@ sol l =
pad :: BS.ByteString -> BS.ByteString
pad m = BL.toStrict . BSB.toLazyByteString $ padded where
l = fi (BS.length m)
-
padded = BSB.byteString m <> fill (sol l) (BSB.word8 0x80)
fill j !acc
@@ -100,12 +106,10 @@ pad m = BL.toStrict . BSB.toLazyByteString $ padded where
-- RFC 6234 4.1 (lazy)
pad_lazy :: BL.ByteString -> BL.ByteString
pad_lazy (BL.toChunks -> m) = BL.fromChunks (walk 0 m) where
- -- walk chunks, calculating length and appending padding
walk !l bs = case bs of
(c:cs) -> c : walk (l + fi (BS.length c)) cs
[] -> padding l (sol l) (BSB.word8 0x80)
- -- construct padding
padding l k bs
| k == 0 =
pure
@@ -198,24 +202,28 @@ data Block = Block {
, m12 :: !Word32, m13 :: !Word32, m14 :: !Word32, m15 :: !Word32
}
-parse :: BS.ByteString -> Block
-parse bs =
- let !(WSPair m00 t00) = parseWord32 bs
- !(WSPair m01 t01) = parseWord32 t00
- !(WSPair m02 t02) = parseWord32 t01
- !(WSPair m03 t03) = parseWord32 t02
- !(WSPair m04 t04) = parseWord32 t03
- !(WSPair m05 t05) = parseWord32 t04
- !(WSPair m06 t06) = parseWord32 t05
- !(WSPair m07 t07) = parseWord32 t06
- !(WSPair m08 t08) = parseWord32 t07
- !(WSPair m09 t09) = parseWord32 t08
- !(WSPair m10 t10) = parseWord32 t09
- !(WSPair m11 t11) = parseWord32 t10
- !(WSPair m12 t12) = parseWord32 t11
- !(WSPair m13 t13) = parseWord32 t12
- !(WSPair m14 t14) = parseWord32 t13
- !(WSPair m15 t15) = parseWord32 t14
+-- parse strict bytestring to block
+--
+-- invariant:
+-- the input bytestring is exactly 512 bits long
+unsafe_parse :: BS.ByteString -> Block
+unsafe_parse bs =
+ let !(WSPair m00 t00) = unsafe_parseWsPair bs
+ !(WSPair m01 t01) = unsafe_parseWsPair t00
+ !(WSPair m02 t02) = unsafe_parseWsPair t01
+ !(WSPair m03 t03) = unsafe_parseWsPair t02
+ !(WSPair m04 t04) = unsafe_parseWsPair t03
+ !(WSPair m05 t05) = unsafe_parseWsPair t04
+ !(WSPair m06 t06) = unsafe_parseWsPair t05
+ !(WSPair m07 t07) = unsafe_parseWsPair t06
+ !(WSPair m08 t08) = unsafe_parseWsPair t07
+ !(WSPair m09 t09) = unsafe_parseWsPair t08
+ !(WSPair m10 t10) = unsafe_parseWsPair t09
+ !(WSPair m11 t11) = unsafe_parseWsPair t10
+ !(WSPair m12 t12) = unsafe_parseWsPair t11
+ !(WSPair m13 t13) = unsafe_parseWsPair t12
+ !(WSPair m14 t14) = unsafe_parseWsPair t13
+ !(WSPair m15 t15) = unsafe_parseWsPair t14
in if BS.null t15
then Block {..}
else error "ppad-sha256: internal error (bytes remaining)"
@@ -357,8 +365,11 @@ step (Registers a b c d e f g h) k w =
in Registers (t1 + t2) a b c (d + t1) e f g
-- RFC 6234 6.2 block pipeline
-hash_alg :: Registers -> BS.ByteString -> Registers
-hash_alg rs bs = block_hash rs (prepare_schedule (parse bs))
+--
+-- invariant:
+-- the input bytestring is exactly 512 bits in length
+unsafe_hash_alg :: Registers -> BS.ByteString -> Registers
+unsafe_hash_alg rs bs = block_hash rs (prepare_schedule (unsafe_parse bs))
-- register concatenation
cat :: Registers -> BS.ByteString
@@ -380,11 +391,32 @@ cat Registers {..} =
-- "<strict 256-bit message digest>"
hash :: BS.ByteString -> BS.ByteString
hash bs = cat (go iv (pad bs)) where
+ -- proof that 'go' always terminates safely:
+ --
+ -- let b = pad bs
+ -- then length(b) = n * 512 bits for some n >= 0 (1)
go :: Registers -> BS.ByteString -> Registers
go !acc b
+ -- if n == 0, then 'go' terminates safely (2)
| BS.null b = acc
+ -- if n > 0, then
+ --
+ -- let (c, r) = BS.splitAt 64 b
+ -- then length(c) == 512 bits by (1)
+ -- length(r) == m * 512 bits for some m >= 0 by (1)
+ --
+ -- note 'unsafe_hash_alg' terminates safely for bytestring (3)
+ -- input of exactly 512 bits in length
+ --
+ -- length(c) == 512
+ -- => 'unsafe_hash_alg' terminates safely by (3)
+ -- => 'go' terminates safely (4)
+ -- length(r) == m * 512 bits for m >= 0
+ -- => next invocation of 'go' terminates safely by (2), (4)
+ --
+ -- then by induction, 'go' always terminates safely (QED)
| otherwise = case BS.splitAt 64 b of
- (c, r) -> go (hash_alg acc c) r
+ (c, r) -> go (unsafe_hash_alg acc c) r
-- | Compute a condensed representation of a lazy bytestring via
-- SHA-256.
@@ -395,11 +427,12 @@ hash bs = cat (go iv (pad bs)) where
-- "<strict 256-bit message digest>"
hash_lazy :: BL.ByteString -> BS.ByteString
hash_lazy bl = cat (go iv (pad_lazy bl)) where
+ -- proof of safety proceeds analogously
go :: Registers -> BL.ByteString -> Registers
go !acc bs
| BL.null bs = acc
| otherwise = case splitAt64 bs of
- SLPair c r -> go (hash_alg acc c) r
+ SLPair c r -> go (unsafe_hash_alg acc c) r
-- HMAC
-- https://datatracker.ietf.org/doc/html/rfc2104#section-2