Skip to content

pass HttpRequest to ServerBaseUrlCustomizer #2589

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -24,6 +24,8 @@

package org.springdoc.core.customizers;

import org.springframework.http.HttpRequest;

/**
* The interface Server Base URL customiser.
* @author skylar -stark
@@ -35,7 +37,8 @@ public interface ServerBaseUrlCustomizer {
* Customise.
*
* @param serverBaseUrl the serverBaseUrl.
* @param request the request.
* @return the customised serverBaseUrl
*/
String customize(String serverBaseUrl);
String customize(String serverBaseUrl, HttpRequest request);
}
Original file line number Diff line number Diff line change
@@ -81,6 +81,7 @@
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.core.type.filter.AnnotationTypeFilter;
import org.springframework.http.HttpRequest;
import org.springframework.stereotype.Controller;
import org.springframework.util.CollectionUtils;
import org.springframework.web.bind.annotation.ControllerAdvice;
@@ -490,12 +491,12 @@ public Schema resolveProperties(Schema schema, Locale locale) {
*
* @param serverBaseUrl the server base url
*/
public void setServerBaseUrl(String serverBaseUrl) {
public void setServerBaseUrl(String serverBaseUrl, HttpRequest httpRequest) {
String customServerBaseUrl = serverBaseUrl;

if (serverBaseUrlCustomizers.isPresent()) {
for (ServerBaseUrlCustomizer customizer : serverBaseUrlCustomizers.get()) {
customServerBaseUrl = customizer.customize(customServerBaseUrl);
customServerBaseUrl = customizer.customize(customServerBaseUrl, httpRequest);
}
}

Original file line number Diff line number Diff line change
@@ -62,6 +62,7 @@
import org.springframework.context.ApplicationContext;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.mock.http.client.MockClientHttpRequest;

import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
@@ -190,7 +191,7 @@ void preLoadingModeShouldNotOverwriteServers() throws InterruptedException {
doCallRealMethod().when(openAPIService).updateServers(any());
when(openAPIService.getCachedOpenAPI(any())).thenCallRealMethod();
doAnswer(new CallsRealMethods()).when(openAPIService).setServersPresent(true);
doAnswer(new CallsRealMethods()).when(openAPIService).setServerBaseUrl(any());
doAnswer(new CallsRealMethods()).when(openAPIService).setServerBaseUrl(any(), any());
doAnswer(new CallsRealMethods()).when(openAPIService).setCachedOpenAPI(any(), any());

String customUrl = "https://custom.com";
@@ -212,7 +213,7 @@ properties, springDocProviders, new SpringDocCustomizers(Optional.of(singletonLi
Thread.sleep(1_000);

// emulate generating base url
openAPIService.setServerBaseUrl(generatedUrl);
openAPIService.setServerBaseUrl(generatedUrl, new MockClientHttpRequest());
openAPIService.updateServers(openAPI);
Locale locale = Locale.US;
OpenAPI after = resource.getOpenApi(locale);
@@ -224,7 +225,7 @@ properties, springDocProviders, new SpringDocCustomizers(Optional.of(singletonLi
void serverBaseUrlCustomisersTest() throws InterruptedException {
doCallRealMethod().when(openAPIService).updateServers(any());
when(openAPIService.getCachedOpenAPI(any())).thenCallRealMethod();
doAnswer(new CallsRealMethods()).when(openAPIService).setServerBaseUrl(any());
doAnswer(new CallsRealMethods()).when(openAPIService).setServerBaseUrl(any(), any());
doAnswer(new CallsRealMethods()).when(openAPIService).setCachedOpenAPI(any(), any());

SpringDocConfigProperties properties = new SpringDocConfigProperties();
@@ -247,37 +248,37 @@ springDocProviders, new SpringDocCustomizers(Optional.empty(),Optional.empty(),O

// Test that setting generated URL works fine with no customizers present
String generatedUrl = "https://generated-url.com/context-path";
openAPIService.setServerBaseUrl(generatedUrl);
openAPIService.setServerBaseUrl(generatedUrl, new MockClientHttpRequest());
openAPIService.updateServers(openAPI);
OpenAPI after = resource.getOpenApi(locale);
assertThat(after.getServers().get(0).getUrl(), is(generatedUrl));

// Test that adding a serverBaseUrlCustomizer has the desired effect
ServerBaseUrlCustomizer serverBaseUrlCustomizer = serverBaseUrl -> serverBaseUrl.replace("/context-path", "");
ServerBaseUrlCustomizer serverBaseUrlCustomizer = (serverBaseUrl, request) -> serverBaseUrl.replace("/context-path", "");
List<ServerBaseUrlCustomizer> serverBaseUrlCustomizerList = new ArrayList<>();
serverBaseUrlCustomizerList.add(serverBaseUrlCustomizer);

ReflectionTestUtils.setField(openAPIService, "serverBaseUrlCustomizers", Optional.of(serverBaseUrlCustomizerList));
openAPIService.setServerBaseUrl(generatedUrl);
openAPIService.setServerBaseUrl(generatedUrl, new MockClientHttpRequest());
openAPIService.updateServers(openAPI);
after = resource.getOpenApi(locale);
assertThat(after.getServers().get(0).getUrl(), is("https://generated-url.com"));

// Test that serverBaseUrlCustomisers are performed in order
generatedUrl = "https://generated-url.com/context-path/second-path";
ServerBaseUrlCustomizer serverBaseUrlCustomiser2 = serverBaseUrl -> serverBaseUrl.replace("/context-path/second-path", "");
ServerBaseUrlCustomizer serverBaseUrlCustomiser2 = (serverBaseUrl, request) -> serverBaseUrl.replace("/context-path/second-path", "");
serverBaseUrlCustomizerList.add(serverBaseUrlCustomiser2);

openAPIService.setServerBaseUrl(generatedUrl);
openAPIService.setServerBaseUrl(generatedUrl, new MockClientHttpRequest());
openAPIService.updateServers(openAPI);
after = resource.getOpenApi(locale);
assertThat(after.getServers().get(0).getUrl(), is("https://generated-url.com/second-path"));

// Test that all serverBaseUrlCustomisers in the List are performed
ServerBaseUrlCustomizer serverBaseUrlCustomiser3 = serverBaseUrl -> serverBaseUrl.replace("/second-path", "");
ServerBaseUrlCustomizer serverBaseUrlCustomiser3 = (serverBaseUrl, request) -> serverBaseUrl.replace("/second-path", "");
serverBaseUrlCustomizerList.add(serverBaseUrlCustomiser3);

openAPIService.setServerBaseUrl(generatedUrl);
openAPIService.setServerBaseUrl(generatedUrl, new MockClientHttpRequest());
openAPIService.updateServers(openAPI);
after = resource.getOpenApi(locale);
assertThat(after.getServers().get(0).getUrl(), is("https://generated-url.com"));
Original file line number Diff line number Diff line change
@@ -131,7 +131,7 @@ public Mono<byte[]> openapiYaml(ServerHttpRequest serverHttpRequest, Locale loca
protected void calculateServerUrl(ServerHttpRequest serverHttpRequest, String apiDocsUrl, Locale locale) {
super.initOpenAPIBuilder(locale);
URI uri = getActuatorURI(serverHttpRequest.getURI().getScheme(), serverHttpRequest.getURI().getHost());
openAPIService.setServerBaseUrl(uri.toString());
openAPIService.setServerBaseUrl(uri.toString(), serverHttpRequest);
}

@Override
Original file line number Diff line number Diff line change
@@ -229,7 +229,7 @@ protected void getWebFluxRouterFunctionPaths(Locale locale, OpenAPI openAPI) {
protected void calculateServerUrl(ServerHttpRequest serverHttpRequest, String apiDocsUrl, Locale locale) {
initOpenAPIBuilder(locale);
String serverUrl = getServerUrl(serverHttpRequest, apiDocsUrl);
openAPIService.setServerBaseUrl(serverUrl);
openAPIService.setServerBaseUrl(serverUrl, serverHttpRequest);
}

/**
Original file line number Diff line number Diff line change
@@ -55,6 +55,7 @@

import org.springframework.aop.support.AopUtils;
import org.springframework.beans.factory.ObjectFactory;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MimeType;
import org.springframework.web.bind.annotation.RequestMethod;
@@ -244,7 +245,8 @@ private Comparator<RequestMappingInfo> byReversedRequestMappingInfos() {
protected void calculateServerUrl(HttpServletRequest request, String apiDocsUrl, Locale locale) {
super.initOpenAPIBuilder(locale);
String calculatedUrl = getServerUrl(request, apiDocsUrl);
openAPIService.setServerBaseUrl(calculatedUrl);
ServletServerHttpRequest serverRequest = request != null ? new ServletServerHttpRequest(request) : null;
openAPIService.setServerBaseUrl(calculatedUrl, serverRequest);
}

/**