{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE PatternSynonyms #-}

module Crypto.HPKE.ID (
    AEAD_ID (AES_128_GCM, AES_256_GCM, ChaCha20Poly1305, ..),
    AEADCipher (..),
    defaultAEADMap,
    --
    KDF_ID (HKDF_SHA256, HKDF_SHA384, HKDF_SHA512, ..),
    KDFHash (..),
    defaultKDFMap,
    --
    KEM_ID (
        DHKEM_P256_HKDF_SHA256,
        DHKEM_P384_HKDF_SHA384,
        DHKEM_P521_HKDF_SHA512,
        DHKEM_X25519_HKDF_SHA256,
        DHKEM_X448_HKDF_SHA512,
        ..
    ),
    defaultKEMMap,
    KEMGroup (..),
    --
    HPKEMap (..),
    defaultHPKEMap,
) where

import Crypto.Cipher.AES (AES128, AES256)
import Crypto.Cipher.ChaChaPoly1305 (ChaCha20Poly1305)
import Crypto.ECC (
    Curve_P256R1,
    Curve_P384R1,
    Curve_P521R1,
    Curve_X25519,
    Curve_X448,
    EllipticCurve (..),
    EllipticCurveDH (..),
 )
import Data.Proxy (Proxy (..))
import Data.Word (Word16)
import Text.Printf (printf)

import Crypto.HPKE.AEAD
import Crypto.HPKE.KDF

----------------------------------------------------------------

-- | ID for authenticated encryption with additional data
newtype AEAD_ID = AEAD_ID {AEAD_ID -> Word16
fromAEAD_ID :: Word16} deriving (AEAD_ID -> AEAD_ID -> Bool
(AEAD_ID -> AEAD_ID -> Bool)
-> (AEAD_ID -> AEAD_ID -> Bool) -> Eq AEAD_ID
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: AEAD_ID -> AEAD_ID -> Bool
== :: AEAD_ID -> AEAD_ID -> Bool
$c/= :: AEAD_ID -> AEAD_ID -> Bool
/= :: AEAD_ID -> AEAD_ID -> Bool
Eq)

{- FOURMOLU_DISABLE -}
pattern AES_128_GCM      :: AEAD_ID
pattern $bAES_128_GCM :: AEAD_ID
$mAES_128_GCM :: forall {r}. AEAD_ID -> ((# #) -> r) -> ((# #) -> r) -> r
AES_128_GCM       = AEAD_ID 0x0001
pattern AES_256_GCM      :: AEAD_ID
pattern $bAES_256_GCM :: AEAD_ID
$mAES_256_GCM :: forall {r}. AEAD_ID -> ((# #) -> r) -> ((# #) -> r) -> r
AES_256_GCM       = AEAD_ID 0x0002
pattern ChaCha20Poly1305 :: AEAD_ID
pattern $bChaCha20Poly1305 :: AEAD_ID
$mChaCha20Poly1305 :: forall {r}. AEAD_ID -> ((# #) -> r) -> ((# #) -> r) -> r
ChaCha20Poly1305  = AEAD_ID 0x0003

instance Show AEAD_ID where
    show :: AEAD_ID -> String
show AEAD_ID
AES_128_GCM      = String
"AES_128_GCM"
    show AEAD_ID
AES_256_GCM      = String
"AES_256_GCM"
    show AEAD_ID
ChaCha20Poly1305 = String
"ChaCha20Poly1305"
    show (AEAD_ID Word16
n)      = String
"AEAD_ID 0x" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> Word16 -> String
forall r. PrintfType r => String -> r
printf String
"%04x" Word16
n
{- FOURMOLU_ENABLE -}

{- FOURMOLU_DISABLE -}
aes128 :: Proxy AES128
aes128 :: Proxy AES128
aes128  = Proxy AES128
forall {k} (t :: k). Proxy t
Proxy :: Proxy AES128
aes256 :: Proxy AES256
aes256 :: Proxy AES256
aes256  = Proxy AES256
forall {k} (t :: k). Proxy t
Proxy :: Proxy AES256
chacha :: Proxy ChaCha20Poly1305
chacha :: Proxy ChaCha20Poly1305
chacha  = Proxy ChaCha20Poly1305
forall {k} (t :: k). Proxy t
Proxy :: Proxy ChaCha20Poly1305
{- FOURMOLU_ENABLE -}

data AEADCipher = forall a. Aead a => AEADCipher (Proxy a)

{- FOURMOLU_DISABLE -}
defaultAEADMap :: [(AEAD_ID, AEADCipher)]
defaultAEADMap :: [(AEAD_ID, AEADCipher)]
defaultAEADMap =
    [ (AEAD_ID
AES_128_GCM,      Proxy AES128 -> AEADCipher
forall a. Aead a => Proxy a -> AEADCipher
AEADCipher Proxy AES128
aes128)
    , (AEAD_ID
AES_256_GCM,      Proxy AES256 -> AEADCipher
forall a. Aead a => Proxy a -> AEADCipher
AEADCipher Proxy AES256
aes256)
    , (AEAD_ID
ChaCha20Poly1305, Proxy ChaCha20Poly1305 -> AEADCipher
forall a. Aead a => Proxy a -> AEADCipher
AEADCipher Proxy ChaCha20Poly1305
chacha)
    ]
{- FOURMOLU_ENABLE -}

----------------------------------------------------------------

-- | ID for key derivation function.
newtype KDF_ID = KDF_ID {KDF_ID -> Word16
fromKDF_ID :: Word16} deriving (KDF_ID -> KDF_ID -> Bool
(KDF_ID -> KDF_ID -> Bool)
-> (KDF_ID -> KDF_ID -> Bool) -> Eq KDF_ID
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: KDF_ID -> KDF_ID -> Bool
== :: KDF_ID -> KDF_ID -> Bool
$c/= :: KDF_ID -> KDF_ID -> Bool
/= :: KDF_ID -> KDF_ID -> Bool
Eq)

{- FOURMOLU_DISABLE -}
pattern HKDF_SHA256 :: KDF_ID
pattern $bHKDF_SHA256 :: KDF_ID
$mHKDF_SHA256 :: forall {r}. KDF_ID -> ((# #) -> r) -> ((# #) -> r) -> r
HKDF_SHA256  = KDF_ID 0x0001
pattern HKDF_SHA384 :: KDF_ID
pattern $bHKDF_SHA384 :: KDF_ID
$mHKDF_SHA384 :: forall {r}. KDF_ID -> ((# #) -> r) -> ((# #) -> r) -> r
HKDF_SHA384  = KDF_ID 0x0002
pattern HKDF_SHA512 :: KDF_ID
pattern $bHKDF_SHA512 :: KDF_ID
$mHKDF_SHA512 :: forall {r}. KDF_ID -> ((# #) -> r) -> ((# #) -> r) -> r
HKDF_SHA512  = KDF_ID 0x0003

instance Show KDF_ID where
    show :: KDF_ID -> String
show KDF_ID
HKDF_SHA256 = String
"HKDF_SHA256"
    show KDF_ID
HKDF_SHA384 = String
"HKDF_SHA384"
    show KDF_ID
HKDF_SHA512 = String
"HKDF_SHA512"
    show (KDF_ID Word16
n)  = String
"HKDF_ID 0x" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> Word16 -> String
forall r. PrintfType r => String -> r
printf String
"%04x" Word16
n
{- FOURMOLU_ENABLE -}

data KDFHash = forall h. (HashAlgorithm h, KDF h) => KDFHash h

defaultKDFMap :: [(KDF_ID, KDFHash)]
defaultKDFMap :: [(KDF_ID, KDFHash)]
defaultKDFMap =
    [ (KDF_ID
HKDF_SHA256, SHA256 -> KDFHash
forall h. (HashAlgorithm h, KDF h) => h -> KDFHash
KDFHash SHA256
SHA256)
    , (KDF_ID
HKDF_SHA384, SHA384 -> KDFHash
forall h. (HashAlgorithm h, KDF h) => h -> KDFHash
KDFHash SHA384
SHA384)
    , (KDF_ID
HKDF_SHA512, SHA512 -> KDFHash
forall h. (HashAlgorithm h, KDF h) => h -> KDFHash
KDFHash SHA512
SHA512)
    ]

----------------------------------------------------------------
----------------------------------------------------------------

-- | ID for key encapsulation mechanism.
newtype KEM_ID = KEM_ID {KEM_ID -> Word16
fromKEM_ID :: Word16} deriving (KEM_ID -> KEM_ID -> Bool
(KEM_ID -> KEM_ID -> Bool)
-> (KEM_ID -> KEM_ID -> Bool) -> Eq KEM_ID
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: KEM_ID -> KEM_ID -> Bool
== :: KEM_ID -> KEM_ID -> Bool
$c/= :: KEM_ID -> KEM_ID -> Bool
/= :: KEM_ID -> KEM_ID -> Bool
Eq)

{- FOURMOLU_DISABLE -}
pattern DHKEM_P256_HKDF_SHA256   :: KEM_ID
pattern $bDHKEM_P256_HKDF_SHA256 :: KEM_ID
$mDHKEM_P256_HKDF_SHA256 :: forall {r}. KEM_ID -> ((# #) -> r) -> ((# #) -> r) -> r
DHKEM_P256_HKDF_SHA256    = KEM_ID 0x0010
pattern DHKEM_P384_HKDF_SHA384   :: KEM_ID
pattern $bDHKEM_P384_HKDF_SHA384 :: KEM_ID
$mDHKEM_P384_HKDF_SHA384 :: forall {r}. KEM_ID -> ((# #) -> r) -> ((# #) -> r) -> r
DHKEM_P384_HKDF_SHA384    = KEM_ID 0x0011
pattern DHKEM_P521_HKDF_SHA512   :: KEM_ID
pattern $bDHKEM_P521_HKDF_SHA512 :: KEM_ID
$mDHKEM_P521_HKDF_SHA512 :: forall {r}. KEM_ID -> ((# #) -> r) -> ((# #) -> r) -> r
DHKEM_P521_HKDF_SHA512    = KEM_ID 0x0012
pattern DHKEM_X25519_HKDF_SHA256 :: KEM_ID
pattern $bDHKEM_X25519_HKDF_SHA256 :: KEM_ID
$mDHKEM_X25519_HKDF_SHA256 :: forall {r}. KEM_ID -> ((# #) -> r) -> ((# #) -> r) -> r
DHKEM_X25519_HKDF_SHA256  = KEM_ID 0x0020
pattern DHKEM_X448_HKDF_SHA512   :: KEM_ID
pattern $bDHKEM_X448_HKDF_SHA512 :: KEM_ID
$mDHKEM_X448_HKDF_SHA512 :: forall {r}. KEM_ID -> ((# #) -> r) -> ((# #) -> r) -> r
DHKEM_X448_HKDF_SHA512    = KEM_ID 0x0021

instance Show KEM_ID where
    show :: KEM_ID -> String
show KEM_ID
DHKEM_P256_HKDF_SHA256   = String
"DHKEM(P-256, HKDF-SHA256)"
    show KEM_ID
DHKEM_P384_HKDF_SHA384   = String
"DHKEM(P-384, HKDF-SHA384)"
    show KEM_ID
DHKEM_P521_HKDF_SHA512   = String
"DHKEM(P-521, HKDF-SHA512)"
    show KEM_ID
DHKEM_X25519_HKDF_SHA256 = String
"DHKEM(X25519, HKDF-SHA256)"
    show KEM_ID
DHKEM_X448_HKDF_SHA512   = String
"DHKEM(X448, HKDF-SHA512)"
    show (KEM_ID Word16
n)               = String
"DHKEM_ID 0x" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> Word16 -> String
forall r. PrintfType r => String -> r
printf String
"%04x" Word16
n
{- FOURMOLU_ENABLE -}

----------------------------------------------------------------

{- FOURMOLU_DISABLE -}
p256   :: Proxy Curve_P256R1
p256 :: Proxy Curve_P256R1
p256    = Proxy Curve_P256R1
forall {k} (t :: k). Proxy t
Proxy :: Proxy Curve_P256R1
p384   :: Proxy Curve_P384R1
p384 :: Proxy Curve_P384R1
p384    = Proxy Curve_P384R1
forall {k} (t :: k). Proxy t
Proxy :: Proxy Curve_P384R1
p521   :: Proxy Curve_P521R1
p521 :: Proxy Curve_P521R1
p521    = Proxy Curve_P521R1
forall {k} (t :: k). Proxy t
Proxy :: Proxy Curve_P521R1
x25519 :: Proxy Curve_X25519
x25519 :: Proxy Curve_X25519
x25519  = Proxy Curve_X25519
forall {k} (t :: k). Proxy t
Proxy :: Proxy Curve_X25519
x448   :: Proxy Curve_X448
x448 :: Proxy Curve_X448
x448    = Proxy Curve_X448
forall {k} (t :: k). Proxy t
Proxy :: Proxy Curve_X448

data KEMGroup
    = forall c. (EllipticCurve c, EllipticCurveDH c) => KEMGroup (Proxy c)

defaultKEMMap :: [(KEM_ID, (KEMGroup, KDFHash))]
defaultKEMMap :: [(KEM_ID, (KEMGroup, KDFHash))]
defaultKEMMap =
    [ (KEM_ID
DHKEM_P256_HKDF_SHA256,   (Proxy Curve_P256R1 -> KEMGroup
forall c.
(EllipticCurve c, EllipticCurveDH c) =>
Proxy c -> KEMGroup
KEMGroup Proxy Curve_P256R1
p256,   SHA256 -> KDFHash
forall h. (HashAlgorithm h, KDF h) => h -> KDFHash
KDFHash SHA256
SHA256))
    , (KEM_ID
DHKEM_P384_HKDF_SHA384,   (Proxy Curve_P384R1 -> KEMGroup
forall c.
(EllipticCurve c, EllipticCurveDH c) =>
Proxy c -> KEMGroup
KEMGroup Proxy Curve_P384R1
p384,   SHA384 -> KDFHash
forall h. (HashAlgorithm h, KDF h) => h -> KDFHash
KDFHash SHA384
SHA384))
    , (KEM_ID
DHKEM_P521_HKDF_SHA512,   (Proxy Curve_P521R1 -> KEMGroup
forall c.
(EllipticCurve c, EllipticCurveDH c) =>
Proxy c -> KEMGroup
KEMGroup Proxy Curve_P521R1
p521,   SHA512 -> KDFHash
forall h. (HashAlgorithm h, KDF h) => h -> KDFHash
KDFHash SHA512
SHA512))
    , (KEM_ID
DHKEM_X25519_HKDF_SHA256, (Proxy Curve_X25519 -> KEMGroup
forall c.
(EllipticCurve c, EllipticCurveDH c) =>
Proxy c -> KEMGroup
KEMGroup Proxy Curve_X25519
x25519, SHA256 -> KDFHash
forall h. (HashAlgorithm h, KDF h) => h -> KDFHash
KDFHash SHA256
SHA256))
    , (KEM_ID
DHKEM_X448_HKDF_SHA512,   (Proxy Curve_X448 -> KEMGroup
forall c.
(EllipticCurve c, EllipticCurveDH c) =>
Proxy c -> KEMGroup
KEMGroup Proxy Curve_X448
x448,   SHA512 -> KDFHash
forall h. (HashAlgorithm h, KDF h) => h -> KDFHash
KDFHash SHA512
SHA512))
    ]
{- FOURMOLU_ENABLE -}

----------------------------------------------------------------

data HPKEMap = HPKEMap
    { HPKEMap -> [(KEM_ID, (KEMGroup, KDFHash))]
kemMap :: [(KEM_ID, (KEMGroup, KDFHash))]
    , HPKEMap -> [(KDF_ID, KDFHash)]
kdfMap :: [(KDF_ID, KDFHash)]
    , HPKEMap -> [(AEAD_ID, AEADCipher)]
cipherMap :: [(AEAD_ID, AEADCipher)]
    }

defaultHPKEMap :: HPKEMap
defaultHPKEMap :: HPKEMap
defaultHPKEMap =
    HPKEMap
        { kemMap :: [(KEM_ID, (KEMGroup, KDFHash))]
kemMap = [(KEM_ID, (KEMGroup, KDFHash))]
defaultKEMMap
        , kdfMap :: [(KDF_ID, KDFHash)]
kdfMap = [(KDF_ID, KDFHash)]
defaultKDFMap
        , cipherMap :: [(AEAD_ID, AEADCipher)]
cipherMap = [(AEAD_ID, AEADCipher)]
defaultAEADMap
        }