commit 12e1863607af0f8ace4e26f766eadb64a01de79e
parent a1f7f23d728804d317b693f1aa9a460753847a8d
Author: Jared Tobin <jared@jtobin.io>
Date:   Mon, 10 Mar 2025 15:49:16 +0400
lib: encrypt passing
Diffstat:
2 files changed, 83 insertions(+), 2 deletions(-)
diff --git a/lib/Crypto/AEAD/ChaCha20Poly1305.hs b/lib/Crypto/AEAD/ChaCha20Poly1305.hs
@@ -1,11 +1,38 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE LambdaCase #-}
 {-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE ViewPatterns #-}
 
 module Crypto.AEAD.ChaCha20Poly1305 where
 
-import qualified Data.ByteString as BS
-import qualified Data.ByteString.Internal as BI
 import qualified Crypto.Cipher.ChaCha20 as ChaCha20
 import qualified Crypto.MAC.Poly1305 as Poly1305
+import Data.Bits ((.>>.))
+import qualified Data.ByteString as BS
+import qualified Data.ByteString.Internal as BI
+import Data.Word (Word64)
+
+fi :: (Integral a, Num b) => a -> b
+fi = fromIntegral
+{-# INLINE fi #-}
+
+-- little-endian bytestring encoding
+unroll :: Word64 -> BS.ByteString
+unroll i = case i of
+    0 -> BS.singleton 0
+    _ -> BS.unfoldr coalg i
+  where
+    coalg = \case
+      0 -> Nothing
+      m -> Just $! (fi m, m .>>. 8)
+{-# INLINE unroll #-}
+
+-- little-endian bytestring encoding for 64-bit ints, right-padding with zeros
+unroll8 :: Word64 -> BS.ByteString
+unroll8 (unroll -> u@(BI.PS _ _ l))
+  | l < 8 = u <> BS.replicate (8 - l) 0
+  | otherwise = u
+{-# INLINE unroll8 #-}
 
 poly1305_key_gen
   :: BS.ByteString -- ^ 256-bit initial keying material
@@ -14,3 +41,30 @@ poly1305_key_gen
 poly1305_key_gen key@(BI.PS _ _ l) nonce
   | l /= 32   = error "ppad-aead (poly1305_key_gen): invalid key"
   | otherwise = BS.take 32 (ChaCha20.block key 0 nonce)
+{-# INLINEABLE poly1305_key_gen #-}
+
+pad16 :: BS.ByteString -> BS.ByteString
+pad16 (BI.PS _ _ l) = BS.replicate (16 - l `rem` 16) 0
+
+encrypt
+  :: BS.ByteString -- ^ arbitrary-length additional authenticated data
+  -> BS.ByteString -- ^ 256-bit key
+  -> BS.ByteString -- ^ 64-bit initial value (IV)
+  -> BS.ByteString -- ^ 32-bit salt
+  -> BS.ByteString -- ^ arbitrary-length plaintext
+  -> (BS.ByteString, BS.ByteString) -- ^ (ciphertext, MAC)
+encrypt aad key iv salt plaintext
+  | BS.length key  /= 32 = error "ppad-aead (encrypt): invalid key"
+  | BS.length iv   /= 8  = error "ppad-aead (encrypt): invalid IV"
+  | BS.length salt /= 4  = error "ppad-aead (encrypt): invalid salt"
+  | otherwise =
+      let nonce = salt <> iv
+          otk   = poly1305_key_gen key nonce
+          ciphertext = ChaCha20.cipher key 1 nonce plaintext
+          md0 = aad <> pad16 aad
+          md1 = md0 <> ciphertext <> pad16 ciphertext
+          md2 = md1 <> unroll8 (fi (BS.length aad))
+          md3 = md2 <> unroll8 (fi (BS.length ciphertext))
+          tag = Poly1305.mac otk md3
+      in  (ciphertext, tag)
+
diff --git a/test/Main.hs b/test/Main.hs
@@ -5,13 +5,16 @@
 module Main where
 
 import qualified Crypto.AEAD.ChaCha20Poly1305 as AEAD
+import Data.ByteString as BS
 import qualified Data.ByteString.Base16 as B16
+import Data.Maybe (fromJust)
 import Test.Tasty
 import qualified Test.Tasty.HUnit as H
 
 main :: IO ()
 main = defaultMain $ testGroup "ppad-aead" [
     poly1305_key_gen
+  , encrypt
   ]
 
 poly1305_key_gen :: TestTree
@@ -27,3 +30,27 @@ poly1305_key_gen = H.testCase "poly1305_key_gen" $ do
       o = AEAD.poly1305_key_gen key non
   H.assertEqual mempty e o
 
+encrypt :: TestTree
+encrypt = H.testCase "encrypt" $ do
+    let (o_cip, o_tag) = AEAD.encrypt aad key iv salt sunscreen
+
+        e_cip = fromJust . B16.decode $
+          "d31a8d34648e60db7b86afbc53ef7ec2a4aded51296e08fea9e2b5a736ee62d63dbea45e8ca9671282fafb69da92728b1a71de0a9e060b2905d6a5b67ecd3b3692ddbd7f2d778b8c9803aee328091b58fab324e4fad675945585808b4831d7bc3ff4def08e4b7a9de576d26586cec64b6116"
+
+        e_tag = fromJust . B16.decode $
+          "1ae10b594f09e26a7e902ecbd0600691"
+
+    H.assertEqual mempty (e_cip, e_tag) (o_cip, o_tag)
+  where
+    sunscreen :: BS.ByteString
+    sunscreen = fromJust . B16.decode $
+      "4c616469657320616e642047656e746c656d656e206f662074686520636c617373206f66202739393a204966204920636f756c64206f6666657220796f75206f6e6c79206f6e652074697020666f7220746865206675747572652c2073756e73637265656e20776f756c642062652069742e"
+
+    aad = fromJust . B16.decode $ "50515253c0c1c2c3c4c5c6c7"
+    key = fromJust . B16.decode $
+      "808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9f"
+
+    iv = fromJust . B16.decode $ "4041424344454647"
+    salt = fromJust . B16.decode $ "07000000"
+
+