package org.apereo.cas.support.claims;

import module java.base;
import org.apereo.cas.ws.idp.WSFederationConstants;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.lang3.StringUtils;
import org.apache.cxf.rt.security.claims.Claim;
import org.apache.cxf.sts.claims.ClaimsParameters;

/**
 * This is {@link CustomNamespaceWSFederationClaimsClaimsHandler}.
 *
 * @author Misagh Moayyed
 * @since 5.3.0
 */
@Slf4j
@Getter
public class CustomNamespaceWSFederationClaimsClaimsHandler extends NonWSFederationClaimsClaimsHandler {
    private final List<String> supportedClaimTypes;

    public CustomNamespaceWSFederationClaimsClaimsHandler(final String handlerRealm, final String issuer,
                                                          final List<String> namespaces) {
        super(handlerRealm, issuer);
        this.supportedClaimTypes = new CustomNamespaceWSFederationClaimsList(namespaces);
    }

    @RequiredArgsConstructor
    private static final class CustomNamespaceWSFederationClaimsList extends ArrayList<String> {
        @Serial
        private static final long serialVersionUID = 8368878016992806802L;

        private final List<String> namespaces;

        @Override
        public boolean contains(final Object o) {
            var uri = StringUtils.EMPTY;
            if (o instanceof final URI instance) {
                uri = instance.toASCIIString();
            } else {
                uri = o.toString();
            }
            return StringUtils.isNotBlank(uri) && namespaces.stream().anyMatch(uri::startsWith);
        }

        @Override
        public boolean isEmpty() {
            return namespaces.isEmpty();
        }
    }

    @Override
    protected String createProcessedClaimType(final Claim requestClaim, final ClaimsParameters parameters) {
        val tokenType = parameters.getTokenRequirements().getTokenType();
        if (WSFederationConstants.WSS_SAML2_TOKEN_TYPE.equalsIgnoreCase(tokenType)) {
            val claimType = requestClaim.getClaimType();
            val idx = claimType.lastIndexOf('/');
            val claimName = claimType.substring(idx + 1).trim();
            LOGGER.debug("Converted full claim type from [{}] to [{}]", claimType, claimName);
            return claimName;
        }
        return requestClaim.getClaimType();
    }
}
