commit 9bfc6b6037ac9e4446a271cd350f8dc56dd94752
parent 4ed3bc557d7f5183374ba224f8ef892ea2dca538
Author: Jared Tobin <jared@jtobin.io>
Date: Fri, 5 Apr 2024 17:07:12 +0400
lib: add RFC6979 conversion utilities
Diffstat:
1 file changed, 70 insertions(+), 51 deletions(-)
diff --git a/lib/Crypto/Secp256k1.hs b/lib/Crypto/Secp256k1.hs
@@ -18,12 +18,18 @@ import GHC.Natural
import qualified GHC.Num.Integer as I
import Prelude hiding (mod)
+-- see https://www.secg.org/sec2-v2.pdf for parameter specs
+
-- secp256k1 field prime
--
--- _CURVE_P == 2^256 - 2^32 - 2^9 - 2^8 - 2^7 - 2^6 - 2^4 - 1
+-- ~ 2^256 - 2^32 - 2^9 - 2^8 - 2^7 - 2^6 - 2^4 - 1
_CURVE_P :: Integer
_CURVE_P = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
+-- | Division modulo secp256k1 field prime.
+modP :: Integer -> Integer
+modP a = I.integerMod a _CURVE_P
+
-- secp256k1 group order
_CURVE_N :: Integer
_CURVE_N = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
@@ -32,6 +38,10 @@ _CURVE_N = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
_CURVE_N_LEN :: Integer
_CURVE_N_LEN = 256
+-- bytelength of _CURVE_N
+_CURVE_N_BYTES :: Int
+_CURVE_N_BYTES = 32
+
-- secp256k1 short weierstrass form, /a/ coefficient
_CURVE_A :: Integer
_CURVE_A = 0
@@ -40,12 +50,8 @@ _CURVE_A = 0
_CURVE_B :: Integer
_CURVE_B = 7
-_CURVE_G :: BS.ByteString
-_CURVE_G =
- "0279BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798"
-
-- point in affine coordinates
-data Affine = Affine Integer Integer
+data Affine = Affine !Integer !Integer
deriving stock (Show, Generic)
instance Eq Affine where
@@ -68,11 +74,11 @@ instance Eq Projective where
y2z1 = modP (by * az)
in x1z2 == x2z1 && y1z2 == y2z1
--- secp256k1 base point
+-- secp256k1 generator
--
--- Just _BASE == parse _CURVE_G
-_BASE :: Projective
-_BASE = Projective x y 1 where
+-- ~ parse "0279BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798"
+_CURVE_G :: Projective
+_CURVE_G = Projective x y 1 where
x = 0x79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798
y = 0x483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8
@@ -80,9 +86,9 @@ _BASE = Projective x y 1 where
_ZERO :: Projective
_ZERO = Projective 0 1 0
--- | Division modulo secp256k1 field prime.
-modP :: Integer -> Integer
-modP a = I.integerMod a _CURVE_P
+-- | Division modulo secp256k1 group order.
+modN :: Integer -> Integer
+modN a = I.integerMod a _CURVE_N
-- | Is field element.
fe :: Integer -> Bool
@@ -92,6 +98,7 @@ fe n = 0 < n && n < _CURVE_P
ge :: Integer -> Bool
ge n = 0 < n && n < _CURVE_N
+-- modular inverse
-- for a, m return x such that ax = 1 mod m
modinv :: Integer -> Natural -> Maybe Integer
modinv a m = case I.integerRecipMod# a m of
@@ -99,6 +106,7 @@ modinv a m = case I.integerRecipMod# a m of
(# | _ #) -> Nothing
-- modular square root (shanks-tonelli)
+-- for a, m return x such that a = xx mod m
modsqrt :: Integer -> Maybe Integer
modsqrt n = runST $ do
r <- newSTRef 1
@@ -304,6 +312,7 @@ double (Projective x y z) = runST $ do
modifySTRef' x3 (\rx3 -> modP (rx3 + rx3))
Projective <$> readSTRef x3 <*> readSTRef y3 <*> readSTRef z3
+-- ec scalar multiplication
mul :: Projective -> Integer -> Projective
mul p n
| n == 0 = _ZERO
@@ -322,7 +331,7 @@ mul p n
mul_safe :: Projective -> Integer -> Projective
mul_safe p n
| not (ge n) = error "ppad-secp256k1 (mul_safe): scalar not in group"
- | otherwise = loop _ZERO _BASE p n
+ | otherwise = loop _ZERO _CURVE_G p n
where
loop !r !f !d m
| m <= 0 = r
@@ -362,35 +371,32 @@ valid p = case affine p of
-- | Parse hex-encoded compressed or uncompressed point.
parse :: BS.ByteString -> Maybe Projective
parse (B16.decode -> ebs) = case ebs of
- Left _ -> Nothing
- Right bs -> case BS.uncons bs of
- Nothing -> Nothing
- Just (fromIntegral -> h, t) ->
- let (roll -> x, etc) = BS.splitAt _CURVE_N_BYTES t
- len = BS.length bs
- in -- compressed
- if len == 33 && (h == 0x02 || h == 0x03)
- then if not (fe x)
- then Nothing
- else do
- y <- modsqrt (weierstrass x)
- let yodd = I.integerTestBit y 0
- hodd = I.integerTestBit h 0
- pure $
- if hodd /= yodd
- then Projective x (modP (negate y)) 1
- else Projective x y 1
- else -- uncompressed
- if len == 65 && h == 0x04
- then let (roll -> y, _) = BS.splitAt _CURVE_N_BYTES etc
- p = Projective x y 1
- in if valid p
- then Just p
- else Nothing
- else Nothing
- where
- _CURVE_N_BYTES :: Int
- _CURVE_N_BYTES = 32
+ Left _ -> Nothing
+ Right bs -> case BS.uncons bs of
+ Nothing -> Nothing
+ Just (fromIntegral -> h, t) ->
+ let (roll -> x, etc) = BS.splitAt _CURVE_N_BYTES t
+ len = BS.length bs
+ in -- compressed
+ if len == 33 && (h == 0x02 || h == 0x03)
+ then if not (fe x)
+ then Nothing
+ else do
+ y <- modsqrt (weierstrass x)
+ let yodd = I.integerTestBit y 0
+ hodd = I.integerTestBit h 0
+ pure $
+ if hodd /= yodd
+ then Projective x (modP (negate y)) 1
+ else Projective x y 1
+ else -- uncompressed
+ if len == 65 && h == 0x04
+ then let (roll -> y, _) = BS.splitAt _CURVE_N_BYTES etc
+ p = Projective x y 1
+ in if valid p
+ then Just p
+ else Nothing
+ else Nothing
-- big-endian bytestring decoding
roll :: BS.ByteString -> Integer
@@ -406,26 +412,39 @@ unroll i = case i of
step 0 = Nothing
step m = Just (fromIntegral m, m `I.integerShiftR` 8)
--- XX not sure how much i need these things; do roll and unroll suffice?
-
-- RFC6979
bits2int :: BS.ByteString -> Integer
bits2int bs =
- let (fromIntegral -> del) = BS.length bs * 8 - 256
- num = roll bs
+ let (fromIntegral -> blen) = BS.length bs * 8
+ (fromIntegral -> qlen) = _CURVE_N_LEN
+ del = blen - qlen
in if del > 0
- then num `I.integerShiftR` del
- else num
+ then roll bs `I.integerShiftR` del
+ else roll bs
-- RFC6979
int2octets :: Integer -> BS.ByteString
-int2octets = unroll
+int2octets i = pad (unroll i) where
+ pad !bs
+ | BS.length bs < _CURVE_N_BYTES = pad (BS.cons 0 bs)
+ | otherwise = bs
-- RFC6979
bits2octets :: BS.ByteString -> BS.ByteString
bits2octets bs =
let z1 = bits2int bs
- z2 = modP z1 -- XX correct modulo?
+ z2 = let d = z1 - _CURVE_N
+ in if d < 0
+ then z1
+ else d
in int2octets z2
+-- XX test
+
+test_h1 :: BS.ByteString
+test_h1 = B16.decodeLenient
+ "AF2BDBE1AA9B6EC1E2ADE1D694F41FC71A831D0268E9891562113D8A62ADD1BF"
+
+test_x :: Integer
+test_x = 0x09A4D6792295A7F730FC3F2B49CBC0F62E862272F