diff --git a/header.go b/header.go index 76b7d35eb0..5665e79e59 100644 --- a/header.go +++ b/header.go @@ -65,6 +65,7 @@ type RequestHeader struct { noHTTP11 bool connectionClose bool noDefaultContentType bool + disableSpecialHeader bool // These two fields have been moved close to other bool fields // for reducing RequestHeader object size. @@ -360,6 +361,9 @@ func (h *ResponseHeader) SetServerBytes(server []byte) { // ContentType returns Content-Type header value. func (h *RequestHeader) ContentType() []byte { + if h.disableSpecialHeader { + return peekArgBytes(h.h, []byte(HeaderContentType)) + } return h.contentType } @@ -573,6 +577,9 @@ func (h *RequestHeader) MultipartFormBoundary() []byte { // Host returns Host header value. func (h *RequestHeader) Host() []byte { + if h.disableSpecialHeader { + return peekArgBytes(h.h, []byte(HeaderHost)) + } return h.host } @@ -588,6 +595,9 @@ func (h *RequestHeader) SetHostBytes(host []byte) { // UserAgent returns User-Agent header value. func (h *RequestHeader) UserAgent() []byte { + if h.disableSpecialHeader { + return peekArgBytes(h.h, []byte(HeaderUserAgent)) + } return h.userAgent } @@ -882,6 +892,23 @@ func (h *RequestHeader) Len() int { return n } +// DisableSpecialHeader disables special header processing. +// fasthttp will not set any special headers for you, such as Host, Content-Type, User-Agent, etc. +// You must set everything yourself. +// If RequestHeader.Read() is called, special headers will be ignored. +// This can be used to control case and order of special headers. +// This is generally not recommended. +func (h *RequestHeader) DisableSpecialHeader() { + h.disableSpecialHeader = true +} + +// EnableSpecialHeader enables special header processing. +// fasthttp will send Host, Content-Type, User-Agent, etc headers for you. +// This is suggested and enabled by default. +func (h *RequestHeader) EnableSpecialHeader() { + h.disableSpecialHeader = false +} + // DisableNormalizing disables header names' normalization. // // By default all the header names are normalized by uppercasing @@ -1316,7 +1343,7 @@ func (h *ResponseHeader) setNonSpecial(key []byte, value []byte) { // setSpecialHeader handles special headers and return true when a header is processed. func (h *RequestHeader) setSpecialHeader(key, value []byte) bool { - if len(key) == 0 { + if len(key) == 0 || h.disableSpecialHeader { return false } @@ -2471,12 +2498,12 @@ func (h *RequestHeader) AppendBytes(dst []byte) []byte { dst = append(dst, strCRLF...) userAgent := h.UserAgent() - if len(userAgent) > 0 { + if len(userAgent) > 0 && !h.disableSpecialHeader { dst = appendHeaderLine(dst, strUserAgent, userAgent) } host := h.Host() - if len(host) > 0 { + if len(host) > 0 && !h.disableSpecialHeader { dst = appendHeaderLine(dst, strHost, host) } @@ -2484,10 +2511,10 @@ func (h *RequestHeader) AppendBytes(dst []byte) []byte { if !h.noDefaultContentType && len(contentType) == 0 && !h.ignoreBody() { contentType = strDefaultContentType } - if len(contentType) > 0 { + if len(contentType) > 0 && !h.disableSpecialHeader { dst = appendHeaderLine(dst, strContentType, contentType) } - if len(h.contentLengthBytes) > 0 { + if len(h.contentLengthBytes) > 0 && !h.disableSpecialHeader { dst = appendHeaderLine(dst, strContentLength, h.contentLengthBytes) } @@ -2513,14 +2540,14 @@ func (h *RequestHeader) AppendBytes(dst []byte) []byte { // there is no need in h.collectCookies() here, since if cookies aren't collected yet, // they all are located in h.h. n := len(h.cookies) - if n > 0 { + if n > 0 && !h.disableSpecialHeader { dst = append(dst, strCookie...) dst = append(dst, strColonSpace...) dst = appendRequestCookieBytes(dst, h.cookies) dst = append(dst, strCRLF...) } - if h.ConnectionClose() { + if h.ConnectionClose() && !h.disableSpecialHeader { dst = appendHeaderLine(dst, strConnection, strClose) } @@ -2904,6 +2931,11 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { continue } + if h.disableSpecialHeader { + h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue) + continue + } + switch s.key[0] | 0x20 { case 'h': if caseInsensitiveCompare(s.key, strHost) { diff --git a/header_test.go b/header_test.go index 94eee8b244..f53959e193 100644 --- a/header_test.go +++ b/header_test.go @@ -348,6 +348,35 @@ func TestRequestRawHeaders(t *testing.T) { }) } +func TestRequestDisableSpecialHeaders(t *testing.T) { + t.Parallel() + + kvs := "Host: foobar\r\n" + + "User-Agent: ua\r\n" + + "Non-Special: val\r\n" + + "\r\n" + + var h RequestHeader + h.DisableSpecialHeader() + + s := "GET / HTTP/1.0\r\n" + kvs + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := h.Read(br); err != nil { + t.Fatalf("unexpected error: %v", err) + } + // assert order of all headers preserved + if h.String() != s { + t.Fatalf("Headers not equal: %q. Expecting %q", h.String(), s) + } + h.SetCanonical([]byte("host"), []byte("notfoobar")) + if string(h.Host()) != "foobar" { + t.Fatalf("unexpected: %q. Expecting %q", h.Host(), "foobar") + } + if h.String() != "GET / HTTP/1.0\r\nHost: foobar\r\nUser-Agent: ua\r\nNon-Special: val\r\nhost: notfoobar\r\n\r\n" { + t.Fatalf("custom special header ordering failed: %q", h.String()) + } +} + func TestRequestHeaderSetCookieWithSpecialChars(t *testing.T) { t.Parallel()