package com.yubico.core;
import module java.base;
import org.apereo.cas.configuration.CasConfigurationProperties;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.yubico.core.SessionManager;
import com.yubico.core.WebAuthnCache;
import com.yubico.data.AssertionRequestWrapper;
import com.yubico.data.AssertionResponse;
import com.yubico.data.CredentialRegistration;
import com.yubico.data.RegistrationRequest;
import com.yubico.data.RegistrationResponse;
import com.yubico.internal.util.CertificateParser;
import com.yubico.internal.util.JacksonCodecs;
import com.yubico.util.Either;
import com.yubico.webauthn.FinishAssertionOptions;
import com.yubico.webauthn.FinishRegistrationOptions;
import com.yubico.webauthn.RegisteredCredential;
import com.yubico.webauthn.RegistrationResult;
import com.yubico.webauthn.RelyingParty;
import com.yubico.webauthn.StartAssertionOptions;
import com.yubico.webauthn.StartRegistrationOptions;
import com.yubico.webauthn.attestation.Attestation;
import com.yubico.webauthn.attestation.AttestationMetadataSource;
import com.yubico.webauthn.data.AuthenticatorData;
import com.yubico.webauthn.data.AuthenticatorSelectionCriteria;
import com.yubico.webauthn.data.AuthenticatorTransport;
import com.yubico.webauthn.data.ByteArray;
import com.yubico.webauthn.data.ResidentKeyRequirement;
import com.yubico.webauthn.data.UserIdentity;
import com.yubico.webauthn.exception.AssertionFailedException;
import com.yubico.webauthn.exception.RegistrationFailedException;
import lombok.AllArgsConstructor;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.Value;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.jooq.lambda.Unchecked;
import org.jspecify.annotations.NonNull;
import jakarta.servlet.http.HttpServletRequest;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;

@Setter
@Slf4j
@RequiredArgsConstructor
public class WebAuthnServer {
    private static final int IDENTIFIER_LENGTH = 32;
    private static final ObjectMapper OBJECT_MAPPER = JacksonCodecs.json();

    private final RegistrationStorage userStorage;
    private final WebAuthnCache<RegistrationRequest> registerRequestStorage;
    private final WebAuthnCache<AssertionRequestWrapper> assertRequestStorage;
    private final RelyingParty relyingParty;
    private final SessionManager sessionManager;
    private final CasConfigurationProperties casProperties;

    public Either<String, RegistrationRequest> startRegistration(
        final HttpServletRequest request,
        @NonNull final String username,
        final Optional<String> displayName,
        final Optional<String> credentialNickname,
        final ResidentKeyRequirement residentKeyRequirement,
        final Optional<ByteArray> sessionToken) {

        LOGGER.trace("Starting registration operation for username: [{}], credentialNickname: [{}]", username, credentialNickname);
        val registrations = userStorage.getRegistrationsByUsername(username);
        val existingUser = registrations.stream().findAny().map(CredentialRegistration::getUserIdentity);
        val permissionGranted = casProperties.getAuthn().getMfa().getWebAuthn().getCore().isMultipleDeviceRegistrationEnabled()
            || existingUser.map(userIdentity -> sessionManager.isSessionForUser(request, userIdentity.getId(), sessionToken)).orElse(true);

        if (permissionGranted) {
            val registrationUserId = existingUser.orElseGet(() ->
                UserIdentity.builder()
                    .name(username)
                    .displayName(displayName.orElseThrow())
                    .id(SessionManager.generateRandom(IDENTIFIER_LENGTH))
                    .build()
            );

            val registrationRequest = new RegistrationRequest(
                username,
                credentialNickname,
                SessionManager.generateRandom(IDENTIFIER_LENGTH),
                relyingParty.startRegistration(
                    StartRegistrationOptions.builder()
                        .user(registrationUserId)
                        .authenticatorSelection(AuthenticatorSelectionCriteria.builder()
                            .residentKey(residentKeyRequirement)
                            .build()
                        )
                        .build()
                ),
                Optional.of(sessionManager.createSession(request, registrationUserId.getId()))
            );
            registerRequestStorage.put(request, registrationRequest.requestId(), registrationRequest);
            return Either.right(registrationRequest);
        }
        return Either.left("The username %s is already registered and/or has an active session.".formatted(username));
    }

