Handle Te header when http2

This commit is contained in:
SALLEYRON Julien 2018-08-27 18:10:03 +02:00 committed by Traefiker Bot
parent f586950528
commit 56488d435f
15 changed files with 306 additions and 134 deletions

4
Gopkg.lock generated
View file

@ -706,7 +706,7 @@
branch = "master" branch = "master"
name = "github.com/gorilla/websocket" name = "github.com/gorilla/websocket"
packages = ["."] packages = ["."]
revision = "eb925808374e5ca90c83401a40d711dc08c0c0f6" revision = "66b9c49e59c6c48f0ffce28c2d8b8a5678502c6d"
[[projects]] [[projects]]
name = "github.com/gravitational/trace" name = "github.com/gravitational/trace"
@ -1272,7 +1272,7 @@
"roundrobin", "roundrobin",
"utils" "utils"
] ]
revision = "885e42fe04d8e0efa6c18facad4e0fc5757cde9b" revision = "f6bbeac6d5c4c06f88ba07ed42983ff36a5b407e"
[[projects]] [[projects]]
name = "github.com/vulcand/predicate" name = "github.com/vulcand/predicate"

View file

@ -6,12 +6,14 @@ package websocket
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
"net/http" "net/http"
"net/http/httptrace"
"net/url" "net/url"
"strings" "strings"
"time" "time"
@ -51,6 +53,10 @@ type Dialer struct {
// NetDial is nil, net.Dial is used. // NetDial is nil, net.Dial is used.
NetDial func(network, addr string) (net.Conn, error) NetDial func(network, addr string) (net.Conn, error)
// NetDialContext specifies the dial function for creating TCP connections. If
// NetDialContext is nil, net.DialContext is used.
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
// Proxy specifies a function to return a proxy for a given // Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the // Request. If the function returns a non-nil error, the
// request is aborted with the provided error. // request is aborted with the provided error.
@ -69,6 +75,17 @@ type Dialer struct {
// do not limit the size of the messages that can be sent or received. // do not limit the size of the messages that can be sent or received.
ReadBufferSize, WriteBufferSize int ReadBufferSize, WriteBufferSize int
// WriteBufferPool is a pool of buffers for write operations. If the value
// is not set, then write buffers are allocated to the connection for the
// lifetime of the connection.
//
// A pool is most useful when the application has a modest volume of writes
// across a large number of connections.
//
// Applications should use a single pool for each unique value of
// WriteBufferSize.
WriteBufferPool BufferPool
// Subprotocols specifies the client's requested subprotocols. // Subprotocols specifies the client's requested subprotocols.
Subprotocols []string Subprotocols []string
@ -84,6 +101,11 @@ type Dialer struct {
Jar http.CookieJar Jar http.CookieJar
} }
// Dial creates a new client connection by calling DialContext with a background context.
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
return d.DialContext(context.Background(), urlStr, requestHeader)
}
var errMalformedURL = errors.New("malformed ws or wss URL") var errMalformedURL = errors.New("malformed ws or wss URL")
func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
@ -111,19 +133,20 @@ var DefaultDialer = &Dialer{
} }
// nilDialer is dialer to use when receiver is nil. // nilDialer is dialer to use when receiver is nil.
var nilDialer Dialer = *DefaultDialer var nilDialer = *DefaultDialer
// Dial creates a new client connection. Use requestHeader to specify the // DialContext creates a new client connection. Use requestHeader to specify the
// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
// Use the response.Header to get the selected subprotocol // Use the response.Header to get the selected subprotocol
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie). // (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
// //
// The context will be used in the request and in the Dialer
//
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a // If the WebSocket handshake fails, ErrBadHandshake is returned along with a
// non-nil *http.Response so that callers can handle redirects, authentication, // non-nil *http.Response so that callers can handle redirects, authentication,
// etcetera. The response body may not contain the entire response and does not // etcetera. The response body may not contain the entire response and does not
// need to be closed by the application. // need to be closed by the application.
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
if d == nil { if d == nil {
d = &nilDialer d = &nilDialer
} }
@ -161,6 +184,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
Header: make(http.Header), Header: make(http.Header),
Host: u.Host, Host: u.Host,
} }
req = req.WithContext(ctx)
// Set the cookies present in the cookie jar of the dialer // Set the cookies present in the cookie jar of the dialer
if d.Jar != nil { if d.Jar != nil {
@ -201,23 +225,33 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
} }
if d.EnableCompression { if d.EnableCompression {
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"}
} }
var deadline time.Time
if d.HandshakeTimeout != 0 { if d.HandshakeTimeout != 0 {
deadline = time.Now().Add(d.HandshakeTimeout) var cancel func()
ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout)
defer cancel()
} }
// Get network dial function. // Get network dial function.
netDial := d.NetDial var netDial func(network, add string) (net.Conn, error)
if netDial == nil {
netDialer := &net.Dialer{Deadline: deadline} if d.NetDialContext != nil {
netDial = netDialer.Dial netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
} else {
netDialer := &net.Dialer{}
netDial = func(network, addr string) (net.Conn, error) {
return netDialer.DialContext(ctx, network, addr)
}
} }
// If needed, wrap the dial function to set the connection deadline. // If needed, wrap the dial function to set the connection deadline.
if !deadline.Equal(time.Time{}) { if deadline, ok := ctx.Deadline(); ok {
forwardDial := netDial forwardDial := netDial
netDial = func(network, addr string) (net.Conn, error) { netDial = func(network, addr string) (net.Conn, error) {
c, err := forwardDial(network, addr) c, err := forwardDial(network, addr)
@ -249,7 +283,17 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
} }
hostPort, hostNoPort := hostPortNoPort(u) hostPort, hostNoPort := hostPortNoPort(u)
trace := httptrace.ContextClientTrace(ctx)
if trace != nil && trace.GetConn != nil {
trace.GetConn(hostPort)
}
netConn, err := netDial("tcp", hostPort) netConn, err := netDial("tcp", hostPort)
if trace != nil && trace.GotConn != nil {
trace.GotConn(httptrace.GotConnInfo{
Conn: netConn,
})
}
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -267,22 +311,31 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
} }
tlsConn := tls.Client(netConn, cfg) tlsConn := tls.Client(netConn, cfg)
netConn = tlsConn netConn = tlsConn
if err := tlsConn.Handshake(); err != nil {
return nil, nil, err var err error
if trace != nil {
err = doHandshakeWithTrace(trace, tlsConn, cfg)
} else {
err = doHandshake(tlsConn, cfg)
} }
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { if err != nil {
return nil, nil, err return nil, nil, err
}
} }
} }
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize) conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil)
if err := req.Write(netConn); err != nil { if err := req.Write(netConn); err != nil {
return nil, nil, err return nil, nil, err
} }
if trace != nil && trace.GotFirstResponseByte != nil {
if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 {
trace.GotFirstResponseByte()
}
}
resp, err := http.ReadResponse(conn.br, req) resp, err := http.ReadResponse(conn.br, req)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -328,3 +381,15 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
netConn = nil // to avoid close in defer. netConn = nil // to avoid close in defer.
return conn, resp, nil return conn, resp, nil
} }
func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error {
if err := tlsConn.Handshake(); err != nil {
return err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return err
}
}
return nil
}

