diff --git a/src/main/java/com/salesforce/einsteinbot/sdk/util/WebClientUtil.java b/src/main/java/com/salesforce/einsteinbot/sdk/util/WebClientUtil.java index b32b0fa..6e8cae0 100644 --- a/src/main/java/com/salesforce/einsteinbot/sdk/util/WebClientUtil.java +++ b/src/main/java/com/salesforce/einsteinbot/sdk/util/WebClientUtil.java @@ -14,6 +14,8 @@ import com.salesforce.einsteinbot.sdk.model.Error; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; import org.springframework.http.ReactiveHttpInputMessage; import org.springframework.http.client.reactive.ClientHttpResponse; import org.springframework.web.reactive.function.BodyExtractor; @@ -32,7 +34,6 @@ public class WebClientUtil { private static final Logger logger = LoggerFactory.getLogger(WebClientUtil.class); - public static Mono createLoggingRequestProcessor(ClientRequest clientRequest) { logger.info("Making {} Request to URI {} with Headers : {}", clientRequest.method(), clientRequest.url(), maskAuthorizationHeader(clientRequest.headers())); @@ -60,8 +61,9 @@ public static ExchangeFilterFunction createFilter( public static BodyExtractor, ReactiveHttpInputMessage> errorBodyExtractor() { BodyExtractor, ReactiveHttpInputMessage> extractor = (inputMessage, context) -> { - String contentType = inputMessage.getHeaders().getContentType().toString(); - if (contentType.contains("application/json")) { + HttpHeaders headers = inputMessage.getHeaders(); + MediaType contentType = headers != null ? headers.getContentType() : null; + if (contentType != null && contentType.toString().toLowerCase().contains("application/json")) { return BodyExtractors.toMono(Error.class) .extract(inputMessage, context); } else { @@ -72,10 +74,10 @@ public static BodyExtractor, ReactiveHttpInputMessage> errorBodyExtr } private static Mono buildErrorFromClientResponseBodyString(ReactiveHttpInputMessage clientResponse, BodyExtractor.Context context) { - ClientHttpResponse response = (ClientHttpResponse) clientResponse; - Mono bodyString = BodyExtractors.toMono(String.class). + ClientHttpResponse response = (ClientHttpResponse) clientResponse; + Mono bodyString = BodyExtractors.toMono(String.class). extract(clientResponse, context); - return bodyString.map(errorMessage -> new Error() + return bodyString.map(errorMessage -> new Error() .status(response.getRawStatusCode()) .message("This Response content type is not 'application/json', " + "See the 'error' field for actual error returned by the server.") diff --git a/src/test/java/com/salesforce/einsteinbot/sdk/client/ClientApiWireMockTest.java b/src/test/java/com/salesforce/einsteinbot/sdk/client/ClientApiWireMockTest.java index c29b3fc..00f162d 100644 --- a/src/test/java/com/salesforce/einsteinbot/sdk/client/ClientApiWireMockTest.java +++ b/src/test/java/com/salesforce/einsteinbot/sdk/client/ClientApiWireMockTest.java @@ -407,4 +407,41 @@ private void stubVersionsResponse(String responseBodyFile, int statusCode) { ); } -} \ No newline at end of file + @Test + public void test429ResponseWithNullContentType() { + wireMock.stubFor( + get(VERSIONS_URI) + .willReturn( + aResponse() + .withStatus(HttpStatus.TOO_MANY_REQUESTS.value()) + .withBody("Too many requests") + ) + ); + + Throwable exceptionThrown = assertThrows(RuntimeException.class, + () -> client.getSupportedVersions()); + + ChatbotResponseException chatbotResponseException = validateAndGetCause(exceptionThrown, + ChatbotResponseException.class); + assertEquals(HttpStatus.TOO_MANY_REQUESTS.value(), chatbotResponseException.getStatus()); + } + + @Test + public void test200ResponseWithApplicationJsonContentType() throws Exception { + String responseBodyFile = "versionsResponse.json"; + wireMock.stubFor( + get(VERSIONS_URI) + .willReturn( + aResponse() + .withStatus(HttpStatus.OK.value()) + .withHeader("Content-Type", "application/json;charset=UTF-8") + .withBodyFile(TEST_MOCK_DIR + responseBodyFile) + ) + ); + + SupportedVersions versions = client.getSupportedVersions(); + + verifyResponseEnvelope(responseBodyFile, versions); + } + +}