    public Either<List<String>, SuccessfulRegistrationResult> finishRegistration(final HttpServletRequest request, final String responseJson) {
        LOGGER.trace("Finishing registration with response: [{}]", responseJson);
        RegistrationResponse registrationResponse;
        try {
            registrationResponse = OBJECT_MAPPER.readValue(responseJson, RegistrationResponse.class);
        } catch (final Exception e) {
            LOGGER.error("Registration failed; response: [{}]", responseJson, e);
            return Either.left(List.of("Registration failed", "Failed to decode response object.", e.getMessage()));
        }

        val registrationRequest = registerRequestStorage.getIfPresent(request, registrationResponse.requestId());
        registerRequestStorage.invalidate(request, registrationResponse.requestId());

        if (registrationRequest == null) {
            LOGGER.debug("Finishing registration failed with: [{}]", responseJson);
            return Either.left(List.of("Registration failed", "No such registration in progress."));
        } else {
            try {
                val registration = relyingParty.finishRegistration(
                    FinishRegistrationOptions.builder()
                        .request(registrationRequest.publicKeyCredentialCreationOptions())
                        .response(registrationResponse.credential())
                        .build()
                );

                if (userStorage.userExists(registrationRequest.username())) {
                    var permissionGranted = false;

                    val isValidSession = registrationRequest.sessionToken().map(token ->
                        sessionManager.isSessionForUser(request, registrationRequest.publicKeyCredentialCreationOptions().getUser().getId(), token)
                    ).orElse(false);

                    LOGGER.debug("Session token: [{}], valid session [{}]", registrationRequest.sessionToken(), isValidSession);

                    if (isValidSession) {
                        permissionGranted = true;
                        LOGGER.info("Session token accepted for user [{}]", registrationRequest.publicKeyCredentialCreationOptions().getUser().getId());
                    }

                    LOGGER.debug("Permission granted to finish registration: [{}]", permissionGranted);

                    if (!permissionGranted) {
                        throw new RegistrationFailedException(new IllegalArgumentException("User %s already exists".formatted(registrationRequest.username())));
                    }
                }

                return Either.right(
                    new SuccessfulRegistrationResult(
                        registrationRequest,
                        registrationResponse,
                        addRegistration(
                            registrationRequest.publicKeyCredentialCreationOptions().getUser(),
                            registrationRequest.credentialNickname(),
                            registration
                        ),
                        registration.isAttestationTrusted() || relyingParty.isAllowUntrustedAttestation(),
                        sessionManager.createSession(request, registrationRequest.publicKeyCredentialCreationOptions().getUser().getId())
                    )
                );
            } catch (final RegistrationFailedException e) {
                LOGGER.debug("Finishing registration failed with: [{}]", responseJson, e);
                return Either.left(List.of("Registration failed", e.getMessage()));
            } catch (final Exception e) {
                LOGGER.error("Finishing registration failed with: [{}]", responseJson, e);
                return Either.left(List.of("Registration failed unexpectedly; this is likely a bug.", e.getMessage()));
            }
        }
    }

    public Either<List<String>, AssertionRequestWrapper> startAuthentication(final HttpServletRequest request, final Optional<String> username) {
        if (username.isPresent() && !userStorage.userExists(username.get())) {
            return Either.left(List.of("The username %s is not registered.".formatted(username.get())));
        }
        val assertionRequest = new AssertionRequestWrapper(
            SessionManager.generateRandom(IDENTIFIER_LENGTH),
            relyingParty.startAssertion(StartAssertionOptions.builder().username(username).build())
        );
        assertRequestStorage.put(request, assertionRequest.getRequestId(), assertionRequest);
        return Either.right(assertionRequest);
    }