View file

@ -223,6 +223,20 @@ func isValidReceivedCloseCode(code int) bool {
return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
} }
// BufferPool represents a pool of buffers. The *sync.Pool type satisfies this
// interface. The type of the value stored in a pool is not specified.
type BufferPool interface {
// Get gets a value from the pool or returns nil if the pool is empty.
Get() interface{}
// Put adds a value to the pool.
Put(interface{})
}
// writePoolData is the type added to the write buffer pool. This wrapper is
// used to prevent applications from peeking at and depending on the values
// added to the pool.
type writePoolData struct{ buf []byte }
// The Conn type represents a WebSocket connection. // The Conn type represents a WebSocket connection.
type Conn struct { type Conn struct {
conn net.Conn conn net.Conn
@ -232,6 +246,8 @@ type Conn struct {
// Write fields // Write fields
mu chan bool // used as mutex to protect write to conn mu chan bool // used as mutex to protect write to conn
writeBuf []byte // frame is constructed in this buffer. writeBuf []byte // frame is constructed in this buffer.
writePool BufferPool
writeBufSize int
writeDeadline time.Time writeDeadline time.Time
writer io.WriteCloser // the current writer returned to the application writer io.WriteCloser // the current writer returned to the application
isWriting bool // for best-effort concurrent write detection isWriting bool // for best-effort concurrent write detection
@ -263,64 +279,29 @@ type Conn struct {
newDecompressionReader func(io.Reader) io.ReadCloser newDecompressionReader func(io.Reader) io.ReadCloser
} }
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn {
return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil)
}
type writeHook struct {
p []byte
}
func (wh *writeHook) Write(p []byte) (int, error) {
wh.p = p
return len(p), nil
}
func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn {
mu := make(chan bool, 1)
mu <- true
var br *bufio.Reader
if readBufferSize == 0 && brw != nil && brw.Reader != nil {
// Reuse the supplied bufio.Reader if the buffer has a useful size.
// This code assumes that peek on a reader returns
// bufio.Reader.buf[:0].
brw.Reader.Reset(conn)
if p, err := brw.Reader.Peek(0); err == nil && cap(p) >= 256 {
br = brw.Reader
}
}
if br == nil { if br == nil {
if readBufferSize == 0 { if readBufferSize == 0 {
readBufferSize = defaultReadBufferSize readBufferSize = defaultReadBufferSize
} } else if readBufferSize < maxControlFramePayloadSize {
if readBufferSize < maxControlFramePayloadSize { // must be large enough for control frame
readBufferSize = maxControlFramePayloadSize readBufferSize = maxControlFramePayloadSize
} }
br = bufio.NewReaderSize(conn, readBufferSize) br = bufio.NewReaderSize(conn, readBufferSize)
} }
var writeBuf []byte if writeBufferSize <= 0 {
if writeBufferSize == 0 && brw != nil && brw.Writer != nil { writeBufferSize = defaultWriteBufferSize
// Use the bufio.Writer's buffer if the buffer has a useful size. This }
// code assumes that bufio.Writer.buf[:1] is passed to the writeBufferSize += maxFrameHeaderSize
// bufio.Writer's underlying writer.
var wh writeHook if writeBuf == nil && writeBufferPool == nil {
brw.Writer.Reset(&wh) writeBuf = make([]byte, writeBufferSize)
brw.Writer.WriteByte(0)
brw.Flush()
if cap(wh.p) >= maxFrameHeaderSize+256 {
writeBuf = wh.p[:cap(wh.p)]
}
}
if writeBuf == nil {
if writeBufferSize == 0 {
writeBufferSize = defaultWriteBufferSize
}
writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize)
} }
mu := make(chan bool, 1)
mu <- true
c := &Conn{ c := &Conn{
isServer: isServer, isServer: isServer,
br: br, br: br,
@ -328,6 +309,8 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in
mu: mu, mu: mu,
readFinal: true, readFinal: true,
writeBuf: writeBuf, writeBuf: writeBuf,
writePool: writeBufferPool,
writeBufSize: writeBufferSize,
enableWriteCompression: true, enableWriteCompression: true,
compressionLevel: defaultCompressionLevel, compressionLevel: defaultCompressionLevel,
} }
@ -370,6 +353,15 @@ func (c *Conn) writeFatal(err error) error {
return err return err
} }
func (c *Conn) read(n int) ([]byte, error) {
p, err := c.br.Peek(n)
if err == io.EOF {
err = errUnexpectedEOF
}
c.br.Discard(len(p))
return p, err
}
func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error { func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error {
<-c.mu <-c.mu
defer func() { c.mu <- true }() defer func() { c.mu <- true }()
@ -475,7 +467,19 @@ func (c *Conn) prepWrite(messageType int) error {
c.writeErrMu.Lock() c.writeErrMu.Lock()
err := c.writeErr err := c.writeErr
c.writeErrMu.Unlock() c.writeErrMu.Unlock()
return err if err != nil {
return err
}
if c.writeBuf == nil {
wpd, ok := c.writePool.Get().(writePoolData)
if ok {
c.writeBuf = wpd.buf
} else {
c.writeBuf = make([]byte, c.writeBufSize)
}
}
return nil
} }
// NextWriter returns a writer for the next message to send. The writer's Close // NextWriter returns a writer for the next message to send. The writer's Close
@ -601,6 +605,10 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
if final { if final {
c.writer = nil c.writer = nil
if c.writePool != nil {
c.writePool.Put(writePoolData{buf: c.writeBuf})
c.writeBuf = nil
}
return nil return nil
} }

View file

@ -1,18 +0,0 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.5
package websocket
import "io"
func (c *Conn) read(n int) ([]byte, error) {
p, err := c.br.Peek(n)
if err == io.EOF {
err = errUnexpectedEOF
}
c.br.Discard(len(p))
return p, err
}

View file

@ -1,21 +0,0 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.5
package websocket
import "io"
func (c *Conn) read(n int) ([]byte, error) {
p, err := c.br.Peek(n)
if err == io.EOF {
err = errUnexpectedEOF
}
if len(p) > 0 {
// advance over the bytes just read
io.ReadFull(c.br, p)
}
return p, err
}

View file

@ -19,7 +19,6 @@ import (
type PreparedMessage struct { type PreparedMessage struct {
messageType int messageType int
data []byte data []byte
err error
mu sync.Mutex mu sync.Mutex
frames map[prepareKey]*preparedFrame frames map[prepareKey]*preparedFrame
} }

View file

@ -14,7 +14,7 @@ import (
"strings" "strings"
) )
type netDialerFunc func(netowrk, addr string) (net.Conn, error) type netDialerFunc func(network, addr string) (net.Conn, error)
func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) { func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
return fn(network, addr) return fn(network, addr)

View file

@ -7,7 +7,7 @@ package websocket
import ( import (
"bufio" "bufio"
"errors" "errors"
"net" "io"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -33,10 +33,23 @@ type Upgrader struct {
// or received. // or received.
ReadBufferSize, WriteBufferSize int ReadBufferSize, WriteBufferSize int
// WriteBufferPool is a pool of buffers for write operations. If the value
// is not set, then write buffers are allocated to the connection for the
// lifetime of the connection.
//
// A pool is most useful when the application has a modest volume of writes
// across a large number of connections.
//
// Applications should use a single pool for each unique value of
// WriteBufferSize.
WriteBufferPool BufferPool
// Subprotocols specifies the server's supported protocols in order of // Subprotocols specifies the server's supported protocols in order of
// preference. If this field is set, then the Upgrade method negotiates a // preference. If this field is not nil, then the Upgrade method negotiates a
// subprotocol by selecting the first match in this list with a protocol // subprotocol by selecting the first match in this list with a protocol
// requested by the client. // requested by the client. If there's no match, then no protocol is
// negotiated (the Sec-Websocket-Protocol header is not included in the
// handshake response).
Subprotocols []string Subprotocols []string
// Error specifies the function for generating HTTP error responses. If Error // Error specifies the function for generating HTTP error responses. If Error
@ -103,7 +116,7 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
// //
// The responseHeader is included in the response to the client's upgrade // The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the // request. Use the responseHeader to specify cookies (Set-Cookie) and the
// application negotiated subprotocol (Sec-Websocket-Protocol). // application negotiated subprotocol (Sec-WebSocket-Protocol).
// //
// If the upgrade fails, then Upgrade replies to the client with an HTTP error // If the upgrade fails, then Upgrade replies to the client with an HTTP error
// response. // response.
@ -127,7 +140,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
} }
if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok { if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-Websocket-Extensions' headers are unsupported") return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported")
} }
checkOrigin := u.CheckOrigin checkOrigin := u.CheckOrigin
@ -140,7 +153,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
challengeKey := r.Header.Get("Sec-Websocket-Key") challengeKey := r.Header.Get("Sec-Websocket-Key")
if challengeKey == "" { if challengeKey == "" {
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-Websocket-Key' header is missing or blank") return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-WebSocket-Key' header is missing or blank")
} }
subprotocol := u.selectSubprotocol(r, responseHeader) subprotocol := u.selectSubprotocol(r, responseHeader)
@ -157,17 +170,12 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
} }
} }
var (
netConn net.Conn
err error
)
h, ok := w.(http.Hijacker) h, ok := w.(http.Hijacker)
if !ok { if !ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
} }
var brw *bufio.ReadWriter var brw *bufio.ReadWriter
netConn, brw, err = h.Hijack() netConn, brw, err := h.Hijack()
if err != nil { if err != nil {
return u.returnError(w, r, http.StatusInternalServerError, err.Error()) return u.returnError(w, r, http.StatusInternalServerError, err.Error())
} }
@ -177,7 +185,21 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
return nil, errors.New("websocket: client sent data before handshake is complete") return nil, errors.New("websocket: client sent data before handshake is complete")
} }
c := newConnBRW(netConn, true, u.ReadBufferSize, u.WriteBufferSize, brw) var br *bufio.Reader
if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 {
// Reuse hijacked buffered reader as connection reader.
br = brw.Reader
}
buf := bufioWriterBuffer(netConn, brw.Writer)
var writeBuf []byte
if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
// Reuse hijacked write buffer as connection buffer.
writeBuf = buf
}
c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
c.subprotocol = subprotocol c.subprotocol = subprotocol
if compress { if compress {
@ -185,17 +207,23 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
c.newDecompressionReader = decompressNoContextTakeover c.newDecompressionReader = decompressNoContextTakeover
} }
p := c.writeBuf[:0] // Use larger of hijacked buffer and connection write buffer for header.
p := buf
if len(c.writeBuf) > len(p) {
p = c.writeBuf
}
p = p[:0]
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
p = append(p, computeAcceptKey(challengeKey)...) p = append(p, computeAcceptKey(challengeKey)...)
p = append(p, "\r\n"...) p = append(p, "\r\n"...)
if c.subprotocol != "" { if c.subprotocol != "" {
p = append(p, "Sec-Websocket-Protocol: "...) p = append(p, "Sec-WebSocket-Protocol: "...)
p = append(p, c.subprotocol...) p = append(p, c.subprotocol...)
p = append(p, "\r\n"...) p = append(p, "\r\n"...)
} }
if compress { if compress {
p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
} }
for k, vs := range responseHeader { for k, vs := range responseHeader {
if k == "Sec-Websocket-Protocol" { if k == "Sec-Websocket-Protocol" {
@ -296,3 +324,40 @@ func IsWebSocketUpgrade(r *http.Request) bool {
return tokenListContainsValue(r.Header, "Connection", "upgrade") && return tokenListContainsValue(r.Header, "Connection", "upgrade") &&
tokenListContainsValue(r.Header, "Upgrade", "websocket") tokenListContainsValue(r.Header, "Upgrade", "websocket")
} }
// bufioReaderSize size returns the size of a bufio.Reader.
func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int {
// This code assumes that peek on a reset reader returns
// bufio.Reader.buf[:0].
// TODO: Use bufio.Reader.Size() after Go 1.10
br.Reset(originalReader)
if p, err := br.Peek(0); err == nil {
return cap(p)
}
return 0
}
// writeHook is an io.Writer that records the last slice passed to it vio
// io.Writer.Write.
type writeHook struct {
p []byte
}
func (wh *writeHook) Write(p []byte) (int, error) {
wh.p = p
return len(p), nil
}
// bufioWriterBuffer grabs the buffer from a bufio.Writer.
func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
// This code assumes that bufio.Writer.buf[:1] is passed to the
// bufio.Writer's underlying writer.
var wh writeHook
bw.Reset(&wh)
bw.WriteByte(0)
bw.Flush()
bw.Reset(originalWriter)
return wh.p[:cap(wh.p)]
}

19
vendor/github.com/gorilla/websocket/trace.go generated vendored Normal file
View file

@ -0,0 +1,19 @@
// +build go1.8
package websocket
import (
"crypto/tls"
"net/http/httptrace"
)
func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
if trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
err := doHandshake(tlsConn, cfg)
if trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
}
return err
}

