Skip to content

Commit

Permalink
Avoid parse cookies when mutating request or response
Browse files Browse the repository at this point in the history
When mutating a ServerHttpRequest or ClientResponse, the respective
builders no longer access cookies automatically which causes them to
be parsed and does so only if necessary. Likewise re-applying the
read-only HttpHeaders wrapper is avoided.

See gh-24680
  • Loading branch information
rstoyanchev committed May 11, 2020
1 parent 0e9ecb6 commit 94824e3
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 73 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,7 +20,6 @@
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.function.Consumer;

import reactor.core.publisher.Flux;
Expand All @@ -31,7 +30,6 @@
import org.springframework.http.HttpMethod;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;

Expand All @@ -46,12 +44,10 @@ class DefaultServerHttpRequestBuilder implements ServerHttpRequest.Builder {

private URI uri;

private HttpHeaders httpHeaders;
private HttpHeaders headers;

private String httpMethodValue;

private final MultiValueMap<String, HttpCookie> cookies;

@Nullable
private String uriPath;

Expand All @@ -70,21 +66,12 @@ public DefaultServerHttpRequestBuilder(ServerHttpRequest original) {
Assert.notNull(original, "ServerHttpRequest is required");

this.uri = original.getURI();
this.headers = HttpHeaders.writableHttpHeaders(original.getHeaders());
this.httpMethodValue = original.getMethodValue();
this.body = original.getBody();

this.httpHeaders = HttpHeaders.writableHttpHeaders(original.getHeaders());

this.cookies = new LinkedMultiValueMap<>(original.getCookies().size());
copyMultiValueMap(original.getCookies(), this.cookies);

this.originalRequest = original;
}

private static <K, V> void copyMultiValueMap(MultiValueMap<K,V> source, MultiValueMap<K,V> target) {
source.forEach((key, value) -> target.put(key, new LinkedList<>(value)));
}


@Override
public ServerHttpRequest.Builder method(HttpMethod httpMethod) {
Expand Down Expand Up @@ -113,14 +100,14 @@ public ServerHttpRequest.Builder contextPath(String contextPath) {

@Override
public ServerHttpRequest.Builder header(String headerName, String... headerValues) {
this.httpHeaders.put(headerName, Arrays.asList(headerValues));
this.headers.put(headerName, Arrays.asList(headerValues));
return this;
}

@Override
public ServerHttpRequest.Builder headers(Consumer<HttpHeaders> headersConsumer) {
Assert.notNull(headersConsumer, "'headersConsumer' must not be null");
headersConsumer.accept(this.httpHeaders);
headersConsumer.accept(this.headers);
return this;
}

Expand All @@ -132,8 +119,8 @@ public ServerHttpRequest.Builder sslInfo(SslInfo sslInfo) {

@Override
public ServerHttpRequest build() {
return new MutatedServerHttpRequest(getUriToUse(), this.contextPath, this.httpHeaders,
this.httpMethodValue, this.cookies, this.sslInfo, this.body, this.originalRequest);
return new MutatedServerHttpRequest(getUriToUse(), this.contextPath,
this.httpMethodValue, this.sslInfo, this.body, this.originalRequest);
}

private URI getUriToUse() {
Expand Down Expand Up @@ -179,8 +166,6 @@ private static class MutatedServerHttpRequest extends AbstractServerHttpRequest

private final String methodValue;

private final MultiValueMap<String, HttpCookie> cookies;

@Nullable
private final SslInfo sslInfo;

Expand All @@ -190,12 +175,11 @@ private static class MutatedServerHttpRequest extends AbstractServerHttpRequest


public MutatedServerHttpRequest(URI uri, @Nullable String contextPath,
HttpHeaders headers, String methodValue, MultiValueMap<String, HttpCookie> cookies,
@Nullable SslInfo sslInfo, Flux<DataBuffer> body, ServerHttpRequest originalRequest) {
String methodValue, @Nullable SslInfo sslInfo,
Flux<DataBuffer> body, ServerHttpRequest originalRequest) {

super(uri, contextPath, headers);
super(uri, contextPath, originalRequest.getHeaders());
this.methodValue = methodValue;
this.cookies = cookies;
this.sslInfo = sslInfo != null ? sslInfo : originalRequest.getSslInfo();
this.body = body;
this.originalRequest = originalRequest;
Expand All @@ -208,7 +192,7 @@ public String getMethodValue() {

@Override
protected MultiValueMap<String, HttpCookie> initCookies() {
return this.cookies;
return this.originalRequest.getCookies();
}

@Nullable
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -49,98 +49,101 @@ public class ServerHttpRequestTests {

@Test
public void queryParamsNone() throws Exception {
MultiValueMap<String, String> params = createHttpRequest("/path").getQueryParams();
MultiValueMap<String, String> params = createRequest("/path").getQueryParams();
assertThat(params.size()).isEqualTo(0);
}

@Test
public void queryParams() throws Exception {
MultiValueMap<String, String> params = createHttpRequest("/path?a=A&b=B").getQueryParams();
MultiValueMap<String, String> params = createRequest("/path?a=A&b=B").getQueryParams();
assertThat(params.size()).isEqualTo(2);
assertThat(params.get("a")).isEqualTo(Collections.singletonList("A"));
assertThat(params.get("b")).isEqualTo(Collections.singletonList("B"));
}

@Test
public void queryParamsWithMultipleValues() throws Exception {
MultiValueMap<String, String> params = createHttpRequest("/path?a=1&a=2").getQueryParams();
MultiValueMap<String, String> params = createRequest("/path?a=1&a=2").getQueryParams();
assertThat(params.size()).isEqualTo(1);
assertThat(params.get("a")).isEqualTo(Arrays.asList("1", "2"));
}

@Test // SPR-15140
public void queryParamsWithEncodedValue() throws Exception {
MultiValueMap<String, String> params = createHttpRequest("/path?a=%20%2B+%C3%A0").getQueryParams();
MultiValueMap<String, String> params = createRequest("/path?a=%20%2B+%C3%A0").getQueryParams();
assertThat(params.size()).isEqualTo(1);
assertThat(params.get("a")).isEqualTo(Collections.singletonList(" + \u00e0"));
}

@Test
public void queryParamsWithEmptyValue() throws Exception {
MultiValueMap<String, String> params = createHttpRequest("/path?a=").getQueryParams();
MultiValueMap<String, String> params = createRequest("/path?a=").getQueryParams();
assertThat(params.size()).isEqualTo(1);
assertThat(params.get("a")).isEqualTo(Collections.singletonList(""));
}

@Test
public void queryParamsWithNoValue() throws Exception {
MultiValueMap<String, String> params = createHttpRequest("/path?a").getQueryParams();
MultiValueMap<String, String> params = createRequest("/path?a").getQueryParams();
assertThat(params.size()).isEqualTo(1);
assertThat(params.get("a")).isEqualTo(Collections.singletonList(null));
}

@Test
public void mutateRequest() throws Exception {
public void mutateRequestMethod() throws Exception {
ServerHttpRequest request = createRequest("/").mutate().method(HttpMethod.DELETE).build();
assertThat(request.getMethod()).isEqualTo(HttpMethod.DELETE);
}

@Test
public void mutateSslInfo() throws Exception {
SslInfo sslInfo = mock(SslInfo.class);
ServerHttpRequest request = createHttpRequest("/").mutate().sslInfo(sslInfo).build();
ServerHttpRequest request = createRequest("/").mutate().sslInfo(sslInfo).build();
assertThat(request.getSslInfo()).isSameAs(sslInfo);
}

request = createHttpRequest("/").mutate().method(HttpMethod.DELETE).build();
assertThat(request.getMethod()).isEqualTo(HttpMethod.DELETE);

@Test
public void mutateUriAndPath() throws Exception {
String baseUri = "https://aaa.org:8080/a";

request = createHttpRequest(baseUri).mutate().uri(URI.create("https://bbb.org:9090/b")).build();
ServerHttpRequest request = createRequest(baseUri).mutate().uri(URI.create("https://bbb.org:9090/b")).build();
assertThat(request.getURI().toString()).isEqualTo("https://bbb.org:9090/b");

request = createHttpRequest(baseUri).mutate().path("/b/c/d").build();
request = createRequest(baseUri).mutate().path("/b/c/d").build();
assertThat(request.getURI().toString()).isEqualTo("https://aaa.org:8080/b/c/d");

request = createHttpRequest(baseUri).mutate().path("/app/b/c/d").contextPath("/app").build();
request = createRequest(baseUri).mutate().path("/app/b/c/d").contextPath("/app").build();
assertThat(request.getURI().toString()).isEqualTo("https://aaa.org:8080/app/b/c/d");
assertThat(request.getPath().contextPath().value()).isEqualTo("/app");
}

@Test
public void mutateWithInvalidPath() throws Exception {
assertThatIllegalArgumentException().isThrownBy(() ->
createHttpRequest("/").mutate().path("foo-bar"));
}

@Test // SPR-16434
public void mutatePathWithEncodedQueryParams() throws Exception {
ServerHttpRequest request = createHttpRequest("/path?name=%E6%89%8E%E6%A0%B9");
ServerHttpRequest request = createRequest("/path?name=%E6%89%8E%E6%A0%B9");
request = request.mutate().path("/mutatedPath").build();

assertThat(request.getURI().getRawPath()).isEqualTo("/mutatedPath");
assertThat(request.getURI().getRawQuery()).isEqualTo("name=%E6%89%8E%E6%A0%B9");
}

@Test
public void mutateWithInvalidPath() {
assertThatIllegalArgumentException().isThrownBy(() -> createRequest("/").mutate().path("foo-bar"));
}

@Test
public void mutateHeadersViaConsumer() throws Exception {
String headerName = "key";
String headerValue1 = "value1";
String headerValue2 = "value2";

ServerHttpRequest request = createHttpRequest("/path");
ServerHttpRequest request = createRequest("/path");
assertThat(request.getHeaders().get(headerName)).isNull();

request = request.mutate().headers(headers -> headers.add(headerName, headerValue1)).build();

assertThat(request.getHeaders().get(headerName)).containsExactly(headerValue1);

request = request.mutate().headers(headers -> headers.add(headerName, headerValue2)).build();

assertThat(request.getHeaders().get(headerName)).containsExactly(headerValue1, headerValue2);
}

Expand All @@ -151,19 +154,17 @@ public void mutateHeaderBySettingHeaderValues() throws Exception {
String headerValue2 = "value2";
String headerValue3 = "value3";

ServerHttpRequest request = createHttpRequest("/path");
ServerHttpRequest request = createRequest("/path");
assertThat(request.getHeaders().get(headerName)).isNull();

request = request.mutate().header(headerName, headerValue1, headerValue2).build();

assertThat(request.getHeaders().get(headerName)).containsExactly(headerValue1, headerValue2);

request = request.mutate().header(headerName, headerValue3).build();

assertThat(request.getHeaders().get(headerName)).containsExactly(headerValue3);
}

private ServerHttpRequest createHttpRequest(String uriString) throws Exception {
private ServerHttpRequest createRequest(String uriString) throws Exception {
URI uri = URI.create(uriString);
MockHttpServletRequest request = new TestHttpServletRequest(uri);
AsyncContext asyncContext = new MockAsyncContext(request, new MockHttpServletResponse());
Expand Down
Loading

0 comments on commit 94824e3

Please sign in to comment.