    public Either<List<String>, SuccessfulAuthenticationResult> finishAuthentication(final HttpServletRequest request, final String responseJson) {
        final AssertionResponse assertionResponse;
        try {
            assertionResponse = OBJECT_MAPPER.readValue(responseJson, AssertionResponse.class);
        } catch (final Exception e) {
            LOGGER.debug("Failed to decode response object", e);
            return Either.left(List.of("Assertion failed!", "Failed to decode response object.", e.getMessage()));
        }

        val assertionRequestWrapper = assertRequestStorage.getIfPresent(request, assertionResponse.requestId());
        assertRequestStorage.invalidate(request, assertionResponse.requestId());

        if (assertionRequestWrapper == null) {
            return Either.left(List.of("Assertion failed!", "No such assertion in progress."));
        } else {
            try {
                val assertionResult = relyingParty.finishAssertion(
                    FinishAssertionOptions.builder()
                        .request(assertionRequestWrapper.getRequest())
                        .response(assertionResponse.credential())
                        .build()
                );

                if (assertionResult.isSuccess()) {
                    try {
                        userStorage.updateSignatureCount(assertionResult);
                    } catch (final Exception e) {
                        LOGGER.error(
                            "Failed to update signature count for user \"{}\", credential \"{}\"",
                            assertionResult.getUsername(),
                            assertionResponse.credential().getId(),
                            e
                        );
                    }

                    val session = sessionManager.createSession(request, assertionResult.getCredential().getUserHandle());
                    return Either.right(
                        new SuccessfulAuthenticationResult(
                            assertionRequestWrapper,
                            assertionResponse,
                            userStorage.getRegistrationsByUsername(assertionResult.getUsername()),
                            assertionResult.getUsername(),
                            session
                        )
                    );
                } else {
                    return Either.left(List.of("Assertion failed: Invalid assertion."));
                }
            } catch (final AssertionFailedException e) {
                LOGGER.warn("Assertion failed", e);
                return Either.left(List.of("Assertion failed", e.getMessage()));
            } catch (final Exception e) {
                LOGGER.error("Assertion failed", e);
                return Either.left(List.of("Assertion failed unexpectedly; this is likely a bug.", e.getMessage()));
            }
        }
    }

    @Value
    public static class SuccessfulRegistrationResult {
        boolean success;

        RegistrationRequest request;

        RegistrationResponse response;

        CredentialRegistration registration;

        boolean attestationTrusted;

        Optional<AttestationCertInfo> attestationCert;

        @JsonSerialize(using = AuthDataSerializer.class)
        AuthenticatorData authData;

        String username;

        ByteArray sessionToken;

        public SuccessfulRegistrationResult(final RegistrationRequest request,
                                            final RegistrationResponse response,
                                            final CredentialRegistration registration,
                                            final boolean attestationTrusted,
                                            final ByteArray sessionToken) {
            this.request = request;
            this.response = response;
            this.registration = registration;
            this.attestationTrusted = attestationTrusted;
            attestationCert = Optional.ofNullable(
                    response.credential().getResponse().getAttestation().getAttestationStatement().get("x5c")
                ).map(certs -> certs.get(0))
                .flatMap(Unchecked.function(certDer -> Optional.of(new ByteArray(certDer.binaryValue()))))
                .map(AttestationCertInfo::new);
            this.authData = response.credential().getResponse().getParsedAuthenticatorData();
            this.username = request.username();
            this.sessionToken = sessionToken;
            this.success = true;
        }

    }

    @Value
    public static class AttestationCertInfo {
        ByteArray der;

        String text;

        public AttestationCertInfo(final ByteArray certDer) {
            der = certDer;
            X509Certificate cert = null;
            try {
                cert = CertificateParser.parseDer(certDer.getBytes());
            } catch (final CertificateException e) {
                LOGGER.error("Failed to parse attestation certificate");
            }
            if (cert == null) {
                text = null;
            } else {
                text = cert.toString();
            }
        }
    }