12
vendor/github.com/gorilla/websocket/trace_17.go generated vendored Normal file
View file

@ -0,0 +1,12 @@
// +build !go1.8
package websocket
import (
"crypto/tls"
"net/http/httptrace"
)
func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
return doHandshake(tlsConn, cfg)
}

View file

@ -178,7 +178,7 @@ headers:
return false return false
} }
// parseExtensiosn parses WebSocket extensions from a header. // parseExtensions parses WebSocket extensions from a header.
func parseExtensions(header http.Header) []map[string]string { func parseExtensions(header http.Header) []map[string]string {
// From RFC 6455: // From RFC 6455:
// //

View file

@ -259,6 +259,8 @@ func New(setters ...optSetter) (*Forwarder, error) {
errorHandler: f.errHandler, errorHandler: f.errHandler,
} }
f.postConfig()
return f, nil return f, nil
} }
@ -342,7 +344,7 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request,
// WebSocket is only in http/1.1 // WebSocket is only in http/1.1
dialer.TLSClientConfig.NextProtos = []string{"http/1.1"} dialer.TLSClientConfig.NextProtos = []string{"http/1.1"}
} }
targetConn, resp, err := dialer.Dial(outReq.URL.String(), outReq.Header) targetConn, resp, err := dialer.DialContext(outReq.Context(), outReq.URL.String(), outReq.Header)
if err != nil { if err != nil {
if resp == nil { if resp == nil {
ctx.errHandler.ServeHTTP(w, req, err) ctx.errHandler.ServeHTTP(w, req, err)

5
vendor/github.com/vulcand/oxy/forward/post_config.go generated vendored Normal file
View file

@ -0,0 +1,5 @@
// +build go1.11
package forward
func (f *Forwarder) postConfig() {}

View file

@ -0,0 +1,42 @@
// +build !go1.11
package forward
import (
"context"
"net/http"
)
type key string
const (
teHeader key = "TeHeader"
)
type TeTrailerRoundTripper struct {
http.RoundTripper
}
func (t *TeTrailerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
teHeader := req.Context().Value(teHeader)
if teHeader != nil {
req.Header.Set("Te", teHeader.(string))
}
return t.RoundTripper.RoundTrip(req)
}
type TeTrailerRewriter struct {
ReqRewriter
}
func (t *TeTrailerRewriter) Rewrite(req *http.Request) {
if req.Header.Get("Te") == "trailers" {
*req = *req.WithContext(context.WithValue(req.Context(), teHeader, req.Header.Get("Te")))
}
t.ReqRewriter.Rewrite(req)
}
func (f *Forwarder) postConfig() {
f.roundTripper = &TeTrailerRoundTripper{RoundTripper: f.roundTripper}
f.rewriter = &TeTrailerRewriter{ReqRewriter: f.rewriter}
}

View file

@ -69,12 +69,6 @@ func (rw *HeaderRewriter) Rewrite(req *http.Request) {
if rw.Hostname != "" { if rw.Hostname != "" {
req.Header.Set(XForwardedServer, rw.Hostname) req.Header.Set(XForwardedServer, rw.Hostname)
} }
if !IsWebsocketRequest(req) {
// Remove hop-by-hop headers to the backend. Especially important is "Connection" because we want a persistent
// connection, regardless of what the client sent to us.
utils.RemoveHeaders(req.Header, HopHeaders...)
}
} }
func forwardedPort(req *http.Request) string { func forwardedPort(req *http.Request) string {