package org.keycloak.crypto.def;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.math.BigInteger;
import java.nio.charset.Charset;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.PublicKey;
import java.security.SecureRandom;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.ECPublicKey;
import java.security.spec.ECParameterSpec;
import java.security.spec.ECPoint;
import java.security.spec.ECPublicKeySpec;
import java.security.spec.InvalidKeySpecException;
import javax.crypto.KeyAgreement;
import org.bouncycastle.crypto.agreement.kdf.ConcatenationKDFGenerator;
import org.bouncycastle.crypto.engines.AESWrapEngine;
import org.bouncycastle.crypto.params.KDFParameters;
import org.bouncycastle.crypto.params.KeyParameter;
import org.bouncycastle.crypto.util.DigestFactory;
import org.bouncycastle.jce.ECNamedCurveTable;
import org.bouncycastle.jce.spec.ECNamedCurveParameterSpec;
import org.bouncycastle.jce.spec.ECNamedCurveSpec;
import org.keycloak.common.util.Base64Url;
import org.keycloak.jose.jwe.JWEHeader;
import org.keycloak.jose.jwe.JWEKeyStorage;
import org.keycloak.jose.jwe.alg.JWEAlgorithmProvider;
import org.keycloak.jose.jwe.enc.JWEEncryptionProvider;
import org.keycloak.jose.jwk.ECPublicJWK;
import org.keycloak.jose.jwk.JWKUtil;

/* loaded from: input_file:org/keycloak/crypto/def/BCEcdhEsAlgorithmProvider.class */
public class BCEcdhEsAlgorithmProvider implements JWEAlgorithmProvider {
    public byte[] decodeCek(byte[] bArr, Key key, JWEHeader jWEHeader, JWEEncryptionProvider jWEEncryptionProvider) throws Exception {
        byte[] deriveKey = deriveKey(toPublicKey(jWEHeader.getEphemeralPublicKey()), key, getKeyDataLength(jWEHeader.getAlgorithm(), jWEEncryptionProvider), getAlgorithmID(jWEHeader.getAlgorithm(), jWEHeader.getEncryptionAlgorithm()), base64UrlDecode(jWEHeader.getAgreementPartyUInfo()), base64UrlDecode(jWEHeader.getAgreementPartyVInfo()));
        if ("ECDH-ES".equals(jWEHeader.getAlgorithm())) {
            return deriveKey;
        }
        AESWrapEngine aESWrapEngine = new AESWrapEngine();
        aESWrapEngine.init(false, new KeyParameter(deriveKey));
        return aESWrapEngine.unwrap(bArr, 0, bArr.length);
    }

    public byte[] encodeCek(JWEEncryptionProvider jWEEncryptionProvider, JWEKeyStorage jWEKeyStorage, Key key, JWEHeader.JWEHeaderBuilder jWEHeaderBuilder) throws Exception {
        JWEHeader build = jWEHeaderBuilder.build();
        int keyDataLength = getKeyDataLength(build.getAlgorithm(), jWEEncryptionProvider);
        KeyPair generateEcKeyPair = generateEcKeyPair(((ECPublicKey) key).getParams());
        ECPublicKey eCPublicKey = (ECPublicKey) generateEcKeyPair.getPublic();
        ECPrivateKey eCPrivateKey = (ECPrivateKey) generateEcKeyPair.getPrivate();
        byte[] base64UrlDecode = build.getAgreementPartyUInfo() != null ? base64UrlDecode(build.getAgreementPartyUInfo()) : new byte[0];
        byte[] base64UrlDecode2 = build.getAgreementPartyVInfo() != null ? base64UrlDecode(build.getAgreementPartyVInfo()) : new byte[0];
        jWEHeaderBuilder.ephemeralPublicKey(toECPublicJWK(eCPublicKey));
        byte[] deriveKey = deriveKey(key, eCPrivateKey, keyDataLength, getAlgorithmID(build.getAlgorithm(), build.getEncryptionAlgorithm()), base64UrlDecode, base64UrlDecode2);
        if ("ECDH-ES".equals(build.getAlgorithm())) {
            jWEKeyStorage.setCEKBytes(deriveKey);
            jWEEncryptionProvider.deserializeCEK(jWEKeyStorage);
            return new byte[0];
        }
        AESWrapEngine aESWrapEngine = new AESWrapEngine();
        aESWrapEngine.init(true, new KeyParameter(deriveKey));
        byte[] cekBytes = jWEKeyStorage.getCekBytes();
        return aESWrapEngine.wrap(cekBytes, 0, cekBytes.length);
    }

    private byte[] base64UrlDecode(String str) {
        return Base64Url.decode(str == null ? "" : str);
    }

    private static KeyPair generateEcKeyPair(ECParameterSpec eCParameterSpec) {
        try {
            KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("EC");
            keyPairGenerator.initialize(eCParameterSpec, SecureRandom.getInstance("SHA1PRNG"));
            return keyPairGenerator.generateKeyPair();
        } catch (InvalidAlgorithmParameterException | NoSuchAlgorithmException e) {
            throw new IllegalArgumentException(e);
        }
    }

    /* JADX WARN: Type inference failed for: r0v11, types: [byte[], byte[][]] */
    private static byte[] deriveOtherInfo(int i, String str, byte[] bArr, byte[] bArr2) {
        return concat(new byte[]{encodeDataLengthData(str.getBytes(Charset.forName("ASCII"))), encodeDataLengthData(bArr), encodeDataLengthData(bArr2), toByteArray(i), emptyBytes()});
    }

