package org.apereo.cas.support.saml;

import module java.base;
import module java.xml;
import org.apereo.cas.support.saml.util.credential.BasicResourceCredentialFactoryBean;
import org.apereo.cas.support.saml.util.credential.BasicX509CredentialFactoryBean;
import org.apereo.cas.util.CollectionUtils;
import org.apereo.cas.util.EncodingUtils;
import org.apereo.cas.util.LoggingUtils;
import org.apereo.cas.util.ResourceUtils;
import org.apereo.cas.util.function.FunctionUtils;
import lombok.experimental.UtilityClass;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import net.shibboleth.shared.codec.Base64Support;
import net.shibboleth.shared.resolver.CriteriaSet;
import org.apache.commons.lang3.StringUtils;
import org.cryptacular.util.CertUtil;
import org.jooq.lambda.Unchecked;
import org.jspecify.annotations.Nullable;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.core.xml.util.XMLObjectSupport;
import org.opensaml.saml.common.SAMLObject;
import org.opensaml.saml.common.SAMLObjectBuilder;
import org.opensaml.saml.metadata.resolver.filter.impl.SignatureValidationFilter;
import org.opensaml.saml.saml2.core.RequestAbstractType;
import org.opensaml.security.credential.BasicCredential;
import org.opensaml.security.credential.impl.StaticCredentialResolver;
import org.opensaml.soap.common.SOAPObject;
import org.opensaml.soap.common.SOAPObjectBuilder;
import org.opensaml.xmlsec.SecurityConfigurationSupport;
import org.opensaml.xmlsec.SignatureValidationConfiguration;
import org.opensaml.xmlsec.criterion.SignatureValidationConfigurationCriterion;
import org.opensaml.xmlsec.impl.BasicSignatureValidationParametersResolver;
import org.opensaml.xmlsec.keyinfo.impl.BasicProviderKeyInfoCredentialResolver;
import org.opensaml.xmlsec.keyinfo.impl.KeyInfoProvider;
import org.opensaml.xmlsec.keyinfo.impl.provider.DEREncodedKeyValueProvider;
import org.opensaml.xmlsec.keyinfo.impl.provider.DSAKeyValueProvider;
import org.opensaml.xmlsec.keyinfo.impl.provider.InlineX509DataProvider;
import org.opensaml.xmlsec.keyinfo.impl.provider.RSAKeyValueProvider;
import org.opensaml.xmlsec.signature.support.SignatureValidationParametersCriterion;
import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine;
import org.springframework.core.io.Resource;
import org.springframework.core.io.ResourceLoader;
import java.security.cert.X509Certificate;

/**
 * This is {@link SamlUtils}.
 *
 * @author Misagh Moayyed
 * @since 5.0.0
 */
@Slf4j
@UtilityClass
public class SamlUtils {
    private static final ThreadLocal<TransformerFactory> TRANSFORMER_FACTORY_INSTANCE = ThreadLocal.withInitial(
        Unchecked.supplier(() -> {
            val tf = TransformerFactory.newInstance();
            tf.setFeature(XMLConstants.FEATURE_SECURE_PROCESSING, true);
            return tf;
        }));

    /**
     * The constant DEFAULT_ELEMENT_NAME_FIELD.
     */
    private static final String DEFAULT_ELEMENT_NAME_FIELD = "DEFAULT_ELEMENT_NAME";

    private static final int SAML_OBJECT_LOG_ASTERIXLINE_LENGTH = 80;

    /**
     * Gets saml object QName indicated by field {@link #DEFAULT_ELEMENT_NAME_FIELD}.
     *
     * @param objectType the object type
     * @return the saml object QName
     */
    public QName getSamlObjectQName(final Class objectType) {
        try {
            val field = objectType.getField(DEFAULT_ELEMENT_NAME_FIELD);
            return (QName) field.get(null);
        } catch (final Exception e) {
            throw new IllegalStateException("Cannot find/access field " + objectType.getName() + '.' + DEFAULT_ELEMENT_NAME_FIELD, e);
        }
    }

