BIP340.hs (3281B)
1 {-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-} 2 {-# LANGUAGE OverloadedStrings #-} 3 {-# LANGUAGE RecordWildCards #-} 4 {-# LANGUAGE ViewPatterns #-} 5 6 module BIP340 ( 7 cases 8 , execute 9 ) where 10 11 import Control.Applicative 12 import Crypto.Curve.Secp256k1 13 import qualified Data.Attoparsec.ByteString.Char8 as AT 14 import qualified Data.ByteString as BS 15 import qualified Data.ByteString.Base16 as B16 16 import qualified GHC.Num.Integer as I 17 import Test.Tasty 18 import Test.Tasty.HUnit 19 20 -- XX make a test prelude instead of copying/pasting these things everywhere 21 22 decodeLenient :: BS.ByteString -> BS.ByteString 23 decodeLenient bs = case B16.decode bs of 24 Nothing -> error "bang" 25 Just b -> b 26 27 fi :: (Integral a, Num b) => a -> b 28 fi = fromIntegral 29 {-# INLINE fi #-} 30 31 roll :: BS.ByteString -> Integer 32 roll = BS.foldl' unstep 0 where 33 unstep a (fi -> b) = (a `I.integerShiftL` 8) `I.integerOr` b 34 35 data Case = Case { 36 c_index :: !Int 37 , c_sk :: !BS.ByteString 38 , c_pk :: !BS.ByteString 39 , c_aux :: !BS.ByteString 40 , c_msg :: !BS.ByteString 41 , c_sig :: !BS.ByteString 42 , c_res :: !Bool 43 , c_comment :: !BS.ByteString 44 } deriving Show 45 46 execute :: Context -> Case -> TestTree 47 execute tex Case {..} = testCase ("bip0340 " <> show c_index) $ 48 case parse_point (decodeLenient c_pk) of 49 Nothing -> assertBool mempty (not c_res) 50 Just pk -> do 51 if c_sk == mempty 52 then do -- no signature; test verification 53 let ver = verify_schnorr c_msg pk c_sig 54 ver' = verify_schnorr' tex c_msg pk c_sig 55 if c_res 56 then do 57 assertBool mempty ver 58 assertBool mempty ver' 59 else do 60 assertBool mempty (not ver) 61 assertBool mempty (not ver') 62 -- XX test pubkey derivation from sk 63 else do -- signature present; test sig too 64 let sk = roll c_sk 65 Just sig = sign_schnorr sk c_msg c_aux 66 Just sig' = sign_schnorr' tex sk c_msg c_aux 67 ver = verify_schnorr c_msg pk sig 68 ver' = verify_schnorr' tex c_msg pk sig 69 assertEqual mempty c_sig sig 70 assertEqual mempty c_sig sig' 71 if c_res 72 then do 73 assertBool mempty ver 74 assertBool mempty ver' 75 else do 76 assertBool mempty (not ver) 77 assertBool mempty (not ver') 78 79 header :: AT.Parser () 80 header = do 81 _ <- AT.string "index,secret key,public key,aux_rand,message,signature,verification result,comment" 82 AT.endOfLine 83 84 test_case :: AT.Parser Case 85 test_case = do 86 c_index <- AT.decimal AT.<?> "index" 87 _ <- AT.char ',' 88 c_sk <- fmap decodeLenient (AT.takeWhile (/= ',') AT.<?> "sk") 89 _ <- AT.char ',' 90 c_pk <- AT.takeWhile1 (/= ',') AT.<?> "pk" 91 _ <- AT.char ',' 92 c_aux <- fmap decodeLenient (AT.takeWhile (/= ',') AT.<?> "aux") 93 _ <- AT.char ',' 94 c_msg <- fmap decodeLenient (AT.takeWhile (/= ',') AT.<?> "msg") 95 _ <- AT.char ',' 96 c_sig <- fmap decodeLenient (AT.takeWhile1 (/= ',') AT.<?> "sig") 97 _ <- AT.char ',' 98 c_res <- (AT.string "TRUE" *> pure True) <|> (AT.string "FALSE" *> pure False) 99 AT.<?> "res" 100 _ <- AT.char ',' 101 c_comment <- AT.takeWhile (/= '\n') AT.<?> "comment" 102 AT.endOfLine 103 pure Case {..} 104 105 cases :: AT.Parser [Case] 106 cases = header *> AT.many1 test_case 107