{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeSynonymInstances #-}

module Crypto.HPKE.AEAD (
    Aead (..),
) where

import Crypto.Cipher.AES (AES128, AES256)
import qualified Crypto.Cipher.ChaChaPoly1305 as CCP
import Crypto.Cipher.Types (AEAD (..), AuthTag (..), BlockCipher)
import qualified Crypto.Cipher.Types as Cipher
import Data.ByteArray (ByteArray, ByteArrayAccess)
import qualified Data.ByteString as BS
import Data.Tuple (swap)

import Crypto.HPKE.Types

-- $setup
-- >>> :set -XOverloadedStrings
-- >>> import Data.ByteString

----------------------------------------------------------------

class Aead a where
    sealA :: Proxy a -> Key -> Seal
    openA :: Proxy a -> Key -> Open
    nK :: Proxy a -> Int
    nN :: Proxy a -> Int
    nT :: Proxy a -> Int

mkSealA :: AeadEncrypt -> p -> Key -> Seal
mkSealA :: forall p. AeadEncrypt -> p -> Key -> Seal
mkSealA AeadEncrypt
enc p
_ Key
key Key
nonce Key
aad Key
plain = do
    (cipher, AuthTag tag) <- Key -> Key -> Key -> Key -> Either HPKEError (Key, AuthTag)
AeadEncrypt
enc Key
key Key
nonce Key
aad Key
plain
    return (cipher <> convert tag)

mkOpenA :: AeadDecrypt -> Int -> p -> Key -> Open
mkOpenA :: forall p. AeadEncrypt -> Int -> p -> Key -> Seal
mkOpenA AeadEncrypt
dec Int
len p
_ Key
key Key
nonce Key
aad Key
cipher = do
    (plain, AuthTag tag) <- Key -> Key -> Key -> Key -> Either HPKEError (Key, AuthTag)
AeadEncrypt
dec Key
key Key
nonce Key
aad Key
cipher'
    if tag == convert tag'
        then Right plain
        else Left $ OpenError "tag mismatch"
  where
    brkpt :: Int
brkpt = Key -> Int
BS.length Key
cipher Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len
    (Key
cipher', Key
tag') = Int -> Key -> (Key, Key)
BS.splitAt Int
brkpt Key
cipher

----------------------------------------------------------------

-- 'forall' is necessary because of 'type'
type AeadEncrypt =
    forall k n a t
     . ( ByteArray k
       , ByteArrayAccess n
       , ByteArrayAccess a
       , ByteArray t
       )
    => k -> n -> a -> t -> Either HPKEError (t, AuthTag)

type AeadDecrypt =
    forall k n a t
     . ( ByteArray k
       , ByteArrayAccess n
       , ByteArrayAccess a
       , ByteArray t
       )
    => k -> n -> a -> t -> Either HPKEError (t, AuthTag)

----------------------------------------------------------------

initAES
    :: ( ByteArray k
       , ByteArrayAccess n
       , BlockCipher c
       )
    => k -> n -> Maybe (AEAD c)
initAES :: forall k n c.
(ByteArray k, ByteArrayAccess n, BlockCipher c) =>
k -> n -> Maybe (AEAD c)
initAES k
key n
nonce = case CryptoFailable (AEAD c)
mst of
    CryptoPassed AEAD c
st -> AEAD c -> Maybe (AEAD c)
forall a. a -> Maybe a
Just AEAD c
st
    CryptoFailed CryptoError
_ -> Maybe (AEAD c)
forall a. Maybe a
Nothing
  where
    mst :: CryptoFailable (AEAD c)
mst = do
        st0 <- k -> CryptoFailable c
forall key. ByteArray key => key -> CryptoFailable c
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
Cipher.cipherInit k
key
        Cipher.aeadInit Cipher.AEAD_GCM st0 nonce

----------------------------------------------------------------

-- | From RFC 9180 A.1
--
-- >>> let key = "\x45\x31\x68\x5d\x41\xd6\x5f\x03\xdc\x48\xf6\xb8\x30\x2c\x05\xb0" :: ByteString
-- >>> let nonce = "\x56\xd8\x90\xe5\xac\xca\xaf\x01\x1c\xff\x4b\x7d" :: ByteString
-- >>> let aad = "\x43\x6f\x75\x6e\x74\x2d\x30" :: ByteString
-- >>> let plain = "The quick brown fox jumps over the very lazy dog." :: ByteString
-- >>> let proxy = Proxy :: Proxy AES128
-- >>> sealA proxy key nonce aad plain >>= openA proxy key nonce aad
-- Right "The quick brown fox jumps over the very lazy dog."
instance Aead AES128 where
    sealA :: Proxy AES128 -> Key -> Seal
sealA = AeadEncrypt -> Proxy AES128 -> Key -> Seal
forall p. AeadEncrypt -> p -> Key -> Seal
mkSealA k -> n -> a -> t -> Either HPKEError (t, AuthTag)
AeadEncrypt
encryptAes128gcm
    openA :: Proxy AES128 -> Key -> Seal
openA = AeadEncrypt -> Int -> Proxy AES128 -> Key -> Seal
forall p. AeadEncrypt -> Int -> p -> Key -> Seal
mkOpenA k -> n -> a -> t -> Either HPKEError (t, AuthTag)
AeadEncrypt
decryptAes128gcm Int
aes128tagLength
    nK :: Proxy AES128 -> Int
nK = Int -> Proxy AES128 -> Int
forall a b. a -> b -> a
const Int
16
    nN :: Proxy AES128 -> Int
nN = Int -> Proxy AES128 -> Int
forall a b. a -> b -> a
const Int
12
    nT :: Proxy AES128 -> Int
nT = Int -> Proxy AES128 -> Int
forall a b. a -> b -> a
const Int
16

encryptAes128gcm :: AeadEncrypt
encryptAes128gcm :: AeadEncrypt
encryptAes128gcm k
key n
nonce a
aad t
plain = case k -> n -> Maybe (AEAD AES128)
forall k n c.
(ByteArray k, ByteArrayAccess n, BlockCipher c) =>
k -> n -> Maybe (AEAD c)
initAES k
key n
nonce of
    Maybe (AEAD AES128)
Nothing -> HPKEError -> Either HPKEError (t, AuthTag)
forall a b. a -> Either a b
Left (HPKEError -> Either HPKEError (t, AuthTag))
-> HPKEError -> Either HPKEError (t, AuthTag)
forall a b. (a -> b) -> a -> b
$ String -> HPKEError
SealError String
"encryptAes128gcm"
    Just AEAD AES128
st -> (t, AuthTag) -> Either HPKEError (t, AuthTag)
forall a b. b -> Either a b
Right ((t, AuthTag) -> Either HPKEError (t, AuthTag))
-> (t, AuthTag) -> Either HPKEError (t, AuthTag)
forall a b. (a -> b) -> a -> b
$ AEAD AES128 -> a -> t -> Int -> (t, AuthTag)
forall a t cipher.
(ByteArrayAccess a, ByteArray t) =>
AEAD cipher -> a -> t -> Int -> (t, AuthTag)
simpleEncrypt (AEAD AES128
st :: AEAD AES128) a
aad t
plain Int
aes128tagLength

decryptAes128gcm :: AeadDecrypt
decryptAes128gcm :: AeadEncrypt
decryptAes128gcm k
key n
nonce a
aad t
cipher = case k -> n -> Maybe (AEAD AES128)
forall k n c.
(ByteArray k, ByteArrayAccess n, BlockCipher c) =>
k -> n -> Maybe (AEAD c)
initAES k
key n
nonce of
    Maybe (AEAD AES128)
Nothing -> HPKEError -> Either HPKEError (t, AuthTag)
forall a b. a -> Either a b
Left (HPKEError -> Either HPKEError (t, AuthTag))
-> HPKEError -> Either HPKEError (t, AuthTag)
forall a b. (a -> b) -> a -> b
$ String -> HPKEError
OpenError String
"decrypttAes128gcm"
    Just AEAD AES128
st -> (t, AuthTag) -> Either HPKEError (t, AuthTag)
forall a b. b -> Either a b
Right ((t, AuthTag) -> Either HPKEError (t, AuthTag))
-> (t, AuthTag) -> Either HPKEError (t, AuthTag)
forall a b. (a -> b) -> a -> b
$ AEAD AES128 -> a -> t -> Int -> (t, AuthTag)
forall a t cipher.
(ByteArrayAccess a, ByteArray t) =>
AEAD cipher -> a -> t -> Int -> (t, AuthTag)
simpleDecrypt (AEAD AES128
st :: AEAD AES128) a
aad t
cipher Int
aes128tagLength

aes128tagLength :: Int
aes128tagLength :: Int
aes128tagLength = Int
16

----------------------------------------------------------------

-- | From RFC 9180 A.6
--
-- >>> let key = "\x75\x1e\x34\x6c\xe8\xf0\xdd\xb2\x30\x5c\x8a\x2a\x85\xc7\x0d\x5c\xf5\x59\xc5\x30\x93\x65\x6b\xe6\x36\xb9\x40\x6d\x4d\x7d\x1b\x70" :: ByteString
-- >>> let nonce = "\x55\xff\x7a\x7d\x73\x9c\x69\xf4\x4b\x25\x44\x7b" :: ByteString
-- >>> let aad = "\x43\x6f\x75\x6e\x74\x2d\x30" :: ByteString
-- >>> let plain = "The quick brown fox jumps over the very lazy dog." :: ByteString
-- >>> let proxy = Proxy :: Proxy AES256
-- >>> sealA proxy key nonce aad plain >>= openA proxy key nonce aad
-- Right "The quick brown fox jumps over the very lazy dog."
instance Aead AES256 where
    sealA :: Proxy AES256 -> Key -> Seal
sealA = AeadEncrypt -> Proxy AES256 -> Key -> Seal
forall p. AeadEncrypt -> p -> Key -> Seal
mkSealA k -> n -> a -> t -> Either HPKEError (t, AuthTag)
AeadEncrypt
encryptAes256gcm
    openA :: Proxy AES256 -> Key -> Seal
openA = AeadEncrypt -> Int -> Proxy AES256 -> Key -> Seal
forall p. AeadEncrypt -> Int -> p -> Key -> Seal
mkOpenA k -> n -> a -> t -> Either HPKEError (t, AuthTag)
AeadEncrypt
decryptAes256gcm Int
aes256tagLength
    nK :: Proxy AES256 -> Int
nK = Int -> Proxy AES256 -> Int
forall a b. a -> b -> a
const Int
32
    nN :: Proxy AES256 -> Int
nN = Int -> Proxy AES256 -> Int
forall a b. a -> b -> a
const Int
12
    nT :: Proxy AES256 -> Int
nT = Int -> Proxy AES256 -> Int
forall a b. a -> b -> a
const Int
16

encryptAes256gcm :: AeadEncrypt
encryptAes256gcm :: AeadEncrypt
encryptAes256gcm k
key n
nonce a
aad t
plain = case k -> n -> Maybe (AEAD AES256)
forall k n c.
(ByteArray k, ByteArrayAccess n, BlockCipher c) =>
k -> n -> Maybe (AEAD c)
initAES k
key n
nonce of
    Maybe (AEAD AES256)
Nothing -> HPKEError -> Either HPKEError (t, AuthTag)
forall a b. a -> Either a b
Left (HPKEError -> Either HPKEError (t, AuthTag))
-> HPKEError -> Either HPKEError (t, AuthTag)
forall a b. (a -> b) -> a -> b
$ String -> HPKEError
SealError String
"encryptAes256gcm"
    Just AEAD AES256
st -> (t, AuthTag) -> Either HPKEError (t, AuthTag)
forall a b. b -> Either a b
Right ((t, AuthTag) -> Either HPKEError (t, AuthTag))
-> (t, AuthTag) -> Either HPKEError (t, AuthTag)
forall a b. (a -> b) -> a -> b
$ AEAD AES256 -> a -> t -> Int -> (t, AuthTag)
forall a t cipher.
(ByteArrayAccess a, ByteArray t) =>
AEAD cipher -> a -> t -> Int -> (t, AuthTag)
simpleEncrypt (AEAD AES256
st :: AEAD AES256) a
aad t
plain Int
aes256tagLength

decryptAes256gcm :: AeadDecrypt
decryptAes256gcm :: AeadEncrypt
decryptAes256gcm k
key n
nonce a
aad t
cipher = case k -> n -> Maybe (AEAD AES256)
forall k n c.
(ByteArray k, ByteArrayAccess n, BlockCipher c) =>
k -> n -> Maybe (AEAD c)
initAES k
key n
nonce of
    Maybe (AEAD AES256)
Nothing -> HPKEError -> Either HPKEError (t, AuthTag)
forall a b. a -> Either a b
Left (HPKEError -> Either HPKEError (t, AuthTag))
-> HPKEError -> Either HPKEError (t, AuthTag)
forall a b. (a -> b) -> a -> b
$ String -> HPKEError
OpenError String
"decryptAes256gcm"
    Just AEAD AES256
st -> (t, AuthTag) -> Either HPKEError (t, AuthTag)
forall a b. b -> Either a b
Right ((t, AuthTag) -> Either HPKEError (t, AuthTag))
-> (t, AuthTag) -> Either HPKEError (t, AuthTag)
forall a b. (a -> b) -> a -> b
$ AEAD AES256 -> a -> t -> Int -> (t, AuthTag)
forall a t cipher.
(ByteArrayAccess a, ByteArray t) =>
AEAD cipher -> a -> t -> Int -> (t, AuthTag)
simpleDecrypt (AEAD AES256
st :: AEAD AES256) a
aad t
cipher Int
aes256tagLength

aes256tagLength :: Int
aes256tagLength :: Int
aes256tagLength = Int
16

----------------------------------------------------------------

-- | From RFC 9180 A.5
--
-- >>> let key = "\xa8\xf4\x54\x90\xa9\x2a\x3b\x04\xd1\xdb\xf6\xcf\x2c\x39\x39\xad\x8b\xfc\x9b\xfc\xb9\x7c\x04\xbf\xfe\x11\x67\x30\xc9\xdf\xe3\xfc" :: ByteString
-- >>> let nonce = "\x72\x6b\x43\x90\xed\x22\x09\x80\x9f\x58\xc6\x93" :: ByteString
-- >>> let aad = "\x43\x6f\x75\x6e\x74\x2d\x30" :: ByteString
-- >>> let plain = "The quick brown fox jumps over the very lazy dog." :: ByteString
-- >>> let proxy = Proxy :: Proxy CCP.ChaCha20Poly1305
-- >>> sealA proxy key nonce aad plain >>= openA proxy key nonce aad
-- Right "The quick brown fox jumps over the very lazy dog."
instance Aead CCP.ChaCha20Poly1305 where
    sealA :: Proxy ChaCha20Poly1305 -> Key -> Seal
sealA = AeadEncrypt -> Proxy ChaCha20Poly1305 -> Key -> Seal
forall p. AeadEncrypt -> p -> Key -> Seal
mkSealA k -> n -> a -> t -> Either HPKEError (t, AuthTag)
AeadEncrypt
encryptChacha20poly1305
    openA :: Proxy ChaCha20Poly1305 -> Key -> Seal
openA = AeadEncrypt -> Int -> Proxy ChaCha20Poly1305 -> Key -> Seal
forall p. AeadEncrypt -> Int -> p -> Key -> Seal
mkOpenA k -> n -> a -> t -> Either HPKEError (t, AuthTag)
AeadEncrypt
decryptChacha20poly1305 Int
chacha20poly1305tagLength
    nK :: Proxy ChaCha20Poly1305 -> Int
nK = Int -> Proxy ChaCha20Poly1305 -> Int
forall a b. a -> b -> a
const Int
32
    nN :: Proxy ChaCha20Poly1305 -> Int
nN = Int -> Proxy ChaCha20Poly1305 -> Int
forall a b. a -> b -> a
const Int
12
    nT :: Proxy ChaCha20Poly1305 -> Int
nT = Int -> Proxy ChaCha20Poly1305 -> Int
forall a b. a -> b -> a
const Int
16

encryptChacha20poly1305 :: AeadEncrypt
encryptChacha20poly1305 :: AeadEncrypt
encryptChacha20poly1305 k
key n
nonce a
aad t
plain =
    case k -> n -> CryptoFailable (AEAD ChaCha20Poly1305)
forall k n.
(ByteArrayAccess k, ByteArrayAccess n) =>
k -> n -> CryptoFailable (AEAD ChaCha20Poly1305)
CCP.aeadChacha20poly1305Init k
key n
nonce of
        CryptoPassed AEAD ChaCha20Poly1305
st -> (t, AuthTag) -> Either HPKEError (t, AuthTag)
forall a b. b -> Either a b
Right ((t, AuthTag) -> Either HPKEError (t, AuthTag))
-> (t, AuthTag) -> Either HPKEError (t, AuthTag)
forall a b. (a -> b) -> a -> b
$ AEAD ChaCha20Poly1305 -> a -> t -> Int -> (t, AuthTag)
forall a t cipher.
(ByteArrayAccess a, ByteArray t) =>
AEAD cipher -> a -> t -> Int -> (t, AuthTag)
simpleEncrypt AEAD ChaCha20Poly1305
st a
aad t
plain Int
chacha20poly1305tagLength
        CryptoFailed CryptoError
_ -> HPKEError -> Either HPKEError (t, AuthTag)
forall a b. a -> Either a b
Left (HPKEError -> Either HPKEError (t, AuthTag))
-> HPKEError -> Either HPKEError (t, AuthTag)
forall a b. (a -> b) -> a -> b
$ String -> HPKEError
SealError String
"encryptChacha20poly1305"

decryptChacha20poly1305 :: AeadDecrypt
decryptChacha20poly1305 :: AeadEncrypt
decryptChacha20poly1305 k
key n
nonce a
aad t
cipher =
    case k -> n -> CryptoFailable (AEAD ChaCha20Poly1305)
forall k n.
(ByteArrayAccess k, ByteArrayAccess n) =>
k -> n -> CryptoFailable (AEAD ChaCha20Poly1305)
CCP.aeadChacha20poly1305Init k
key n
nonce of
        CryptoPassed AEAD ChaCha20Poly1305
st -> (t, AuthTag) -> Either HPKEError (t, AuthTag)
forall a b. b -> Either a b
Right ((t, AuthTag) -> Either HPKEError (t, AuthTag))
-> (t, AuthTag) -> Either HPKEError (t, AuthTag)
forall a b. (a -> b) -> a -> b
$ AEAD ChaCha20Poly1305 -> a -> t -> Int -> (t, AuthTag)
forall a t cipher.
(ByteArrayAccess a, ByteArray t) =>
AEAD cipher -> a -> t -> Int -> (t, AuthTag)
simpleDecrypt AEAD ChaCha20Poly1305
st a
aad t
cipher Int
chacha20poly1305tagLength
        CryptoFailed CryptoError
_ -> HPKEError -> Either HPKEError (t, AuthTag)
forall a b. a -> Either a b
Left (HPKEError -> Either HPKEError (t, AuthTag))
-> HPKEError -> Either HPKEError (t, AuthTag)
forall a b. (a -> b) -> a -> b
$ String -> HPKEError
SealError String
"decryptChacha20poly1305"

chacha20poly1305tagLength :: Int
chacha20poly1305tagLength :: Int
chacha20poly1305tagLength = Int
16

----------------------------------------------------------------

simpleEncrypt
    :: (ByteArrayAccess a, ByteArray t)
    => AEAD cipher -> a -> t -> Int -> (t, AuthTag)
simpleEncrypt :: forall a t cipher.
(ByteArrayAccess a, ByteArray t) =>
AEAD cipher -> a -> t -> Int -> (t, AuthTag)
simpleEncrypt AEAD cipher
st a
aad t
plain Int
taglen =
    (AuthTag, t) -> (t, AuthTag)
forall a b. (a, b) -> (b, a)
swap ((AuthTag, t) -> (t, AuthTag)) -> (AuthTag, t) -> (t, AuthTag)
forall a b. (a -> b) -> a -> b
$ AEAD cipher -> a -> t -> Int -> (AuthTag, t)
forall aad ba a.
(ByteArrayAccess aad, ByteArray ba) =>
AEAD a -> aad -> ba -> Int -> (AuthTag, ba)
Cipher.aeadSimpleEncrypt AEAD cipher
st a
aad t
plain Int
taglen

simpleDecrypt
    :: (ByteArrayAccess a, ByteArray t)
    => AEAD cipher -> a -> t -> Int -> (t, AuthTag)
simpleDecrypt :: forall a t cipher.
(ByteArrayAccess a, ByteArray t) =>
AEAD cipher -> a -> t -> Int -> (t, AuthTag)
simpleDecrypt AEAD cipher
st a
aad t
cipher Int
taglen = (t
plain, AuthTag
tag)
  where
    st2 :: AEAD cipher
st2 = AEAD cipher -> a -> AEAD cipher
forall aad cipher.
ByteArrayAccess aad =>
AEAD cipher -> aad -> AEAD cipher
Cipher.aeadAppendHeader AEAD cipher
st a
aad
    (t
plain, AEAD cipher
st3) = AEAD cipher -> t -> (t, AEAD cipher)
forall ba cipher.
ByteArray ba =>
AEAD cipher -> ba -> (ba, AEAD cipher)
Cipher.aeadDecrypt AEAD cipher
st2 t
cipher
    tag :: AuthTag
tag = AEAD cipher -> Int -> AuthTag
forall cipher. AEAD cipher -> Int -> AuthTag
Cipher.aeadFinalize AEAD cipher
st3 Int
taglen