commit 88859fcc42aa6187cb7e3c028e5da807c5a3df2e
parent de6534d42a9a8c60901a0dc0a310257a5d1e4be7
Author: Jared Tobin <jared@jtobin.io>
Date: Wed, 11 Jun 2025 13:00:12 +0400
lib: total functions
Diffstat:
6 files changed, 91 insertions(+), 74 deletions(-)
diff --git a/README.md b/README.md
@@ -27,7 +27,7 @@ A sample GHCi session:
> let aad = "and i approve it"
>
> -- encryption produces a 128-bit MAC
- > let (cip, mac) = AEAD.encrypt aad key non msg
+ > let Right (cip, mac) = AEAD.encrypt aad key non msg
> B16.encode cip
"d6377eab18cad56e8c6176968460e6a548c524b9498c9b993e"
> B16.encode mac
@@ -35,11 +35,11 @@ A sample GHCi session:
>
> -- supply both to decrypt
> AEAD.decrypt aad key non (cip, tag)
- Just "this is my secret message"
+ Right "this is my secret message"
>
> -- bogus MACs will cause decryption to fail
> AEAD.decrypt aad key non (cip, "really i swear!!")
- Nothing
+ Left InvalidMAC
```
## Documentation
diff --git a/bench/Main.hs b/bench/Main.hs
@@ -1,13 +1,22 @@
+{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE StandaloneDeriving #-}
module Main where
+import Control.DeepSeq
import Criterion.Main
import qualified Crypto.AEAD.ChaCha20Poly1305 as AEAD
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base16 as B16
import Data.Maybe (fromJust)
+import GHC.Generics
+
+deriving instance Generic AEAD.Error
+
+instance NFData AEAD.Error
main :: IO ()
main = defaultMain [
diff --git a/flake.lock b/flake.lock
@@ -51,11 +51,11 @@
]
},
"locked": {
- "lastModified": 1740802922,
- "narHash": "sha256-j+sxOWxnsMgX3GHyf7Z44lvAiBkrl/p0lD0eYli6Lgo=",
+ "lastModified": 1741625558,
+ "narHash": "sha256-ZBDXRD5fsVqA5bGrAlcnhiu67Eo50q0M9614nR3NBwY=",
"ref": "master",
- "rev": "043c845ae7f280ddbfdb5568ea453c9943e49cf2",
- "revCount": 21,
+ "rev": "fb63457f2e894eda28250dfe65d0fcd1d195ac2f",
+ "revCount": 24,
"type": "git",
"url": "git://git.ppad.tech/base16.git"
},
@@ -85,11 +85,11 @@
]
},
"locked": {
- "lastModified": 1741521082,
- "narHash": "sha256-Y8PA5Y4d7Rj6qgwnrgTyg5BvHGLLA/FAOWC+dxRnbDA=",
+ "lastModified": 1749626971,
+ "narHash": "sha256-+jWE4Kq7Q2KB5jCu2S++TQYHujVVkqsamnf0BXKDTIA=",
"ref": "master",
- "rev": "3b25b1dd3f346b2374dba198817749e61f8dcf91",
- "revCount": 15,
+ "rev": "efad15250e776b00cf37535c18ffc01c520fc679",
+ "revCount": 18,
"type": "git",
"url": "git://git.ppad.tech/chacha.git"
},
@@ -139,11 +139,11 @@
]
},
"locked": {
- "lastModified": 1741611492,
- "narHash": "sha256-f8cxxMiEoCFbM9zJeS1B45nzVIwzS27o7zN+cZAfQJQ=",
+ "lastModified": 1749628515,
+ "narHash": "sha256-6idwz4Wh+MHGgNorwWyGEUKxovBEtlEYx0JdHNYBE0I=",
"ref": "master",
- "rev": "30ec265e9cec86d0ac1a50c42ead189d8e2c6821",
- "revCount": 12,
+ "rev": "21be7b8655da65e8da88fa89ad155b5d91bf5885",
+ "revCount": 15,
"type": "git",
"url": "git://git.ppad.tech/poly1305.git"
},
diff --git a/lib/Crypto/AEAD/ChaCha20Poly1305.hs b/lib/Crypto/AEAD/ChaCha20Poly1305.hs
@@ -18,6 +18,9 @@ module Crypto.AEAD.ChaCha20Poly1305 (
encrypt
, decrypt
+ -- * Error information
+ , Error(..)
+
-- testing
, _poly1305_key_gen
) where
@@ -56,10 +59,11 @@ unroll8 (unroll -> u@(BI.PS _ _ l))
_poly1305_key_gen
:: BS.ByteString -- ^ 256-bit initial keying material
-> BS.ByteString -- ^ 96-bit nonce
- -> BS.ByteString -- ^ 256-bit key (suitable for poly1305)
-_poly1305_key_gen key@(BI.PS _ _ l) nonce
- | l /= 32 = error "ppad-aead (poly1305_key_gen): invalid key"
- | otherwise = BS.take 32 (ChaCha20.block key 0 nonce)
+ -> Either Error BS.ByteString -- ^ 256-bit key (suitable for poly1305)
+_poly1305_key_gen key nonce = case ChaCha20.block key 0 nonce of
+ Left ChaCha20.InvalidKey -> Left InvalidKey
+ Left ChaCha20.InvalidNonce -> Left InvalidNonce
+ Right k -> pure (BS.take 32 k)
{-# INLINEABLE _poly1305_key_gen #-}
pad16 :: BS.ByteString -> BS.ByteString
@@ -68,6 +72,12 @@ pad16 (BI.PS _ _ l)
| otherwise = BS.replicate (16 - l `rem` 16) 0
{-# INLINE pad16 #-}
+data Error =
+ InvalidKey
+ | InvalidNonce
+ | InvalidMAC
+ deriving (Eq, Show)
+
-- RFC8439 2.8
-- | Perform authenticated encryption on a plaintext and some additional
@@ -76,14 +86,11 @@ pad16 (BI.PS _ _ l)
--
-- Produces a ciphertext and 128-bit message authentication code pair.
--
--- Providing an invalid key or nonce will result in an 'ErrorCall'
--- exception being thrown.
---
-- >>> let key = "don't tell anyone my secret key!"
-- >>> let non = "or my nonce!"
-- >>> let pan = "and here's my plaintext"
-- >>> let aad = "i approve this message"
--- >>> let (cip, mac) = encrypt aad key nonce pan
+-- >>> let Right (cip, mac) = encrypt aad key nonce pan
-- >>> (cip, mac)
-- <(ciphertext, 128-bit MAC)>
encrypt
@@ -91,51 +98,55 @@ encrypt
-> BS.ByteString -- ^ 256-bit key
-> BS.ByteString -- ^ 96-bit nonce
-> BS.ByteString -- ^ arbitrary-length plaintext
- -> (BS.ByteString, BS.ByteString) -- ^ (ciphertext, 128-bit MAC)
+ -> Either Error (BS.ByteString, BS.ByteString) -- ^ (ciphertext, 128-bit MAC)
encrypt aad key nonce plaintext
- | BS.length key /= 32 = error "ppad-aead (encrypt): invalid key"
- | BS.length nonce /= 12 = error "ppad-aead (encrypt): invalid nonce"
- | otherwise =
- let otk = _poly1305_key_gen key nonce
- cip = ChaCha20.cipher key 1 nonce plaintext
- md0 = aad <> pad16 aad
- md1 = md0 <> cip <> pad16 cip
- md2 = md1 <> unroll8 (fi (BS.length aad))
- md3 = md2 <> unroll8 (fi (BS.length cip))
- tag = Poly1305.mac otk md3
- in (cip, tag)
+ | BS.length key /= 32 = Left InvalidKey
+ | BS.length nonce /= 12 = Left InvalidNonce
+ | otherwise = do
+ otk <- _poly1305_key_gen key nonce
+ case ChaCha20.cipher key 1 nonce plaintext of
+ Left ChaCha20.InvalidKey -> Left InvalidKey -- impossible, but..
+ Left ChaCha20.InvalidNonce -> Left InvalidNonce -- ditto
+ Right cip -> do
+ let md0 = aad <> pad16 aad
+ md1 = md0 <> cip <> pad16 cip
+ md2 = md1 <> unroll8 (fi (BS.length aad))
+ md3 = md2 <> unroll8 (fi (BS.length cip))
+ case Poly1305.mac otk md3 of
+ Nothing -> Left InvalidKey
+ Just tag -> pure (cip, tag)
-- | Decrypt an authenticated ciphertext, given a message authentication
-- code and some additional authenticated data, via a 256-bit key and
-- 96-bit nonce.
--
--- Returns 'Nothing' if the MAC fails to validate.
---
--- Providing an invalid key or nonce will result in an 'ErrorCall'
--- exception being thrown.
---
-- >>> decrypt aad key non (cip, mac)
--- Just "and here's my plaintext"
+-- Right "and here's my plaintext"
-- >>> decrypt aad key non (cip, "it's a valid mac")
--- Nothing
+-- Left InvalidMAC
decrypt
:: BS.ByteString -- ^ arbitrary-length AAD
-> BS.ByteString -- ^ 256-bit key
-> BS.ByteString -- ^ 96-bit nonce
-> (BS.ByteString, BS.ByteString) -- ^ (arbitrary-length ciphertext, 128-bit MAC)
- -> Maybe BS.ByteString
+ -> Either Error BS.ByteString
decrypt aad key nonce (cip, mac)
- | BS.length key /= 32 = error "ppad-aead (decrypt): invalid key"
- | BS.length nonce /= 12 = error "ppad-aead (decrypt): invalid nonce"
- | BS.length mac /= 16 = Nothing
- | otherwise =
- let otk = _poly1305_key_gen key nonce
- md0 = aad <> pad16 aad
+ | BS.length key /= 32 = Left InvalidKey
+ | BS.length nonce /= 12 = Left InvalidNonce
+ | BS.length mac /= 16 = Left InvalidMAC
+ | otherwise = do
+ otk <- _poly1305_key_gen key nonce
+ let md0 = aad <> pad16 aad
md1 = md0 <> cip <> pad16 cip
md2 = md1 <> unroll8 (fi (BS.length aad))
md3 = md2 <> unroll8 (fi (BS.length cip))
- tag = Poly1305.mac otk md3
- in if mac == tag
- then pure (ChaCha20.cipher key 1 nonce cip)
- else Nothing
+ case Poly1305.mac otk md3 of
+ Nothing -> Left InvalidKey
+ Just tag
+ | mac == tag -> case ChaCha20.cipher key 1 nonce cip of
+ Left ChaCha20.InvalidKey -> Left InvalidKey
+ Left ChaCha20.InvalidNonce -> Left InvalidNonce
+ Right v -> pure v
+ | otherwise ->
+ Left InvalidMAC
diff --git a/ppad-aead.cabal b/ppad-aead.cabal
@@ -67,6 +67,7 @@ benchmark aead-bench
base
, bytestring
, criterion
+ , deepseq
, ppad-base16
, ppad-aead
diff --git a/test/Main.hs b/test/Main.hs
@@ -5,7 +5,6 @@
module Main where
-import Control.Exception
import qualified Crypto.AEAD.ChaCha20Poly1305 as AEAD
import Data.ByteString as BS
import qualified Data.Aeson as A
@@ -43,14 +42,14 @@ poly1305_key_gen = H.testCase "poly1305_key_gen" $ do
Just e = B16.decode
"8ad5a08b905f81cc815040274ab29471a833b637e3fd0da508dbb8e2fdd1a646"
- o = AEAD._poly1305_key_gen key non
+ Right o = AEAD._poly1305_key_gen key non
H.assertEqual mempty e o
crypt :: TestTree
crypt = H.testCase "encrypt/decrypt" $ do
let nonce = salt <> iv
- (o_cip, o_tag) = AEAD.encrypt aad key nonce sunscreen
+ Right (o_cip, o_tag) = AEAD.encrypt aad key nonce sunscreen
e_cip = fromJust . B16.decode $
"d31a8d34648e60db7b86afbc53ef7ec2a4aded51296e08fea9e2b5a736ee62d63dbea45e8ca9671282fafb69da92728b1a71de0a9e060b2905d6a5b67ecd3b3692ddbd7f2d778b8c9803aee328091b58fab324e4fad675945585808b4831d7bc3ff4def08e4b7a9de576d26586cec64b6116"
@@ -59,10 +58,10 @@ crypt = H.testCase "encrypt/decrypt" $ do
"1ae10b594f09e26a7e902ecbd0600691"
- o_dec = AEAD.decrypt aad key nonce (o_cip, o_tag)
+ Right o_dec = AEAD.decrypt aad key nonce (o_cip, o_tag)
H.assertEqual mempty (e_cip, e_tag) (o_cip, o_tag)
- H.assertEqual mempty (Just sunscreen) o_dec
+ H.assertEqual mempty sunscreen o_dec
where
sunscreen :: BS.ByteString
sunscreen = fromJust . B16.decode $
@@ -94,7 +93,7 @@ crypt0 = H.testCase "decrypt (A.5)" $ do
tag = fromJust . B16.decode $
"eead9d67890cbb22392336fea1851f38"
- Just pan = AEAD.decrypt aad key non (cip, tag)
+ Right pan = AEAD.decrypt aad key non (cip, tag)
H.assertEqual mempty e_pan pan
@@ -115,21 +114,18 @@ execute W.AEADTest {..} = H.testCase t_msg $ do
msg = aeadt_msg
ct = aeadt_ct
tag = aeadt_tag
- if aeadt_result == "invalid"
- then do
- out <- try (pure $! AEAD.decrypt aad key iv (ct, tag))
- :: IO (Either ErrorCall (Maybe BS.ByteString))
- case out of
- Left _ -> H.assertBool "invalid (bogus key/nonce)" True
- Right Nothing -> H.assertBool "invalid (bogus MAC)" True
- Right (Just o) -> H.assertBool "invalid" (msg /= o)
- else do
- let (out_cip, out_mac) = AEAD.encrypt aad key iv msg
- out_pan = AEAD.decrypt aad key iv (ct, tag)
- H.assertEqual mempty ct out_cip
- H.assertEqual mempty tag out_mac
- H.assertEqual mempty (Just msg) out_pan
+ case AEAD.decrypt aad key iv (ct, tag) of
+ Left _
+ | aeadt_result == "invalid" -> H.assertBool "invalid" True
+ | otherwise -> H.assertFailure mempty
+ Right out
+ | aeadt_result == "invalid" -> H.assertFailure mempty
+ | otherwise -> case AEAD.encrypt aad key iv msg of
+ Left _ -> H.assertFailure mempty
+ Right (out_cip, out_mac) -> do
+ H.assertEqual mempty ct out_cip
+ H.assertEqual mempty tag out_mac
+ H.assertEqual mempty msg out
where
t_msg = "test " <> show aeadt_tcId
-