    /**
     * New soap object t.
     *
     * @param <T>        the type parameter
     * @param objectType the object type
     * @return the t
     */
    public <T extends SOAPObject> T newSoapObject(final Class<T> objectType) {
        val qName = getSamlObjectQName(objectType);
        LOGGER.trace("Attempting to create SOAPObject for type: [{}] and QName: [{}]", objectType, qName);
        val builder = (SOAPObjectBuilder<T>)
            XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName);
        return objectType.cast(Objects.requireNonNull(builder).buildObject(qName));
    }

    /**
     * Create a new SAML object.
     *
     * @param <T>        the generic type
     * @param objectType the object type
     * @return the t
     */
    public static <T extends SAMLObject> T newSamlObject(final Class<T> objectType) {
        val qName = getSamlObjectQName(objectType);
        return newSamlObject(objectType, qName);
    }

    /**
     * New saml object and provide type.
     *
     * @param <T>        the type parameter
     * @param objectType the object type
     * @param qName      the q name
     * @return the t
     */
    public static <T extends SAMLObject> T newSamlObject(final Class<T> objectType, final QName qName) {
        LOGGER.trace("Attempting to create SAMLObject for type: [{}] and QName: [{}]", objectType, qName);
        val builder = (SAMLObjectBuilder<T>) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName);
        return objectType.cast(Objects.requireNonNull(builder).buildObject(qName));
    }

    /**
     * Gets root element from resource.
     *
     * @param metadataResource the metadata resource
     * @param configBean       the config bean
     * @return the root element from
     */
    public static Element getRootElementFrom(final InputStream metadataResource, final OpenSamlConfigBean configBean) {
        return FunctionUtils.doUnchecked(() -> {
            try (metadataResource) {
                val document = configBean.getParserPool().parse(metadataResource);
                return document.getDocumentElement();
            }
        });
    }

    /**
     * Read certificate x 509 certificate.
     *
     * @param resource the resource
     * @return the x 509 certificate
     */
    public static X509Certificate readCertificate(final Resource resource) {
        try (val in = resource.getInputStream()) {
            return CertUtil.readCertificate(in);
        } catch (final Exception e) {
            throw new IllegalArgumentException("Error reading certificate " + resource, e);
        }
    }

    /**
     * Transform saml object into string without indenting the final string.
     *
     * @param configBean the config bean
     * @param samlObject the saml object
     * @return the string writer
     * @throws SamlException the saml exception
     */
    public static StringWriter transformSamlObject(final OpenSamlConfigBean configBean, final XMLObject samlObject) throws SamlException {
        return transformSamlObject(configBean, samlObject, false);
    }

    /**
     * Transform saml object t.
     *
     * @param <T>        the type parameter
     * @param configBean the config bean
     * @param xml        the xml
     * @param clazz      the clazz
     * @return the t
     */
    public static <T extends XMLObject> @Nullable T transformSamlObject(final OpenSamlConfigBean configBean, final String xml,
                                                              final Class<T> clazz) {
        return transformSamlObject(configBean, xml.getBytes(StandardCharsets.UTF_8), clazz);
    }

    /**
     * Transform saml object t.
     *
     * @param <T>        the type parameter
     * @param configBean the config bean
     * @param data       the data
     * @param clazz      the clazz
     * @return the type
     */
    public static <T extends XMLObject> @Nullable T transformSamlObject(final OpenSamlConfigBean configBean,
                                                              final byte[] data,
                                                              final Class<T> clazz) {
        if (data != null && data.length > 0) {
            try (val in = new ByteArrayInputStream(data)) {
                val document = configBean.getParserPool().parse(in);
                val root = document.getDocumentElement();
                val marshaller = configBean.getUnmarshallerFactory().getUnmarshaller(root);
                if (marshaller != null) {
                    val result = marshaller.unmarshall(root);
                    if (!clazz.isAssignableFrom(result.getClass())) {
                        throw new ClassCastException("Result [" + result + " is of type "
                            + result.getClass() + " when we were expecting " + clazz);
                    }
                    return (T) result;
                }
            } catch (final Exception e) {
                throw new SamlException(e.getMessage(), e);
            }
        }
        return null;
    }

    /**
     * Transform saml object to String.
     *
     * @param configBean the config bean
     * @param samlObject the saml object
     * @param indent     the indent
     * @return the string
     * @throws SamlException the saml exception
     */
    public static StringWriter transformSamlObject(final OpenSamlConfigBean configBean, final XMLObject samlObject,
                                                   final boolean indent) throws SamlException {
        val writer = new StringWriter();
        try {
            val marshaller = configBean.getMarshallerFactory().getMarshaller(samlObject.getElementQName());
            if (marshaller != null) {
                val element = marshaller.marshall(samlObject);
                val domSource = new DOMSource(element);

                val result = new StreamResult(writer);
                val transformer = TRANSFORMER_FACTORY_INSTANCE.get().newTransformer();

                if (indent) {
                    transformer.setOutputProperty(OutputKeys.INDENT, "yes");
                    transformer.setOutputProperty("{http://xml.apache.org/xslt}indent-amount", "4");
                }
                transformer.transform(domSource, result);
            }
        } catch (final Exception e) {
            throw new SamlException(e.getMessage(), e);
        }
        return writer;
    }

    /**
     * Build signature validation filter if needed.
     *
     * @param signatureResourceLocation the signature resource location
     * @return the metadata filter
     * @throws Exception the exception
     */
    public static @Nullable SignatureValidationFilter buildSignatureValidationFilter(final String signatureResourceLocation) throws Exception {
        val resource = ResourceUtils.getResourceFrom(signatureResourceLocation);
        return buildSignatureValidationFilter(resource);
    }

    /**
     * Build signature validation filter if needed.
     *
     * @param resourceLoader            the resource loader
     * @param signatureResourceLocation the signature resource location
     * @return the metadata filter
     */
    public static @Nullable SignatureValidationFilter buildSignatureValidationFilter(final ResourceLoader resourceLoader,
                                                                                     final String signatureResourceLocation) {
        try {
            val resource = resourceLoader.getResource(signatureResourceLocation);
            return buildSignatureValidationFilter(resource);
        } catch (final Exception e) {
            LOGGER.debug(e.getMessage(), e);
        }
        return null;
    }

    /**
     * Build signature validation filter if needed.
     *
     * @param signatureResourceLocation the signature resource location
     * @return the metadata filter
     * @throws Exception the exception
     */
    public static @Nullable SignatureValidationFilter buildSignatureValidationFilter(final Resource signatureResourceLocation) throws Exception {
        if (!ResourceUtils.doesResourceExist(signatureResourceLocation)) {
            LOGGER.warn("Resource [{}] cannot be located", signatureResourceLocation);
            return null;
        }

        val keyInfoProviderList = new ArrayList<KeyInfoProvider>(4);
        keyInfoProviderList.add(new RSAKeyValueProvider());
        keyInfoProviderList.add(new DSAKeyValueProvider());
        keyInfoProviderList.add(new DEREncodedKeyValueProvider());
        keyInfoProviderList.add(new InlineX509DataProvider());

        LOGGER.debug("Attempting to resolve credentials from [{}]", signatureResourceLocation);
        val credential = buildCredentialForMetadataSignatureValidation(signatureResourceLocation);
        LOGGER.info("Successfully resolved credentials from [{}]", signatureResourceLocation);
        Objects.requireNonNull(credential, "No credential found");

        LOGGER.debug("Configuring credential resolver for key signature trust engine @ [{}]",
            credential.getCredentialType().getSimpleName());
        val resolver = new StaticCredentialResolver(credential);
        val keyInfoResolver = new BasicProviderKeyInfoCredentialResolver(keyInfoProviderList);
        val trustEngine = new ExplicitKeySignatureTrustEngine(resolver, keyInfoResolver);

        LOGGER.debug("Adding signature validation filter based on the configured trust engine");
        val signatureValidationFilter = new SignatureValidationFilter(trustEngine);
        signatureValidationFilter.setDefaultCriteria(buildSignatureValidationFilterCriteria());
        LOGGER.debug("Added metadata SignatureValidationFilter with signature from [{}]", signatureResourceLocation);
        return signatureValidationFilter;
    }

    /**
     * Log saml object.
     *
     * @param configBean the config bean
     * @param samlObject the saml object
     * @throws SamlException the saml exception
     */
    public static void logSamlObject(final OpenSamlConfigBean configBean, final XMLObject samlObject) throws SamlException {
        if (LOGGER.isDebugEnabled() || LoggingUtils.isProtocolMessageLoggerEnabled()) {
            val repeat = "*".repeat(SAML_OBJECT_LOG_ASTERIXLINE_LENGTH);
            LOGGER.debug(repeat);
            try (val writer = transformSamlObject(configBean, samlObject, true)) {
                LOGGER.debug("Logging [{}]\n\n[{}]\n\n", samlObject.getClass().getName(), writer);
                LOGGER.debug(repeat);
                LoggingUtils.protocolMessage("SAML " + samlObject.getClass().getName(),
                    Map.of(), writer.toString());
            } catch (final Exception e) {
                throw new SamlException(e.getMessage(), e);
            }
        }
    }

    /**
     * Is dynamic metadata query configured ?
     *
     * @param metadataLocation - the location of the metadata to resolve
     * @return true/false
     */
    public static boolean isDynamicMetadataQueryConfigured(final String metadataLocation) {
        return StringUtils.isNotBlank(metadataLocation) && metadataLocation.trim().endsWith("/entities/{0}");
    }

    /**
     * Build credential for metadata signature validation basic credential.
     *
     * @param resource the resource
     * @return the basic credential
     * @throws Exception the exception
     */
    private static @Nullable BasicCredential buildCredentialForMetadataSignatureValidation(final Resource resource) throws Exception {
        try {
            val x509FactoryBean = new BasicX509CredentialFactoryBean();
            x509FactoryBean.setCertificateResources(CollectionUtils.wrap(resource));
            return x509FactoryBean.getObject();
        } catch (final Exception e) {
            LOGGER.trace(e.getMessage(), e);
            LOGGER.debug("Credential cannot be extracted from [{}] via X.509. Treating it as a public key to locate credential...", resource);
            val credentialFactoryBean = new BasicResourceCredentialFactoryBean();
            credentialFactoryBean.setPublicKeyInfo(resource);
            return credentialFactoryBean.getObject();
        }
    }

    private static CriteriaSet buildSignatureValidationFilterCriteria() {
        val criteriaSet = new CriteriaSet();

        val sigConfigs = new ArrayList<SignatureValidationConfiguration>();
        sigConfigs.add(SecurityConfigurationSupport.getGlobalSignatureValidationConfiguration());

        val paramsResolver = new BasicSignatureValidationParametersResolver();
        val configCriteria = new CriteriaSet(new SignatureValidationConfigurationCriterion(sigConfigs));
        val params = FunctionUtils.doUnchecked(() -> paramsResolver.resolveSingle(configCriteria));
        if (params != null) {
            criteriaSet.add(new SignatureValidationParametersCriterion(params), true);
        }
        return criteriaSet;
    }

    /**
     * Convert to saml object.
     *
     * @param <T>                the type parameter
     * @param openSamlConfigBean the open saml config bean
     * @param requestValue       the request value
     * @param clazz              the clazz
     * @return the final xml object
     */
    public static <T extends RequestAbstractType> T convertToSamlObject(final OpenSamlConfigBean openSamlConfigBean,
                                                                        final String requestValue, final Class<T> clazz) {
        try {
            LOGGER.trace("Retrieving SAML request from [{}]", requestValue);
            val decodedBytes = Base64Support.decode(requestValue);
            try (val is = new InflaterInputStream(new ByteArrayInputStream(decodedBytes), new Inflater(true))) {
                return clazz.cast(XMLObjectSupport.unmarshallFromInputStream(openSamlConfigBean.getParserPool(), is));
            }
        } catch (final Throwable e) {
            return FunctionUtils.doUnchecked(() -> {
                val encodedRequest = EncodingUtils.decodeBase64(requestValue.getBytes(StandardCharsets.UTF_8));
                try (val is = new ByteArrayInputStream(encodedRequest)) {
                    return clazz.cast(XMLObjectSupport.unmarshallFromInputStream(openSamlConfigBean.getParserPool(), is));
                }
            });
        }
    }
}
