package org.apache.nifi.web.security.saml2.service.web;

import jakarta.servlet.http.HttpServletRequest;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;
import org.apache.nifi.web.security.saml2.registration.Saml2RegistrationProperty;
import org.apache.nifi.web.servlet.shared.RequestUriBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
import org.springframework.web.util.UriComponentsBuilder;

/* loaded from: input_file:org/apache/nifi/web/security/saml2/service/web/StandardRelyingPartyRegistrationResolver.class */
public class StandardRelyingPartyRegistrationResolver implements Converter<HttpServletRequest, RelyingPartyRegistration>, RelyingPartyRegistrationResolver {
    private static final String BASE_URL_KEY = "baseUrl";
    private static final String REGISTRATION_ID_KEY = "registrationId";
    private static final Logger logger = LoggerFactory.getLogger(StandardRelyingPartyRegistrationResolver.class);
    private final RelyingPartyRegistrationRepository repository;
    private final List<String> allowedContextPaths;

    public StandardRelyingPartyRegistrationResolver(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository, List<String> list) {
        this.repository = (RelyingPartyRegistrationRepository) Objects.requireNonNull(relyingPartyRegistrationRepository, "Repository required");
        this.allowedContextPaths = (List) Objects.requireNonNull(list, "Allowed Context Paths required");
    }

    public RelyingPartyRegistration convert(HttpServletRequest httpServletRequest) {
        return resolve(httpServletRequest, Saml2RegistrationProperty.REGISTRATION_ID.getProperty());
    }

    public RelyingPartyRegistration resolve(HttpServletRequest httpServletRequest, String str) {
        RelyingPartyRegistration build;
        Objects.requireNonNull(httpServletRequest, "Request required");
        RelyingPartyRegistration findByRegistrationId = this.repository.findByRegistrationId(str);
        if (findByRegistrationId == null) {
            build = null;
            logger.warn("Relying Party Registration [{}] not found", str);
        } else {
            String baseUrl = getBaseUrl(httpServletRequest);
            String resolveUrl = resolveUrl(findByRegistrationId.getAssertionConsumerServiceLocation(), baseUrl, findByRegistrationId);
            build = findByRegistrationId.mutate().assertionConsumerServiceLocation(resolveUrl).singleLogoutServiceLocation(resolveUrl(findByRegistrationId.getSingleLogoutServiceLocation(), baseUrl, findByRegistrationId)).singleLogoutServiceResponseLocation(resolveUrl(findByRegistrationId.getSingleLogoutServiceResponseLocation(), baseUrl, findByRegistrationId)).build();
        }
        return build;
    }

    private String resolveUrl(String str, String str2, RelyingPartyRegistration relyingPartyRegistration) {
        String uriString;
        if (str == null) {
            uriString = null;
        } else {
            HashMap hashMap = new HashMap();
            hashMap.put(BASE_URL_KEY, str2);
            hashMap.put(REGISTRATION_ID_KEY, relyingPartyRegistration.getRegistrationId());
            uriString = UriComponentsBuilder.fromUriString(str).buildAndExpand(hashMap).toUriString();
        }
        return uriString;
    }

    private String getBaseUrl(HttpServletRequest httpServletRequest) {
        String uri = RequestUriBuilder.fromHttpServletRequest(httpServletRequest, this.allowedContextPaths).build().toString();
        return UriComponentsBuilder.fromUriString(uri).path(httpServletRequest.getContextPath()).replaceQuery((String) null).fragment((String) null).build().toString();
    }
}