    public static byte[] deriveKey(Key key, Key key2, int i, String str, byte[] bArr, byte[] bArr2) throws InvalidKeyException, NoSuchAlgorithmException, IllegalStateException {
        KDFParameters kDFParameters = new KDFParameters(deriveSharedSecret(key, key2), deriveOtherInfo(i, str, bArr, bArr2));
        ConcatenationKDFGenerator concatenationKDFGenerator = new ConcatenationKDFGenerator(DigestFactory.createSHA256());
        concatenationKDFGenerator.init(kDFParameters);
        int i2 = i / 8;
        byte[] bArr3 = new byte[i2];
        concatenationKDFGenerator.generateBytes(bArr3, 0, i2);
        return bArr3;
    }

    private static ECPublicJWK toECPublicJWK(ECPublicKey eCPublicKey) {
        ECPublicJWK eCPublicJWK = new ECPublicJWK();
        int fieldSize = eCPublicKey.getParams().getCurve().getField().getFieldSize();
        eCPublicJWK.setCrv("P-" + fieldSize);
        eCPublicJWK.setKeyType("EC");
        eCPublicJWK.setX(Base64Url.encode(JWKUtil.toIntegerBytes(eCPublicKey.getW().getAffineX(), fieldSize)));
        eCPublicJWK.setY(Base64Url.encode(JWKUtil.toIntegerBytes(eCPublicKey.getW().getAffineY(), fieldSize)));
        return eCPublicJWK;
    }

    private static PublicKey toPublicKey(ECPublicJWK eCPublicJWK) {
        String crv = eCPublicJWK.getCrv();
        String x = eCPublicJWK.getX();
        String y = eCPublicJWK.getY();
        if (crv == null) {
            throw new IllegalArgumentException("JWK crv must be set");
        }
        if (x == null) {
            throw new IllegalArgumentException("JWK x must be set");
        }
        if (y == null) {
            throw new IllegalArgumentException("JWK y must be set");
        }
        BigInteger bigInteger = new BigInteger(1, Base64Url.decode(x));
        BigInteger bigInteger2 = new BigInteger(1, Base64Url.decode(y));
        String nistToSecCurveName = nistToSecCurveName(crv);
        try {
            ECPoint eCPoint = new ECPoint(bigInteger, bigInteger2);
            ECNamedCurveParameterSpec parameterSpec = ECNamedCurveTable.getParameterSpec(nistToSecCurveName);
            return KeyFactory.getInstance("EC").generatePublic(new ECPublicKeySpec(eCPoint, new ECNamedCurveSpec(nistToSecCurveName, parameterSpec.getCurve(), parameterSpec.getG(), parameterSpec.getN())));
        } catch (NoSuchAlgorithmException | InvalidKeySpecException e) {
            throw new IllegalArgumentException(e);
        }
    }

    private static byte[] deriveSharedSecret(Key key, Key key2) throws NoSuchAlgorithmException, InvalidKeyException, IllegalStateException {
        KeyAgreement keyAgreement = KeyAgreement.getInstance("ECDH");
        keyAgreement.init(key2);
        keyAgreement.doPhase(key, true);
        return keyAgreement.generateSecret();
    }

    private static String getAlgorithmID(String str, String str2) {
        if ("ECDH-ES+A128KW".equals(str) || "ECDH-ES+A192KW".equals(str) || "ECDH-ES+A256KW".equals(str)) {
            return str;
        }
        if ("ECDH-ES".equals(str)) {
            return str2;
        }
        throw new IllegalArgumentException("Unsupported algorithm");
    }

    private static String nistToSecCurveName(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case 75272022:
                if (str.equals("P-256")) {
                    z = false;
                    break;
                }
                break;
            case 75273074:
                if (str.equals("P-384")) {
                    z = true;
                    break;
                }
                break;
            case 75274807:
                if (str.equals("P-521")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return "secp256r1";
            case true:
                return "secp384r1";
            case true:
                return "secp521r1";
            default:
                throw new IllegalArgumentException("Unsupported curve");
        }
    }

    private static int getKeyDataLength(String str, JWEEncryptionProvider jWEEncryptionProvider) {
        if ("ECDH-ES+A128KW".equals(str)) {
            return 128;
        }
        if ("ECDH-ES+A192KW".equals(str)) {
            return 192;
        }
        if ("ECDH-ES+A256KW".equals(str)) {
            return 256;
        }
        if ("ECDH-ES".equals(str)) {
            return jWEEncryptionProvider.getExpectedCEKLength() * 8;
        }
        throw new IllegalArgumentException("Unsupported algorithm");
    }

    /* JADX WARN: Type inference failed for: r0v8, types: [byte[], byte[][]] */
    private static byte[] encodeDataLengthData(byte[] bArr) {
        byte[] bArr2 = bArr != null ? bArr : new byte[0];
        return concat(new byte[]{toByteArray(bArr2.length), bArr2});
    }

    private static byte[] emptyBytes() {
        return new byte[0];
    }

    private static byte[] toByteArray(int i) {
        return new byte[]{(byte) (i >> 24), (byte) (i >> 16), (byte) (i >> 8), (byte) i};
    }

    private static byte[] concat(byte[]... bArr) {
        try {
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            try {
                for (byte[] bArr2 : bArr) {
                    if (bArr2 != null) {
                        byteArrayOutputStream.write(bArr2);
                    }
                }
                byte[] byteArray = byteArrayOutputStream.toByteArray();
                byteArrayOutputStream.close();
                return byteArray;
            } finally {
            }
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }
}
