secp256k1

Pure Haskell cryptographic primitives on the secp256k1 elliptic curve.
git clone git://git.ppad.tech/secp256k1.git
Log | Files | Refs | LICENSE

commit a7369c62d5a8ce7dbd3d39a67154ff4cf5655a1a
parent 359b98879bed81b4a4a8f62639ef9b23a6100ddb
Author: Jared Tobin <jared@jtobin.io>
Date:   Tue, 19 Mar 2024 13:17:08 +0400

lib: fix bug in point addition

Diffstat:
Mlib/Crypto/Secp256k1.hs | 89+++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------
1 file changed, 64 insertions(+), 25 deletions(-)

diff --git a/lib/Crypto/Secp256k1.hs b/lib/Crypto/Secp256k1.hs @@ -7,13 +7,26 @@ import Control.Monad.ST import Data.Bits (Bits, (.&.), (.|.)) import qualified Data.Bits as B (shiftL, shiftR) import qualified Data.ByteString as BS +import qualified Data.ByteString.Base16 as B16 import Data.STRef --- XX seems like this should be easy to abstract +-- XX could stand some reorganization; lots of stuff 'baked in', e.g. in mods +-- +-- XX seems Point should have a reference to the curve it's on; probably a lazy +-- one. otherwise hard to implement Eq. +-- +-- XX then again, we're exclusively concerned with a single curve here, so +-- who cares. bake everything in. +-- +-- only counterargument is that we may want to reuse the same skeleton for +-- other libraries after the fact. so we have a single library implementing +-- modular arithmetic, curves, etc., and then implement other stuff on top +-- of that -- modular arithmetic utilities + -- XX be aware of non-constant-timeness in these; i should understand this --- exactly +-- issue *precisely* -- modular division moddiv :: Integral a => a -> a -> a @@ -115,6 +128,9 @@ _GROUP_BYTELENGTH = 32 -- curve point +data Affine a = Affine !a !a + deriving Show + data Point a = Point { px :: !a , py :: !a @@ -162,7 +178,7 @@ add (Point x1 y1 z1) (Point x2 y2 z2) = runST $ do writeSTRef t4 (mods (x1 + z1)) t5 <- newSTRef (mods (x2 + z2)) readSTRef t5 >>= \r5 -> - modifySTRef' t4 (\r4 -> mods (r4 + r5)) + modifySTRef' t4 (\r4 -> mods (r4 * r5)) readSTRef t0 >>= \r0 -> readSTRef t2 >>= \r2 -> writeSTRef t5 (mods (r0 + r2)) @@ -228,32 +244,55 @@ add (Point x1 y1 z1) (Point x2 y2 z2) = runST $ do double :: (Integral a, Num a) => Point a -> Point a double p = add p p --- XX assumes we're dealing with hex; need to decode +-- to affine coordinates +affine :: Integral a => Point a -> Maybe (Affine a) +affine p@(Point x y z) + | p == _ZERO = pure (Affine 0 0) + | z == 1 = pure (Affine x y) + | otherwise = do + iz <- modinv z (curve_p secp256k1) + if mods (z * iz) /= 1 + then Nothing + else pure (Affine (mods (x * iz)) (mods (y * iz))) + +-- point is valid +valid :: Integral a => Point a -> Bool +valid p = case affine p of + Nothing -> False + Just (Affine x y) + | not (fe x) || not (fe y) -> False + | mods (y * y) /= weierstrass x -> False + | otherwise -> True + parse_point :: (Bits a, Integral a) => BS.ByteString -> Maybe (Point a) -parse_point bs = case BS.uncons bs of - Nothing -> Nothing - Just (h, t) -> - let (roll -> x, etc) = BS.splitAt _GROUP_BYTELENGTH t - len = BS.length bs - in if len == 33 && (h == 0x02 || h == 0x03) -- compressed - then if not (fe x) -- x must be converted to num - then Nothing - else do - y <- modsqrt (weierstrass x) - let yodd = y .&. 1 == 1 - hodd = h .&. 1 == 1 - pure $ - if hodd /= yodd - then Point x (mods (negate y)) 1 - else Point x y 1 - else if len == 65 && h == 0x04 -- uncompressed - then let (roll -> y, _) = BS.splitAt _GROUP_BYTELENGTH etc - in Just (Point x y 1) -- XX check validity - else Nothing +parse_point (B16.decode -> ebs) = case ebs of + Left _ -> Nothing + Right bs -> case BS.uncons bs of + Nothing -> Nothing + Just (h, t) -> + let (roll -> x, etc) = BS.splitAt _GROUP_BYTELENGTH t + len = BS.length bs + in if len == 33 && (h == 0x02 || h == 0x03) -- compressed + then if not (fe x) + then Nothing + else do + y <- modsqrt (weierstrass x) + let yodd = y .&. 1 == 1 + hodd = h .&. 1 == 1 + pure $ + if hodd /= yodd + then Point x (mods (negate y)) 1 + else Point x y 1 + else if len == 65 && h == 0x04 -- uncompressed + then let (roll -> y, _) = BS.splitAt _GROUP_BYTELENGTH etc + p = Point x y 1 + in if valid p + then Just p + else Nothing + else Nothing -- big-endian bytestring decoding roll :: (Bits a, Integral a) => BS.ByteString -> a roll = BS.foldl' unstep 0 where unstep a b = a `B.shiftL` 8 .|. fromIntegral b -