From ce7278aaf4f20348862267c2081c20dc5bd77128 Mon Sep 17 00:00:00 2001 From: Brian Clozel Date: Wed, 3 Oct 2018 21:18:24 +0200 Subject: [PATCH] Optimize HTTP headers management Several benchmarks underlined a few hotspots for CPU and GC pressure in the Spring Framework codebase: 1. `org.springframework.util.MimeType.(String, String, Map)` 2. `org.springframework.util.LinkedCaseInsensitiveMap.convertKey(String)` Both are linked with HTTP request headers parsing and response headers writin during the exchange processing phase. 1) is linked to repeated calls to `HttpHeaders.getContentType` within a single request handling. The media type parsing operation is expensive and the result doesn't change between calls, since the request headers are immutable at that point. This commit improves this by caching the parsed `MediaType` for the `"Content-Type"` request header in the `ReadOnlyHttpHeaders` class. This change is available for both Spring MVC and Spring WebFlux. 2) is linked to insertions/lookups in the `LinkedCaseInsensitiveMap`, which is the data structure behind `HttpHeaders`. Those operations are creating a lot of garbage (including a lot of `String` created by `toLowerCase`). We could choose a more efficient data structure for storing HTTP headers data. As a first step, this commit is focusing on Spring WebFlux and introduces `MultiValueMap` implementations mapped by native HTTP headers for the following servers: Tomcat, Jetty, Netty and Undertow. Such implementations avoid unnecessary copying of the headers and leverages as much as possible optimized operations provided by the native implementations. This change has a few consequences: * `HttpHeaders` can now wrap a `MultiValueMap` directly * The default constructor of `HttpHeaders` is still backed by a `LinkedCaseInsensitiveMap` * The HTTP request headers for the websocket HTTP handshake now need to be cloned, because native headers are likely to be pooled/recycled by the server implementation, hence gone when the initial HTTP exchange is done Issue: SPR-17250 --- .../org/springframework/http/HttpHeaders.java | 61 ++--- .../http/ReadOnlyHttpHeaders.java | 135 ++++++++++ .../server/ServletServerHttpResponse.java | 3 + .../AbstractListenerServerHttpResponse.java | 7 +- .../reactive/AbstractServerHttpResponse.java | 7 +- .../server/reactive/JettyHeadersAdapter.java | 222 ++++++++++++++++ .../reactive/JettyHttpHandlerAdapter.java | 40 ++- .../server/reactive/NettyHeadersAdapter.java | 217 ++++++++++++++++ .../reactive/ReactorServerHttpRequest.java | 7 +- .../reactive/ReactorServerHttpResponse.java | 17 +- .../reactive/ServletServerHttpRequest.java | 36 ++- .../reactive/ServletServerHttpResponse.java | 10 +- .../server/reactive/TomcatHeadersAdapter.java | 237 ++++++++++++++++++ .../reactive/TomcatHttpHandlerAdapter.java | 58 ++++- .../reactive/UndertowHeadersAdapter.java | 222 ++++++++++++++++ .../reactive/UndertowServerHttpRequest.java | 9 +- .../reactive/UndertowServerHttpResponse.java | 14 +- .../cors/reactive/DefaultCorsProcessor.java | 2 +- .../support/HandshakeWebSocketService.java | 5 +- 19 files changed, 1224 insertions(+), 85 deletions(-) create mode 100644 spring-web/src/main/java/org/springframework/http/ReadOnlyHttpHeaders.java create mode 100644 spring-web/src/main/java/org/springframework/http/server/reactive/JettyHeadersAdapter.java create mode 100644 spring-web/src/main/java/org/springframework/http/server/reactive/NettyHeadersAdapter.java create mode 100644 spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHeadersAdapter.java create mode 100644 spring-web/src/main/java/org/springframework/http/server/reactive/UndertowHeadersAdapter.java diff --git a/spring-web/src/main/java/org/springframework/http/HttpHeaders.java b/spring-web/src/main/java/org/springframework/http/HttpHeaders.java index 3e0a81dc65e0..4756017d3691 100644 --- a/spring-web/src/main/java/org/springframework/http/HttpHeaders.java +++ b/spring-web/src/main/java/org/springframework/http/HttpHeaders.java @@ -35,8 +35,6 @@ import java.util.Collections; import java.util.EnumSet; import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.LinkedList; import java.util.List; import java.util.Locale; import java.util.Map; @@ -47,7 +45,9 @@ import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedCaseInsensitiveMap; +import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; @@ -78,7 +78,8 @@ public class HttpHeaders implements MultiValueMap, Serializable /** * The empty {@code HttpHeaders} instance (immutable). */ - public static final HttpHeaders EMPTY = new HttpHeaders(new LinkedHashMap<>(), true); + public static final HttpHeaders EMPTY = + new ReadOnlyHttpHeaders(new HttpHeaders(new LinkedMultiValueMap<>(0))); /** * The HTTP {@code Accept} header field name. * @see Section 5.3.2 of RFC 7231 @@ -397,35 +398,27 @@ public class HttpHeaders implements MultiValueMap, Serializable private static final DateTimeFormatter[] DATE_FORMATTERS = new DateTimeFormatter[] { DateTimeFormatter.RFC_1123_DATE_TIME, DateTimeFormatter.ofPattern("EEEE, dd-MMM-yy HH:mm:ss zz", Locale.US), - DateTimeFormatter.ofPattern("EEE MMM dd HH:mm:ss yyyy",Locale.US).withZone(GMT) + DateTimeFormatter.ofPattern("EEE MMM dd HH:mm:ss yyyy", Locale.US).withZone(GMT) }; - private final Map> headers; - - private final boolean readOnly; + final MultiValueMap headers; /** - * Constructs a new, empty instance of the {@code HttpHeaders} object. + * Construct a new, empty instance of the {@code HttpHeaders} object. */ public HttpHeaders() { - this(new LinkedCaseInsensitiveMap<>(8, Locale.ENGLISH), false); + this(CollectionUtils.toMultiValueMap( + new LinkedCaseInsensitiveMap<>(8, Locale.ENGLISH))); } /** - * Private constructor that can create read-only {@code HttpHeader} instances. + * Construct a new {@code HttpHeaders} instance backed by an existing map. */ - private HttpHeaders(Map> headers, boolean readOnly) { - if (readOnly) { - Map> map = new LinkedCaseInsensitiveMap<>(headers.size(), Locale.ENGLISH); - headers.forEach((key, valueList) -> map.put(key, Collections.unmodifiableList(valueList))); - this.headers = Collections.unmodifiableMap(map); - } - else { - this.headers = headers; - } - this.readOnly = readOnly; + public HttpHeaders(MultiValueMap headers) { + Assert.notNull(headers, "headers must not be null"); + this.headers = headers; } @@ -1474,8 +1467,7 @@ protected String toCommaDelimitedString(List headerValues) { @Override @Nullable public String getFirst(String headerName) { - List headerValues = this.headers.get(headerName); - return (headerValues != null ? headerValues.get(0) : null); + return this.headers.getFirst(headerName); } /** @@ -1488,19 +1480,17 @@ public String getFirst(String headerName) { */ @Override public void add(String headerName, @Nullable String headerValue) { - List headerValues = this.headers.computeIfAbsent(headerName, k -> new LinkedList<>()); - headerValues.add(headerValue); + this.headers.add(headerName, headerValue); } @Override public void addAll(String key, List values) { - List currentValues = this.headers.computeIfAbsent(key, k -> new LinkedList<>()); - currentValues.addAll(values); + this.headers.addAll(key, values); } @Override public void addAll(MultiValueMap values) { - values.forEach(this::addAll); + this.headers.addAll(values); } /** @@ -1513,21 +1503,17 @@ public void addAll(MultiValueMap values) { */ @Override public void set(String headerName, @Nullable String headerValue) { - List headerValues = new LinkedList<>(); - headerValues.add(headerValue); - this.headers.put(headerName, headerValues); + this.headers.set(headerName, headerValue); } @Override public void setAll(Map values) { - values.forEach(this::set); + this.headers.setAll(values); } @Override public Map toSingleValueMap() { - LinkedHashMap singleValueMap = new LinkedHashMap<>(this.headers.size()); - this.headers.forEach((key, valueList) -> singleValueMap.put(key, valueList.get(0))); - return singleValueMap; + return this.headers.toSingleValueMap(); } @@ -1623,7 +1609,12 @@ public String toString() { */ public static HttpHeaders readOnlyHttpHeaders(HttpHeaders headers) { Assert.notNull(headers, "HttpHeaders must not be null"); - return (headers.readOnly ? headers : new HttpHeaders(headers, true)); + if (headers instanceof ReadOnlyHttpHeaders) { + return headers; + } + else { + return new ReadOnlyHttpHeaders(headers); + } } } diff --git a/spring-web/src/main/java/org/springframework/http/ReadOnlyHttpHeaders.java b/spring-web/src/main/java/org/springframework/http/ReadOnlyHttpHeaders.java new file mode 100644 index 000000000000..39f64d4b6af9 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/ReadOnlyHttpHeaders.java @@ -0,0 +1,135 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http; + +import java.util.AbstractMap; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * {@code HttpHeaders} object that can only be read, not written to. + * + * @author Brian Clozel + * @since 5.1 + */ +class ReadOnlyHttpHeaders extends HttpHeaders { + + private static final long serialVersionUID = -8578554704772377436L; + + @Nullable + private MediaType cachedContentType; + + ReadOnlyHttpHeaders(HttpHeaders headers) { + super(headers.headers); + } + + @Override + public MediaType getContentType() { + if (this.cachedContentType != null) { + return this.cachedContentType; + } + else { + MediaType contentType = super.getContentType(); + this.cachedContentType = contentType; + return contentType; + } + } + + @Override + public List get(Object key) { + List values = this.headers.get(key); + if (values != null) { + return Collections.unmodifiableList(values); + } + return values; + } + + @Override + public void add(String headerName, @Nullable String headerValue) { + throw new UnsupportedOperationException(); + } + + @Override + public void addAll(String key, List values) { + throw new UnsupportedOperationException(); + } + + @Override + public void addAll(MultiValueMap values) { + throw new UnsupportedOperationException(); + } + + @Override + public void set(String headerName, @Nullable String headerValue) { + throw new UnsupportedOperationException(); + } + + @Override + public void setAll(Map values) { + throw new UnsupportedOperationException(); + } + + @Override + public Map toSingleValueMap() { + return Collections.unmodifiableMap(this.headers.toSingleValueMap()); + } + + @Override + public Set keySet() { + return Collections.unmodifiableSet(this.headers.keySet()); + } + + @Override + public List put(String key, List value) { + throw new UnsupportedOperationException(); + } + + @Override + public List remove(Object key) { + throw new UnsupportedOperationException(); + } + + @Override + public void putAll(Map> map) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(); + } + + @Override + public Collection> values() { + return Collections.unmodifiableCollection(this.headers.values()); + } + + @Override + public Set>> entrySet() { + return Collections.unmodifiableSet(this.headers.entrySet().stream() + .map(AbstractMap.SimpleImmutableEntry::new) + .collect(Collectors.toSet())); + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java index a88da3fb2b66..29cd6228d388 100644 --- a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java @@ -153,6 +153,9 @@ public List get(Object key) { Assert.isInstanceOf(String.class, key, "Key must be a String-based header name"); Collection values1 = servletResponse.getHeaders((String) key); + if (headersWritten) { + return new ArrayList<>(values1); + } boolean isEmpty1 = CollectionUtils.isEmpty(values1); List values2 = super.get(key); diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerServerHttpResponse.java index 78c537e152e3..39dd29a9e547 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerServerHttpResponse.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.http.HttpHeaders; /** * Abstract base class for listener-based server responses, e.g. Servlet 3.1 @@ -41,6 +42,10 @@ public AbstractListenerServerHttpResponse(DataBufferFactory dataBufferFactory) { super(dataBufferFactory); } + public AbstractListenerServerHttpResponse(DataBufferFactory dataBufferFactory, HttpHeaders headers) { + super(dataBufferFactory, headers); + } + @Override protected final Mono writeWithInternal(Publisher body) { diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java index 2b21a2ef4d3d..b8356dbd8c41 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java @@ -75,9 +75,14 @@ private enum State {NEW, COMMITTING, COMMITTED} public AbstractServerHttpResponse(DataBufferFactory dataBufferFactory) { + this(dataBufferFactory, new HttpHeaders()); + } + + public AbstractServerHttpResponse(DataBufferFactory dataBufferFactory, HttpHeaders headers) { Assert.notNull(dataBufferFactory, "DataBufferFactory must not be null"); + Assert.notNull(headers, "HttpHeaders must not be null"); this.dataBufferFactory = dataBufferFactory; - this.headers = new HttpHeaders(); + this.headers = headers; this.cookies = new LinkedMultiValueMap<>(); } diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHeadersAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHeadersAdapter.java new file mode 100644 index 000000000000..5ffc0779acc2 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHeadersAdapter.java @@ -0,0 +1,222 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.AbstractSet; +import java.util.Collection; +import java.util.Enumeration; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.eclipse.jetty.http.HttpField; +import org.eclipse.jetty.http.HttpFields; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * {@code MultiValueMap} implementation for wrapping Jetty HTTP headers. + * + * @author Brian Clozel + * @since 5.1 + */ +class JettyHeadersAdapter implements MultiValueMap { + + private final HttpFields headers; + + JettyHeadersAdapter(HttpFields headers) { + this.headers = headers; + } + + @Override + public String getFirst(String key) { + return this.headers.get(key); + } + + @Override + public void add(String key, @Nullable String value) { + this.headers.add(key, value); + } + + @Override + public void addAll(String key, List values) { + values.forEach(value -> add(key, value)); + } + + @Override + public void addAll(MultiValueMap values) { + values.forEach(this::addAll); + } + + @Override + public void set(String key, @Nullable String value) { + this.headers.put(key, value); + } + + @Override + public void setAll(Map values) { + values.forEach(this::set); + } + + @Override + public Map toSingleValueMap() { + Map singleValueMap = new LinkedHashMap<>(this.headers.size()); + Iterator iterator = this.headers.iterator(); + iterator.forEachRemaining(field -> { + if (!singleValueMap.containsKey(field.getName())) { + singleValueMap.put(field.getName(), field.getValue()); + } + }); + return singleValueMap; + } + + @Override + public int size() { + return this.headers.getFieldNamesCollection().size(); + } + + @Override + public boolean isEmpty() { + return this.headers.size() == 0; + } + + @Override + public boolean containsKey(Object key) { + if (key instanceof String) { + return this.headers.containsKey((String) key); + } + return false; + } + + @Override + public boolean containsValue(Object value) { + if (value instanceof String) { + return this.headers.stream() + .anyMatch(field -> field.contains((String) value)); + } + return false; + } + + @Nullable + @Override + public List get(Object key) { + if (key instanceof String) { + return this.headers.getValuesList((String) key); + } + return null; + } + + @Nullable + @Override + public List put(String key, List value) { + List oldValues = get(key); + this.headers.put(key, value); + return oldValues; + } + + @Nullable + @Override + public List remove(Object key) { + if (key instanceof String) { + List oldValues = get(key); + this.headers.remove((String) key); + return oldValues; + } + return null; + } + + @Override + public void putAll(Map> m) { + m.forEach(this::put); + } + + @Override + public void clear() { + this.headers.clear(); + } + + @Override + public Set keySet() { + return this.headers.getFieldNamesCollection(); + } + + @Override + public Collection> values() { + return this.headers.getFieldNamesCollection().stream() + .map(this.headers::getValuesList).collect(Collectors.toList()); + } + + @Override + public Set>> entrySet() { + return new AbstractSet>>() { + @Override + public Iterator>> iterator() { + return new EntryIterator(); + } + + @Override + public int size() { + return headers.size(); + } + }; + } + + private class EntryIterator implements Iterator>> { + + private Enumeration names = headers.getFieldNames(); + + @Override + public boolean hasNext() { + return this.names.hasMoreElements(); + } + + @Override + public Entry> next() { + return new HeaderEntry(this.names.nextElement()); + } + } + + private class HeaderEntry implements Entry> { + + private final String key; + + HeaderEntry(String key) { + this.key = key; + } + + @Override + public String getKey() { + return this.key.toString(); + } + + @Override + public List getValue() { + return headers.getValuesList(this.key); + } + + @Override + public List setValue(List value) { + List previousValues = headers.getValuesList(this.key); + headers.put(this.key, value); + return previousValues; + } + } +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHttpHandlerAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHttpHandlerAdapter.java index 76bcef85494d..bd97b1ef6f33 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHttpHandlerAdapter.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/JettyHttpHandlerAdapter.java @@ -17,15 +17,21 @@ package org.springframework.http.server.reactive; import java.io.IOException; +import java.net.URISyntaxException; import java.nio.ByteBuffer; import javax.servlet.AsyncContext; import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.eclipse.jetty.http.HttpFields; import org.eclipse.jetty.server.HttpOutput; +import org.eclipse.jetty.server.Request; +import org.eclipse.jetty.server.Response; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; +import org.springframework.http.HttpHeaders; /** * {@link ServletHttpHandlerAdapter} extension that uses Jetty APIs for writing @@ -42,6 +48,12 @@ public JettyHttpHandlerAdapter(HttpHandler httpHandler) { } + @Override + protected ServletServerHttpRequest createRequest(HttpServletRequest request, AsyncContext context) + throws IOException, URISyntaxException { + return new JettyServerHttpRequest(request, context, getServletPath(), getDataBufferFactory(), getBufferSize()); + } + @Override protected ServletServerHttpResponse createResponse(HttpServletResponse response, AsyncContext context, ServletServerHttpRequest request) throws IOException { @@ -50,14 +62,38 @@ protected ServletServerHttpResponse createResponse(HttpServletResponse response, response, context, getDataBufferFactory(), getBufferSize(), request); } + private static final class JettyServerHttpRequest extends ServletServerHttpRequest { + + JettyServerHttpRequest(HttpServletRequest request, AsyncContext asyncContext, + String servletPath, DataBufferFactory bufferFactory, int bufferSize) + throws IOException, URISyntaxException { + + super(createHeaders(request), request, asyncContext, servletPath, bufferFactory, bufferSize); + } + + private static HttpHeaders createHeaders(HttpServletRequest request) { + HttpFields fields = ((Request) request).getMetaData().getFields(); + return new HttpHeaders(new JettyHeadersAdapter(fields)); + } + } + private static final class JettyServerHttpResponse extends ServletServerHttpResponse { - public JettyServerHttpResponse(HttpServletResponse response, AsyncContext asyncContext, + JettyServerHttpResponse(HttpServletResponse response, AsyncContext asyncContext, DataBufferFactory bufferFactory, int bufferSize, ServletServerHttpRequest request) throws IOException { - super(response, asyncContext, bufferFactory, bufferSize, request); + super(createHeaders(response), response, asyncContext, bufferFactory, bufferSize, request); + } + + private static HttpHeaders createHeaders(HttpServletResponse response) { + HttpFields fields = ((Response) response).getHttpFields(); + return new HttpHeaders(new JettyHeadersAdapter(fields)); + } + + @Override + protected void applyHeaders() { } @Override diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/NettyHeadersAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/NettyHeadersAdapter.java new file mode 100644 index 000000000000..6d68ceb1d856 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/NettyHeadersAdapter.java @@ -0,0 +1,217 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.AbstractSet; +import java.util.Collection; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import io.netty.handler.codec.http.HttpHeaders; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * {@code MultiValueMap} implementation for wrapping Netty HTTP headers. + * + * @author Brian Clozel + * @since 5.1 + */ +class NettyHeadersAdapter implements MultiValueMap { + + private final HttpHeaders headers; + + NettyHeadersAdapter(HttpHeaders headers) { + this.headers = headers; + } + + @Override + @Nullable + public String getFirst(String key) { + return this.headers.get(key); + } + + @Override + public void add(String key, @Nullable String value) { + this.headers.add(key, value); + } + + @Override + public void addAll(String key, List values) { + this.headers.add(key, values); + } + + @Override + public void addAll(MultiValueMap values) { + values.forEach(this.headers::add); + } + + @Override + public void set(String key, @Nullable String value) { + this.headers.set(key, value); + } + + @Override + public void setAll(Map values) { + values.forEach(this.headers::set); + } + + @Override + public Map toSingleValueMap() { + Map singleValueMap = new LinkedHashMap<>(this.headers.size()); + this.headers.entries() + .forEach(entry -> { + if (!singleValueMap.containsKey(entry.getKey())) { + singleValueMap.put(entry.getKey(), entry.getValue()); + } + }); + return singleValueMap; + } + + @Override + public int size() { + return this.headers.size(); + } + + @Override + public boolean isEmpty() { + return this.headers.isEmpty(); + } + + @Override + public boolean containsKey(Object key) { + return (key instanceof String) && this.headers.contains((String) key); + } + + @Override + public boolean containsValue(Object value) { + return (value instanceof String) && + this.headers.entries().stream() + .anyMatch(entry -> value != null && value.equals(entry.getValue())); + } + + @Override + @Nullable + public List get(Object key) { + if (key instanceof String) { + return this.headers.getAll((String) key); + } + return null; + } + + @Nullable + @Override + public List put(String key, @Nullable List value) { + List previousValues = this.headers.getAll(key); + this.headers.add(key, value); + return previousValues; + } + + @Nullable + @Override + public List remove(Object key) { + if (key instanceof String) { + List previousValues = this.headers.getAll((String) key); + this.headers.remove((String) key); + return previousValues; + } + return null; + } + + @Override + public void putAll(Map> m) { + m.forEach(this.headers::add); + } + + @Override + public void clear() { + this.headers.clear(); + } + + @Override + public Set keySet() { + return this.headers.names(); + } + + @Override + public Collection> values() { + return this.headers.names().stream() + .map(this.headers::getAll).collect(Collectors.toList()); + } + + @Override + public Set>> entrySet() { + return new AbstractSet>>() { + @Override + public Iterator>> iterator() { + return new EntryIterator(); + } + + @Override + public int size() { + return headers.size(); + } + }; + } + + private class EntryIterator implements Iterator>> { + + private Iterator names = headers.names().iterator(); + + @Override + public boolean hasNext() { + return this.names.hasNext(); + } + + @Override + public Entry> next() { + return new HeaderEntry(this.names.next()); + } + } + + private class HeaderEntry implements Entry> { + + private final String key; + + HeaderEntry(String key) { + this.key = key; + } + + @Override + public String getKey() { + return this.key; + } + + @Override + public List getValue() { + return headers.getAll(this.key); + } + + @Override + public List setValue(List value) { + List previousValues = headers.getAll(this.key); + headers.set(this.key, value); + return previousValues; + } + } + +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java index fb3a88e3bbbc..884f2ccd5100 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpRequest.java @@ -125,11 +125,8 @@ private static String resolveRequestUri(HttpServerRequest request) { } private static HttpHeaders initHeaders(HttpServerRequest channel) { - HttpHeaders headers = new HttpHeaders(); - for (String name : channel.requestHeaders().names()) { - headers.put(name, channel.requestHeaders().getAll(name)); - } - return headers; + NettyHeadersAdapter headersMap = new NettyHeadersAdapter(channel.requestHeaders()); + return new HttpHeaders(headersMap); } diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java index e2dd16ed7f6a..b536a6d9604c 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java @@ -19,6 +19,7 @@ import java.nio.file.Path; import io.netty.buffer.ByteBuf; +import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.cookie.Cookie; import io.netty.handler.codec.http.cookie.DefaultCookie; @@ -30,6 +31,7 @@ import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.NettyDataBufferFactory; +import org.springframework.http.HttpHeaders; import org.springframework.http.ResponseCookie; import org.springframework.http.ZeroCopyHttpOutputMessage; import org.springframework.util.Assert; @@ -47,11 +49,16 @@ class ReactorServerHttpResponse extends AbstractServerHttpResponse implements Ze public ReactorServerHttpResponse(HttpServerResponse response, DataBufferFactory bufferFactory) { - super(bufferFactory); + super(bufferFactory, initHeaders(response)); Assert.notNull(response, "HttpServerResponse must not be null"); this.response = response; } + private static HttpHeaders initHeaders(HttpServerResponse channel) { + channel.responseHeaders().remove(HttpHeaderNames.TRANSFER_ENCODING); + NettyHeadersAdapter headersMap = new NettyHeadersAdapter(channel.responseHeaders()); + return new HttpHeaders(headersMap); + } @SuppressWarnings("unchecked") @Override @@ -80,11 +87,9 @@ protected Mono writeAndFlushWithInternal(Publisher { - for (String value : headerValues) { - this.response.responseHeaders().add(headerName, value); - } - }); + if (getHeaders().getContentLength() == -1) { + this.response.chunkedTransfer(true); + } } @Override diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java index b2437e268cf9..e2d6ec263487 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java @@ -69,12 +69,18 @@ class ServletServerHttpRequest extends AbstractServerHttpRequest { private final byte[] buffer; - public ServletServerHttpRequest(HttpServletRequest request, AsyncContext asyncContext, String servletPath, DataBufferFactory bufferFactory, int bufferSize) throws IOException, URISyntaxException { - super(initUri(request), request.getContextPath() + servletPath, initHeaders(request)); + this(createDefaultHttpHeaders(request), request, asyncContext, servletPath, bufferFactory, bufferSize); + } + + public ServletServerHttpRequest(HttpHeaders headers, HttpServletRequest request, AsyncContext asyncContext, + String servletPath, DataBufferFactory bufferFactory, int bufferSize) + throws IOException, URISyntaxException { + + super(initUri(request), request.getContextPath() + servletPath, initHeaders(headers, request)); Assert.notNull(bufferFactory, "'bufferFactory' must not be null"); Assert.isTrue(bufferSize > 0, "'bufferSize' must be higher than 0"); @@ -91,6 +97,18 @@ public ServletServerHttpRequest(HttpServletRequest request, AsyncContext asyncCo this.bodyPublisher.registerReadListener(); } + + private static HttpHeaders createDefaultHttpHeaders(HttpServletRequest request) { + HttpHeaders headers = new HttpHeaders(); + for (Enumeration names = request.getHeaderNames(); names.hasMoreElements(); ) { + String name = (String) names.nextElement(); + for (Enumeration values = request.getHeaders(name); values.hasMoreElements(); ) { + headers.add(name, (String) values.nextElement()); + } + } + return headers; + } + private static URI initUri(HttpServletRequest request) throws URISyntaxException { Assert.notNull(request, "'request' must not be null"); StringBuffer url = request.getRequestURL(); @@ -101,16 +119,7 @@ private static URI initUri(HttpServletRequest request) throws URISyntaxException return new URI(url.toString()); } - private static HttpHeaders initHeaders(HttpServletRequest request) { - HttpHeaders headers = new HttpHeaders(); - for (Enumeration names = request.getHeaderNames(); - names.hasMoreElements(); ) { - String name = (String) names.nextElement(); - for (Enumeration values = request.getHeaders(name); - values.hasMoreElements(); ) { - headers.add(name, (String) values.nextElement()); - } - } + private static HttpHeaders initHeaders(HttpHeaders headers, HttpServletRequest request) { MediaType contentType = headers.getContentType(); if (contentType == null) { String requestContentType = request.getContentType(); @@ -231,7 +240,8 @@ public T getNativeRequest() { private final class RequestAsyncListener implements AsyncListener { @Override - public void onStartAsync(AsyncEvent event) {} + public void onStartAsync(AsyncEvent event) { + } @Override public void onTimeout(AsyncEvent event) { diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java index 81d9c0c549d3..9d11809031ac 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java @@ -33,6 +33,7 @@ import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ResponseCookie; import org.springframework.lang.Nullable; @@ -62,11 +63,16 @@ class ServletServerHttpResponse extends AbstractListenerServerHttpResponse { private final ServletServerHttpRequest request; - public ServletServerHttpResponse(HttpServletResponse response, AsyncContext asyncContext, DataBufferFactory bufferFactory, int bufferSize, ServletServerHttpRequest request) throws IOException { - super(bufferFactory); + this(new HttpHeaders(), response, asyncContext, bufferFactory, bufferSize, request); + } + + public ServletServerHttpResponse(HttpHeaders headers, HttpServletResponse response, AsyncContext asyncContext, + DataBufferFactory bufferFactory, int bufferSize, ServletServerHttpRequest request) throws IOException { + + super(bufferFactory, headers); Assert.notNull(response, "HttpServletResponse must not be null"); Assert.notNull(bufferFactory, "DataBufferFactory must not be null"); diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHeadersAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHeadersAdapter.java new file mode 100644 index 000000000000..e667a3e86c07 --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHeadersAdapter.java @@ -0,0 +1,237 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.AbstractSet; +import java.util.Collection; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.tomcat.util.buf.MessageBytes; +import org.apache.tomcat.util.http.MimeHeaders; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * {@code MultiValueMap} implementation for wrapping Tomcat HTTP headers. + * + * @author Brian Clozel + * @since 5.1 + */ +class TomcatHeadersAdapter implements MultiValueMap { + + private final MimeHeaders headers; + + TomcatHeadersAdapter(MimeHeaders headers) { + this.headers = headers; + } + + @Override + public String getFirst(String key) { + return this.headers.getHeader(key); + } + + @Override + public void add(String key, String value) { + this.headers.addValue(key).setString(value); + } + + @Override + public void addAll(String key, List values) { + values.forEach(value -> add(key, value)); + } + + @Override + public void addAll(MultiValueMap values) { + values.forEach(this::addAll); + } + + @Override + public void set(String key, String value) { + this.headers.setValue(key).setString(value); + } + + @Override + public void setAll(Map values) { + values.forEach(this::set); + } + + @Override + public Map toSingleValueMap() { + Map singleValueMap = new LinkedHashMap<>(this.headers.size()); + this.keySet().forEach(key -> singleValueMap.put(key, getFirst(key))); + return singleValueMap; + } + + @Override + public int size() { + Enumeration names = this.headers.names(); + int size = 0; + while (names.hasMoreElements()) { + size++; + names.nextElement(); + } + return size; + } + + @Override + public boolean isEmpty() { + return this.headers.size() == 0; + } + + @Override + public boolean containsKey(Object key) { + if (key instanceof String) { + return this.headers.findHeader((String) key, 0) != -1; + } + return false; + } + + @Override + public boolean containsValue(Object value) { + if (value instanceof String) { + MessageBytes needle = MessageBytes.newInstance(); + needle.setString((String) value); + for (int i = 0; i < this.headers.size(); i++) { + if (this.headers.getValue(i).equals(needle)) { + return true; + } + } + } + return false; + } + + @Override + @Nullable + public List get(Object key) { + if (key instanceof String) { + return Collections.list(this.headers.values((String) key)); + } + return null; + } + + @Override + @Nullable + public List put(String key, List value) { + List previousValues = get(key); + value.forEach(v -> this.headers.addValue(key).setString(v)); + return previousValues; + } + + @Override + @Nullable + public List remove(Object key) { + if (key instanceof String) { + List previousValues = get(key); + this.headers.removeHeader((String) key); + return previousValues; + } + return null; + } + + @Override + public void putAll(Map> m) { + m.forEach(this::put); + } + + @Override + public void clear() { + this.headers.clear(); + } + + @Override + public Set keySet() { + Set result = new HashSet<>(8); + Enumeration names = this.headers.names(); + while (names.hasMoreElements()) { + result.add(names.nextElement()); + } + return result; + } + + @Override + public Collection> values() { + return keySet().stream().map(this::get).collect(Collectors.toList()); + } + + @Override + public Set>> entrySet() { + return new AbstractSet>>() { + @Override + public Iterator>> iterator() { + return new EntryIterator(); + } + + @Override + public int size() { + return headers.size(); + } + }; + } + + private class EntryIterator implements Iterator>> { + + private Enumeration names = headers.names(); + + @Override + public boolean hasNext() { + return this.names.hasMoreElements(); + } + + @Override + public Entry> next() { + return new HeaderEntry(this.names.nextElement()); + } + } + + private final class HeaderEntry implements Entry> { + + private final String key; + + private HeaderEntry(String key) { + this.key = key; + } + + @Override + public String getKey() { + return this.key; + } + + @Nullable + @Override + public List getValue() { + return get(this.key); + } + + @Nullable + @Override + public List setValue(List value) { + List previous = getValue(); + headers.removeHeader(this.key); + addAll(this.key, value); + return previous; + } + } +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHttpHandlerAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHttpHandlerAdapter.java index 20412344288b..89851e938ab4 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHttpHandlerAdapter.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/TomcatHttpHandlerAdapter.java @@ -17,6 +17,7 @@ package org.springframework.http.server.reactive; import java.io.IOException; +import java.lang.reflect.Field; import java.net.URISyntaxException; import java.nio.ByteBuffer; import javax.servlet.AsyncContext; @@ -27,17 +28,25 @@ import org.apache.catalina.connector.CoyoteInputStream; import org.apache.catalina.connector.CoyoteOutputStream; +import org.apache.catalina.connector.RequestFacade; +import org.apache.catalina.connector.ResponseFacade; +import org.apache.coyote.Request; +import org.apache.coyote.Response; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.HttpHeaders; import org.springframework.util.Assert; +import org.springframework.util.ReflectionUtils; /** * {@link ServletHttpHandlerAdapter} extension that uses Tomcat APIs for reading * from the request and writing to the response with {@link ByteBuffer}. * * @author Violeta Georgieva + * @author Brian Clozel + * @author Brian Clozel * @since 5.0 * @see org.springframework.web.server.adapter.AbstractReactiveWebInitializer */ @@ -66,21 +75,39 @@ protected ServletServerHttpResponse createResponse(HttpServletResponse response, response, asyncContext, getDataBufferFactory(), getBufferSize(), request); } + private static final class TomcatServerHttpRequest extends ServletServerHttpRequest { - private final class TomcatServerHttpRequest extends ServletServerHttpRequest { + private static final Field COYOTE_REQUEST_FIELD = ReflectionUtils.findField(RequestFacade.class, "request"); - public TomcatServerHttpRequest(HttpServletRequest request, AsyncContext context, + private final int bufferSize; + + private final DataBufferFactory factory; + + static { + ReflectionUtils.makeAccessible(COYOTE_REQUEST_FIELD); + } + + TomcatServerHttpRequest(HttpServletRequest request, AsyncContext context, String servletPath, DataBufferFactory factory, int bufferSize) throws IOException, URISyntaxException { - super(request, context, servletPath, factory, bufferSize); + super(createTomcatHttpHeaders(request), request, context, servletPath, factory, bufferSize); + this.factory = factory; + this.bufferSize = bufferSize; + } + + private static HttpHeaders createTomcatHttpHeaders(HttpServletRequest request) { + Request tomcatRequest = ((org.apache.catalina.connector.Request) ReflectionUtils + .getField(COYOTE_REQUEST_FIELD, request)).getCoyoteRequest(); + TomcatHeadersAdapter headers = new TomcatHeadersAdapter(tomcatRequest.getMimeHeaders()); + return new HttpHeaders(headers); } @Override protected DataBuffer readFromInputStream() throws IOException { boolean release = true; - int capacity = getBufferSize(); - DataBuffer dataBuffer = getDataBufferFactory().allocateBuffer(capacity); + int capacity = this.bufferSize; + DataBuffer dataBuffer = this.factory.allocateBuffer(capacity); try { ByteBuffer byteBuffer = dataBuffer.asByteBuffer(0, capacity); @@ -111,10 +138,27 @@ else if (read == -1) { private static final class TomcatServerHttpResponse extends ServletServerHttpResponse { - public TomcatServerHttpResponse(HttpServletResponse response, AsyncContext context, + private static final Field COYOTE_RESPONSE_FIELD = ReflectionUtils.findField(ResponseFacade.class, "response"); + + static { + ReflectionUtils.makeAccessible(COYOTE_RESPONSE_FIELD); + } + + TomcatServerHttpResponse(HttpServletResponse response, AsyncContext context, DataBufferFactory factory, int bufferSize, ServletServerHttpRequest request) throws IOException { - super(response, context, factory, bufferSize, request); + super(createTomcatHttpHeaders(response), response, context, factory, bufferSize, request); + } + + private static HttpHeaders createTomcatHttpHeaders(HttpServletResponse response) { + Response tomcatResponse = ((org.apache.catalina.connector.Response) ReflectionUtils + .getField(COYOTE_RESPONSE_FIELD, response)).getCoyoteResponse(); + TomcatHeadersAdapter headers = new TomcatHeadersAdapter(tomcatResponse.getMimeHeaders()); + return new HttpHeaders(headers); + } + + @Override + protected void applyHeaders() { } @Override diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowHeadersAdapter.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowHeadersAdapter.java new file mode 100644 index 000000000000..3e817c906f6a --- /dev/null +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowHeadersAdapter.java @@ -0,0 +1,222 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.http.server.reactive; + +import java.util.AbstractSet; +import java.util.Collection; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import io.undertow.util.HeaderMap; +import io.undertow.util.HeaderValues; +import io.undertow.util.HttpString; + +import org.springframework.lang.Nullable; +import org.springframework.util.MultiValueMap; + +/** + * {@code MultiValueMap} implementation for wrapping Undertow HTTP headers. + * + * @author Brian Clozel + * @since 5.1 + */ +class UndertowHeadersAdapter implements MultiValueMap { + + private final HeaderMap headers; + + UndertowHeadersAdapter(HeaderMap headers) { + this.headers = headers; + } + + @Override + public String getFirst(String key) { + return this.headers.getFirst(key); + } + + @Override + public void add(String key, @Nullable String value) { + this.headers.add(HttpString.tryFromString(key), value); + } + + @Override + @SuppressWarnings("unchecked") + public void addAll(String key, List values) { + this.headers.addAll(HttpString.tryFromString(key), (List) values); + } + + @Override + public void addAll(MultiValueMap values) { + values.forEach((key, list) -> this.headers.addAll(HttpString.tryFromString(key), list)); + } + + @Override + public void set(String key, @Nullable String value) { + this.headers.put(HttpString.tryFromString(key), value); + } + + @Override + public void setAll(Map values) { + values.forEach((key, list) -> this.headers.put(HttpString.tryFromString(key), list)); + } + + @Override + public Map toSingleValueMap() { + Map singleValueMap = new LinkedHashMap<>(this.headers.size()); + this.headers.forEach(values -> + singleValueMap.put(values.getHeaderName().toString(), values.getFirst())); + return singleValueMap; + } + + @Override + public int size() { + return this.headers.size(); + } + + @Override + public boolean isEmpty() { + return this.headers.size() == 0; + } + + @Override + public boolean containsKey(Object key) { + if (key instanceof String) { + return this.headers.contains((String) key); + } + return false; + } + + @Override + public boolean containsValue(Object value) { + if (value instanceof String) { + return this.headers.getHeaderNames().stream() + .map(this.headers::get) + .anyMatch(values -> values.contains(value)); + } + return false; + } + + @Override + @Nullable + public List get(Object key) { + if (key instanceof String) { + return this.headers.get((String) key); + } + return null; + } + + @Override + @Nullable + public List put(String key, List value) { + HeaderValues previousValues = this.headers.get(key); + this.headers.putAll(HttpString.tryFromString(key), value); + return previousValues; + } + + @Override + @Nullable + public List remove(Object key) { + if (key instanceof String) { + this.headers.remove((String) key); + } + return null; + } + + @Override + public void putAll(Map> m) { + m.forEach((key, values) -> + this.headers.putAll(HttpString.tryFromString(key), values)); + } + + @Override + public void clear() { + this.headers.clear(); + } + + @Override + public Set keySet() { + return this.headers.getHeaderNames().stream() + .map(HttpString::toString) + .collect(Collectors.toSet()); + } + + @Override + public Collection> values() { + return this.headers.getHeaderNames().stream() + .map(this.headers::get) + .collect(Collectors.toList()); + } + + @Override + public Set>> entrySet() { + return new AbstractSet>>() { + @Override + public Iterator>> iterator() { + return new EntryIterator(); + } + + @Override + public int size() { + return headers.size(); + } + }; + } + + private class EntryIterator implements Iterator>> { + + private Iterator names = headers.getHeaderNames().iterator(); + + @Override + public boolean hasNext() { + return this.names.hasNext(); + } + + @Override + public Entry> next() { + return new HeaderEntry(this.names.next()); + } + } + + private class HeaderEntry implements Entry> { + + private final HttpString key; + + HeaderEntry(HttpString key) { + this.key = key; + } + + @Override + public String getKey() { + return this.key.toString(); + } + + @Override + public List getValue() { + return headers.get(this.key); + } + + @Override + public List setValue(List value) { + List previousValues = headers.get(this.key); + headers.putAll(this.key, value); + return previousValues; + } + } +} diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java index da6c061391e0..6c68f5a52869 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpRequest.java @@ -30,7 +30,6 @@ import io.undertow.connector.PooledByteBuffer; import io.undertow.server.HttpServerExchange; import io.undertow.server.handlers.Cookie; -import io.undertow.util.HeaderValues; import org.xnio.channels.StreamSourceChannel; import reactor.core.publisher.Flux; @@ -79,11 +78,9 @@ private static URI initUri(HttpServerExchange exchange) throws URISyntaxExceptio } private static HttpHeaders initHeaders(HttpServerExchange exchange) { - HttpHeaders headers = new HttpHeaders(); - for (HeaderValues values : exchange.getRequestHeaders()) { - headers.put(values.getHeaderName().toString(), values); - } - return headers; + UndertowHeadersAdapter headersMap = + new UndertowHeadersAdapter(exchange.getRequestHeaders()); + return new HttpHeaders(headersMap); } @Override diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java index 1ad997896ccc..a9533379cd92 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java @@ -25,7 +25,6 @@ import io.undertow.server.HttpServerExchange; import io.undertow.server.handlers.Cookie; import io.undertow.server.handlers.CookieImpl; -import io.undertow.util.HttpString; import org.reactivestreams.Processor; import org.reactivestreams.Publisher; import org.xnio.channels.Channels; @@ -35,6 +34,7 @@ import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DataBufferFactory; import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.http.HttpHeaders; import org.springframework.http.ResponseCookie; import org.springframework.http.ZeroCopyHttpOutputMessage; import org.springframework.lang.Nullable; @@ -58,15 +58,21 @@ class UndertowServerHttpResponse extends AbstractListenerServerHttpResponse impl private StreamSinkChannel responseChannel; - public UndertowServerHttpResponse( + UndertowServerHttpResponse( HttpServerExchange exchange, DataBufferFactory bufferFactory, UndertowServerHttpRequest request) { - super(bufferFactory); + super(bufferFactory, createHeaders(exchange)); Assert.notNull(exchange, "HttpServerExchange must not be null"); this.exchange = exchange; this.request = request; } + private static HttpHeaders createHeaders(HttpServerExchange exchange) { + UndertowHeadersAdapter headersMap = + new UndertowHeadersAdapter(exchange.getResponseHeaders()); + return new HttpHeaders(headersMap); + } + @SuppressWarnings("unchecked") @Override @@ -85,8 +91,6 @@ protected void applyStatusCode() { @Override protected void applyHeaders() { - getHeaders().forEach((headerName, headerValues) -> - this.exchange.getResponseHeaders().addAll(HttpString.tryFromString(headerName), headerValues)); } @Override diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java index b673e1e463a9..f5027686de29 100644 --- a/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java +++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java @@ -87,7 +87,7 @@ public boolean process(@Nullable CorsConfiguration config, ServerWebExchange exc } private boolean responseHasCors(ServerHttpResponse response) { - return (response.getHeaders().getAccessControlAllowOrigin() != null); + return response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN) != null; } /** diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java index f5061f960ee1..d3e353bf9086 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java @@ -272,7 +272,10 @@ private HandshakeInfo createHandshakeInfo(ServerWebExchange exchange, ServerHttp @Nullable String protocol, Map attributes) { URI uri = request.getURI(); - HttpHeaders headers = request.getHeaders(); + // Copy request headers, as they might be pooled and recycled by + // the server implementation once the handshake HTTP exchange is done. + HttpHeaders headers = new HttpHeaders(); + headers.addAll(request.getHeaders()); Mono principal = exchange.getPrincipal(); String logPrefix = exchange.getLogPrefix(); InetSocketAddress remoteAddress = request.getRemoteAddress();