module Network.TLS.Crypto.ECDH
(
ECDHParams(..)
, ECDHPublic
, ECDHPrivate(..)
, ecdhPublic
, ecdhPrivate
, ecdhParams
, ecdhGenerateKeyPair
, ecdhGetShared
, ecdhUnwrap
, ecdhUnwrapPublic
) where
import Network.TLS.Util.Serialization (i2osp, lengthBytes)
import Network.TLS.Extension.EC
import qualified Crypto.PubKey.ECC.DH as ECDH
import qualified Crypto.PubKey.ECC.Types as ECDH
import qualified Crypto.PubKey.ECC.Prim as ECC (isPointValid)
import Network.TLS.RNG
import Data.ByteString (ByteString)
import Data.Word (Word16)
data ECDHPublic = ECDHPublic ECDH.PublicPoint Int
deriving (Show,Eq)
newtype ECDHPrivate = ECDHPrivate ECDH.PrivateNumber deriving (Show,Eq)
data ECDHParams = ECDHParams ECDH.Curve ECDH.CurveName deriving (Show,Eq)
type ECDHKey = ByteString
ecdhPublic :: Integer -> Integer -> Int -> ECDHPublic
ecdhPublic x y siz = ECDHPublic (ECDH.Point x y) siz
ecdhPrivate :: Integer -> ECDHPrivate
ecdhPrivate = ECDHPrivate
ecdhParams :: Word16 -> ECDHParams
ecdhParams w16 = ECDHParams curve name
where
Just name = toCurveName w16
curve = ECDH.getCurveByName name
ecdhGenerateKeyPair :: MonadRandom r => ECDHParams -> r (ECDHPrivate, ECDHPublic)
ecdhGenerateKeyPair (ECDHParams curve _) = do
priv <- ECDH.generatePrivate curve
let siz = pointSize curve
point = ECDH.calculatePublic curve priv
pub = ECDHPublic point siz
return (ECDHPrivate priv, pub)
ecdhGetShared :: ECDHParams -> ECDHPrivate -> ECDHPublic -> Maybe ECDHKey
ecdhGetShared (ECDHParams curve _) (ECDHPrivate priv) (ECDHPublic point _)
| ECC.isPointValid curve point =
let ECDH.SharedKey sk = ECDH.getShared curve priv point
in Just $ i2osp sk
| otherwise =
Nothing
ecdhUnwrap :: ECDHParams -> ECDHPublic -> (Word16,Integer,Integer,Int)
ecdhUnwrap (ECDHParams _ name) point = (w16,x,y,siz)
where
w16 = case fromCurveName name of
Just w -> w
Nothing -> error "ecdhUnwrap"
(x,y,siz) = ecdhUnwrapPublic point
ecdhUnwrapPublic :: ECDHPublic -> (Integer,Integer,Int)
ecdhUnwrapPublic (ECDHPublic (ECDH.Point x y) siz) = (x,y,siz)
ecdhUnwrapPublic _ = error "ecdhUnwrapPublic"
pointSize :: ECDH.Curve -> Int
pointSize (ECDH.CurveFP curve) = lengthBytes $ ECDH.ecc_p curve
pointSize _ = error "pointSize"