    @Value
    @AllArgsConstructor
    public static class SuccessfulAuthenticationResult {
        boolean success = true;

        AssertionRequestWrapper request;

        AssertionResponse response;

        Collection<CredentialRegistration> registrations;

        @JsonSerialize(using = AuthDataSerializer.class)
        AuthenticatorData authData;

        String username;

        ByteArray sessionToken;

        public SuccessfulAuthenticationResult(final AssertionRequestWrapper request, final AssertionResponse response,
                                              final Collection<CredentialRegistration> registrations,
                                              final String username, final ByteArray sessionToken) {
            this(
                request,
                response,
                registrations,
                response.credential().getResponse().getParsedAuthenticatorData(),
                username,
                sessionToken
            );
        }
    }

    @Value
    public static class DeregisterCredentialResult {
        boolean success = true;

        CredentialRegistration droppedRegistration;

        boolean accountDeleted;
    }

    private static class AuthDataSerializer extends JsonSerializer<AuthenticatorData> {

        @Override
        public void serialize(final AuthenticatorData value, final JsonGenerator gen,
                              final SerializerProvider serializers) throws IOException {
            gen.writeStartObject();
            gen.writeStringField("rpIdHash", value.getRpIdHash().getHex());
            gen.writeObjectField("flags", value.getFlags());
            gen.writeNumberField("signatureCounter", value.getSignatureCounter());
            value.getAttestedCredentialData().ifPresent(acd -> {
                try {
                    gen.writeObjectFieldStart("attestedCredentialData");
                    gen.writeStringField("aaguid", acd.getAaguid().getHex());
                    gen.writeStringField("credentialId", acd.getCredentialId().getHex());
                    gen.writeStringField("publicKey", acd.getCredentialPublicKey().getHex());
                    gen.writeEndObject();
                } catch (final IOException e) {
                    throw new RuntimeException(e);
                }
            });
            gen.writeObjectField("extensions", value.getExtensions());
            gen.writeEndObject();
        }
    }

    private CredentialRegistration addRegistration(
        final UserIdentity userIdentity,
        final Optional<String> nickname,
        final RegistrationResult result) {

        return addRegistration(
            userIdentity,
            nickname,
            RegisteredCredential.builder()
                .credentialId(result.getKeyId().getId())
                .userHandle(userIdentity.getId())
                .publicKeyCose(result.getPublicKeyCose())
                .signatureCount(result.getSignatureCount())
                .build(),
            result.getKeyId().getTransports().orElseGet(TreeSet::new),
            result
                .getAttestationTrustPath()
                .flatMap(x5c -> x5c.stream().findFirst())
                .flatMap(cert -> {
                    if (relyingParty.getAttestationTrustSource().isPresent() &&
                        relyingParty.getAttestationTrustSource().get() instanceof final AttestationMetadataSource source) {
                        return source.findMetadata(cert);
                    }
                    return Optional.empty();
                }));
    }


    private CredentialRegistration addRegistration(
        final UserIdentity userIdentity,
        final Optional<String> nickname,
        final RegisteredCredential credential,
        final SortedSet<AuthenticatorTransport> transports,
        final Optional<Attestation> attestationMetadata) {
        val reg = CredentialRegistration.builder()
            .userIdentity(userIdentity)
            .credentialNickname(nickname.orElse(null))
            .registrationTime(Clock.systemUTC().instant())
            .credential(credential)
            .transports(transports)
            .attestationMetadata(attestationMetadata.orElse(null))
            .build();
        LOGGER.debug("Adding registration: user: [{}], nickname: [{}], credential: [{}]", userIdentity, nickname, credential);
        userStorage.addRegistrationByUsername(userIdentity.getName(), reg);
        return reg;
    }

}
