{-# LANGUAGE RecordWildCards #-}

module Crypto.HPKE.KeyPair where

import qualified Control.Exception as E
import Crypto.ECC (
    EllipticCurve (..),
 )
import Crypto.HPKE.ID
import Crypto.HPKE.KEM (genKeyPairP)
import Crypto.HPKE.Types

----------------------------------------------------------------

-- | Generating a pair of public key and secret key based on
-- 'KEM_ID'.
genKeyPair
    :: HPKEMap -> KEM_ID -> IO (EncodedPublicKey, EncodedSecretKey)
genKeyPair :: HPKEMap -> KEM_ID -> IO (EncodedPublicKey, EncodedSecretKey)
genKeyPair HPKEMap{[(AEAD_ID, AEADCipher)]
[(KDF_ID, KDFHash)]
[(KEM_ID, (KEMGroup, KDFHash))]
kemMap :: [(KEM_ID, (KEMGroup, KDFHash))]
kdfMap :: [(KDF_ID, KDFHash)]
cipherMap :: [(AEAD_ID, AEADCipher)]
cipherMap :: HPKEMap -> [(AEAD_ID, AEADCipher)]
kdfMap :: HPKEMap -> [(KDF_ID, KDFHash)]
kemMap :: HPKEMap -> [(KEM_ID, (KEMGroup, KDFHash))]
..} KEM_ID
kem_id = case KEM_ID
-> [(KEM_ID, (KEMGroup, KDFHash))] -> Maybe (KEMGroup, KDFHash)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup KEM_ID
kem_id [(KEM_ID, (KEMGroup, KDFHash))]
kemMap of
    Maybe (KEMGroup, KDFHash)
Nothing -> HPKEError -> IO (EncodedPublicKey, EncodedSecretKey)
forall e a. (HasCallStack, Exception e) => e -> IO a
E.throwIO (HPKEError -> IO (EncodedPublicKey, EncodedSecretKey))
-> HPKEError -> IO (EncodedPublicKey, EncodedSecretKey)
forall a b. (a -> b) -> a -> b
$ String -> HPKEError
Unsupported (String -> HPKEError) -> String -> HPKEError
forall a b. (a -> b) -> a -> b
$ KEM_ID -> String
forall a. Show a => a -> String
show KEM_ID
kem_id
    Just (KEMGroup Proxy c
proxy, KDFHash
_) -> do
        (pk, sk) <- Proxy c -> IO (Point c, Scalar c)
forall curve (proxy :: * -> *).
EllipticCurve curve =>
proxy curve -> IO (Point curve, Scalar curve)
genKeyPairP Proxy c
proxy
        let pkm = ByteString -> EncodedPublicKey
EncodedPublicKey (ByteString -> EncodedPublicKey) -> ByteString -> EncodedPublicKey
forall a b. (a -> b) -> a -> b
$ Proxy c -> Point c -> ByteString
forall curve bs (proxy :: * -> *).
(EllipticCurve curve, ByteArray bs) =>
proxy curve -> Point curve -> bs
forall bs (proxy :: * -> *).
ByteArray bs =>
proxy c -> Point c -> bs
encodePoint Proxy c
proxy Point c
pk
            skm = ByteString -> EncodedSecretKey
EncodedSecretKey (ByteString -> EncodedSecretKey) -> ByteString -> EncodedSecretKey
forall a b. (a -> b) -> a -> b
$ Proxy c -> Scalar c -> ByteString
forall curve bs (proxy :: * -> *).
(EllipticCurve curve, ByteArray bs) =>
proxy curve -> Scalar curve -> bs
forall bs (proxy :: * -> *).
ByteArray bs =>
proxy c -> Scalar c -> bs
encodeScalar Proxy c
proxy Scalar c
sk
        return (pkm, skm)