Add proxy protocol
This commit is contained in:
parent
89b0037ec1
commit
c8c31aea62
11 changed files with 333 additions and 29 deletions
|
@ -71,7 +71,8 @@ Run it and forget it!
|
||||||
- Websocket, HTTP/2, GRPC ready
|
- Websocket, HTTP/2, GRPC ready
|
||||||
- Access Logs (JSON, CLF)
|
- Access Logs (JSON, CLF)
|
||||||
- [Let's Encrypt](https://letsencrypt.org) support (Automatic HTTPS with renewal)
|
- [Let's Encrypt](https://letsencrypt.org) support (Automatic HTTPS with renewal)
|
||||||
- High Availability with cluster mode
|
- [Proxy Protocol](https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt) support
|
||||||
|
- High Availability with cluster mode (beta)
|
||||||
|
|
||||||
## Supported backends
|
## Supported backends
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,8 @@
|
||||||
FROM golang:1.8
|
FROM golang:1.9-alpine
|
||||||
|
|
||||||
# Install a more recent version of mercurial to avoid mismatching results
|
RUN apk --update upgrade \
|
||||||
# between glide run on a decently updated host system and the build container.
|
&& apk --no-cache --no-progress add git mercurial bash gcc musl-dev curl tar \
|
||||||
RUN awk '$1 ~ "^deb" { $3 = $3 "-backports"; print; exit }' /etc/apt/sources.list > /etc/apt/sources.list.d/backports.list && \
|
&& rm -rf /var/cache/apk/*
|
||||||
DEBIAN_FRONTEND=noninteractive apt-get update && \
|
|
||||||
DEBIAN_FRONTEND=noninteractive apt-get install -t jessie-backports --yes --no-install-recommends mercurial=3.9.1-1~bpo8+1 && \
|
|
||||||
rm -fr /var/lib/apt/lists/
|
|
||||||
|
|
||||||
RUN go get github.com/jteeuwen/go-bindata/... \
|
RUN go get github.com/jteeuwen/go-bindata/... \
|
||||||
&& go get github.com/golang/lint/golint \
|
&& go get github.com/golang/lint/golint \
|
||||||
|
|
|
@ -193,7 +193,7 @@ func (ep *EntryPoints) String() string {
|
||||||
// Set's argument is a string to be parsed to set the flag.
|
// Set's argument is a string to be parsed to set the flag.
|
||||||
// It's a comma-separated list, so we split it.
|
// It's a comma-separated list, so we split it.
|
||||||
func (ep *EntryPoints) Set(value string) error {
|
func (ep *EntryPoints) Set(value string) error {
|
||||||
regex := regexp.MustCompile(`(?:Name:(?P<Name>\S*))\s*(?:Address:(?P<Address>\S*))?\s*(?:TLS:(?P<TLS>\S*))?\s*((?P<TLSACME>TLS))?\s*(?:CA:(?P<CA>\S*))?\s*(?:Redirect.EntryPoint:(?P<RedirectEntryPoint>\S*))?\s*(?:Redirect.Regex:(?P<RedirectRegex>\\S*))?\s*(?:Redirect.Replacement:(?P<RedirectReplacement>\S*))?\s*(?:Compress:(?P<Compress>\S*))?\s*(?:WhiteListSourceRange:(?P<WhiteListSourceRange>\S*))?`)
|
regex := regexp.MustCompile(`(?:Name:(?P<Name>\S*))\s*(?:Address:(?P<Address>\S*))?\s*(?:TLS:(?P<TLS>\S*))?\s*((?P<TLSACME>TLS))?\s*(?:CA:(?P<CA>\S*))?\s*(?:Redirect.EntryPoint:(?P<RedirectEntryPoint>\S*))?\s*(?:Redirect.Regex:(?P<RedirectRegex>\\S*))?\s*(?:Redirect.Replacement:(?P<RedirectReplacement>\S*))?\s*(?:Compress:(?P<Compress>\S*))?\s*(?:WhiteListSourceRange:(?P<WhiteListSourceRange>\S*))?\s*(?:ProxyProtocol:(?P<ProxyProtocol>\S*))?`)
|
||||||
match := regex.FindAllStringSubmatch(value, -1)
|
match := regex.FindAllStringSubmatch(value, -1)
|
||||||
if match == nil {
|
if match == nil {
|
||||||
return fmt.Errorf("bad EntryPoints format: %s", value)
|
return fmt.Errorf("bad EntryPoints format: %s", value)
|
||||||
|
@ -234,7 +234,9 @@ func (ep *EntryPoints) Set(value string) error {
|
||||||
|
|
||||||
compress := false
|
compress := false
|
||||||
if len(result["Compress"]) > 0 {
|
if len(result["Compress"]) > 0 {
|
||||||
compress = strings.EqualFold(result["Compress"], "enable") || strings.EqualFold(result["Compress"], "on")
|
compress = strings.EqualFold(result["Compress"], "true") ||
|
||||||
|
strings.EqualFold(result["Compress"], "enable") ||
|
||||||
|
strings.EqualFold(result["Compress"], "on")
|
||||||
}
|
}
|
||||||
|
|
||||||
whiteListSourceRange := []string{}
|
whiteListSourceRange := []string{}
|
||||||
|
@ -242,12 +244,20 @@ func (ep *EntryPoints) Set(value string) error {
|
||||||
whiteListSourceRange = strings.Split(result["WhiteListSourceRange"], ",")
|
whiteListSourceRange = strings.Split(result["WhiteListSourceRange"], ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
proxyprotocol := false
|
||||||
|
if len(result["ProxyProtocol"]) > 0 {
|
||||||
|
proxyprotocol = strings.EqualFold(result["ProxyProtocol"], "true") ||
|
||||||
|
strings.EqualFold(result["ProxyProtocol"], "enable") ||
|
||||||
|
strings.EqualFold(result["ProxyProtocol"], "on")
|
||||||
|
}
|
||||||
|
|
||||||
(*ep)[result["Name"]] = &EntryPoint{
|
(*ep)[result["Name"]] = &EntryPoint{
|
||||||
Address: result["Address"],
|
Address: result["Address"],
|
||||||
TLS: configTLS,
|
TLS: configTLS,
|
||||||
Redirect: redirect,
|
Redirect: redirect,
|
||||||
Compress: compress,
|
Compress: compress,
|
||||||
WhitelistSourceRange: whiteListSourceRange,
|
WhitelistSourceRange: whiteListSourceRange,
|
||||||
|
ProxyProtocol: proxyprotocol,
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -277,6 +287,7 @@ type EntryPoint struct {
|
||||||
Auth *types.Auth
|
Auth *types.Auth
|
||||||
WhitelistSourceRange []string
|
WhitelistSourceRange []string
|
||||||
Compress bool
|
Compress bool
|
||||||
|
ProxyProtocol bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Redirect configures a redirection of an entry point to another, or to an URL
|
// Redirect configures a redirection of an entry point to another, or to an URL
|
||||||
|
|
|
@ -292,6 +292,12 @@ To write JSON format logs, specify `json` as the format:
|
||||||
# address = ":80"
|
# address = ":80"
|
||||||
# whiteListSourceRange = ["127.0.0.1/32"]
|
# whiteListSourceRange = ["127.0.0.1/32"]
|
||||||
|
|
||||||
|
# To enable ProxyProtocol support (https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt):
|
||||||
|
# [entryPoints]
|
||||||
|
# [entryPoints.http]
|
||||||
|
# address = ":80"
|
||||||
|
# proxyprotocol = true
|
||||||
|
|
||||||
[entryPoints]
|
[entryPoints]
|
||||||
[entryPoints.http]
|
[entryPoints.http]
|
||||||
address = ":80"
|
address = ":80"
|
||||||
|
|
6
glide.lock
generated
6
glide.lock
generated
|
@ -1,5 +1,5 @@
|
||||||
hash: 3d5a06016b7b56be08120ed653406a1e8d4ade7e69b4fbc37b31683cb4e9a519
|
hash: 2b042ce06e9c4aed4606f2b8ced5d6c3de537d1254316e8c6611e78d934a024a
|
||||||
updated: 2017-08-21T14:15:06.346751095+02:00
|
updated: 2017-08-24T14:24:42.04425168+02:00
|
||||||
imports:
|
imports:
|
||||||
- name: cloud.google.com/go
|
- name: cloud.google.com/go
|
||||||
version: 2e6a95edb1071d750f6d7db777bf66cd2997af6c
|
version: 2e6a95edb1071d750f6d7db777bf66cd2997af6c
|
||||||
|
@ -10,6 +10,8 @@ imports:
|
||||||
version: 0ddd408d5d60ea76e320503cc7dd091992dee608
|
version: 0ddd408d5d60ea76e320503cc7dd091992dee608
|
||||||
- name: github.com/aokoli/goutils
|
- name: github.com/aokoli/goutils
|
||||||
version: 3391d3790d23d03408670993e957e8f408993c34
|
version: 3391d3790d23d03408670993e957e8f408993c34
|
||||||
|
- name: github.com/armon/go-proxyproto
|
||||||
|
version: 48572f11356f1843b694f21a290d4f1006bc5e47
|
||||||
- name: github.com/ArthurHlt/go-eureka-client
|
- name: github.com/ArthurHlt/go-eureka-client
|
||||||
version: 9d0a49cbd39aa3634ae1977e9f519a262b10adaf
|
version: 9d0a49cbd39aa3634ae1977e9f519a262b10adaf
|
||||||
subpackages:
|
subpackages:
|
||||||
|
|
|
@ -202,6 +202,8 @@ import:
|
||||||
- spew
|
- spew
|
||||||
- package: github.com/Masterminds/sprig
|
- package: github.com/Masterminds/sprig
|
||||||
version: e039e20e500c2c025d9145be375e27cf42a94174
|
version: e039e20e500c2c025d9145be375e27cf42a94174
|
||||||
|
- package: github.com/armon/go-proxyproto
|
||||||
|
version: 48572f11356f1843b694f21a290d4f1006bc5e47
|
||||||
testImport:
|
testImport:
|
||||||
- package: github.com/stvp/go-udp-testing
|
- package: github.com/stvp/go-udp-testing
|
||||||
- package: github.com/docker/libcompose
|
- package: github.com/docker/libcompose
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/armon/go-proxyproto"
|
||||||
"github.com/containous/mux"
|
"github.com/containous/mux"
|
||||||
"github.com/containous/traefik/cluster"
|
"github.com/containous/traefik/cluster"
|
||||||
"github.com/containous/traefik/configuration"
|
"github.com/containous/traefik/configuration"
|
||||||
|
@ -65,6 +66,7 @@ type serverEntryPoints map[string]*serverEntryPoint
|
||||||
|
|
||||||
type serverEntryPoint struct {
|
type serverEntryPoint struct {
|
||||||
httpServer *http.Server
|
httpServer *http.Server
|
||||||
|
listener net.Listener
|
||||||
httpRouter *middlewares.HandlerSwitcher
|
httpRouter *middlewares.HandlerSwitcher
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -259,7 +261,7 @@ func (server *Server) startHTTPServers() {
|
||||||
|
|
||||||
for newServerEntryPointName, newServerEntryPoint := range server.serverEntryPoints {
|
for newServerEntryPointName, newServerEntryPoint := range server.serverEntryPoints {
|
||||||
serverEntryPoint := server.setupServerEntryPoint(newServerEntryPointName, newServerEntryPoint)
|
serverEntryPoint := server.setupServerEntryPoint(newServerEntryPointName, newServerEntryPoint)
|
||||||
go server.startServer(serverEntryPoint.httpServer, server.globalConfiguration)
|
go server.startServer(serverEntryPoint, server.globalConfiguration)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -296,12 +298,13 @@ func (server *Server) setupServerEntryPoint(newServerEntryPointName string, newS
|
||||||
}
|
}
|
||||||
serverMiddlewares = append(serverMiddlewares, ipWhitelistMiddleware)
|
serverMiddlewares = append(serverMiddlewares, ipWhitelistMiddleware)
|
||||||
}
|
}
|
||||||
newSrv, err := server.prepareServer(newServerEntryPointName, server.globalConfiguration.EntryPoints[newServerEntryPointName], newServerEntryPoint.httpRouter, serverMiddlewares...)
|
newSrv, listener, err := server.prepareServer(newServerEntryPointName, server.globalConfiguration.EntryPoints[newServerEntryPointName], newServerEntryPoint.httpRouter, serverMiddlewares...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("Error preparing server: ", err)
|
log.Fatal("Error preparing server: ", err)
|
||||||
}
|
}
|
||||||
serverEntryPoint := server.serverEntryPoints[newServerEntryPointName]
|
serverEntryPoint := server.serverEntryPoints[newServerEntryPointName]
|
||||||
serverEntryPoint.httpServer = newSrv
|
serverEntryPoint.httpServer = newSrv
|
||||||
|
serverEntryPoint.listener = listener
|
||||||
|
|
||||||
return serverEntryPoint
|
return serverEntryPoint
|
||||||
}
|
}
|
||||||
|
@ -611,20 +614,20 @@ func (server *Server) createTLSConfig(entryPointName string, tlsOption *configur
|
||||||
return config, nil
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) startServer(srv *http.Server, globalConfiguration configuration.GlobalConfiguration) {
|
func (server *Server) startServer(serverEntryPoint *serverEntryPoint, globalConfiguration configuration.GlobalConfiguration) {
|
||||||
log.Infof("Starting server on %s", srv.Addr)
|
log.Infof("Starting server on %s", serverEntryPoint.httpServer.Addr)
|
||||||
var err error
|
var err error
|
||||||
if srv.TLSConfig != nil {
|
if serverEntryPoint.httpServer.TLSConfig != nil {
|
||||||
err = srv.ListenAndServeTLS("", "")
|
err = serverEntryPoint.httpServer.ServeTLS(serverEntryPoint.listener, "", "")
|
||||||
} else {
|
} else {
|
||||||
err = srv.ListenAndServe()
|
err = serverEntryPoint.httpServer.Serve(serverEntryPoint.listener)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Error creating server: ", err)
|
log.Error("Error creating server: ", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) prepareServer(entryPointName string, entryPoint *configuration.EntryPoint, router *middlewares.HandlerSwitcher, middlewares ...negroni.Handler) (*http.Server, error) {
|
func (server *Server) prepareServer(entryPointName string, entryPoint *configuration.EntryPoint, router *middlewares.HandlerSwitcher, middlewares ...negroni.Handler) (*http.Server, net.Listener, error) {
|
||||||
readTimeout, writeTimeout, idleTimeout := buildServerTimeouts(server.globalConfiguration)
|
readTimeout, writeTimeout, idleTimeout := buildServerTimeouts(server.globalConfiguration)
|
||||||
log.Infof("Preparing server %s %+v with readTimeout=%s writeTimeout=%s idleTimeout=%s", entryPointName, entryPoint, readTimeout, writeTimeout, idleTimeout)
|
log.Infof("Preparing server %s %+v with readTimeout=%s writeTimeout=%s idleTimeout=%s", entryPointName, entryPoint, readTimeout, writeTimeout, idleTimeout)
|
||||||
|
|
||||||
|
@ -638,7 +641,16 @@ func (server *Server) prepareServer(entryPointName string, entryPoint *configura
|
||||||
tlsConfig, err := server.createTLSConfig(entryPointName, entryPoint.TLS, router)
|
tlsConfig, err := server.createTLSConfig(entryPointName, entryPoint.TLS, router)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Error creating TLS config: %s", err)
|
log.Errorf("Error creating TLS config: %s", err)
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp", entryPoint.Address)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Error opening listener ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if entryPoint.ProxyProtocol {
|
||||||
|
listener = &proxyproto.Listener{Listener: listener}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &http.Server{
|
return &http.Server{
|
||||||
|
@ -648,7 +660,9 @@ func (server *Server) prepareServer(entryPointName string, entryPoint *configura
|
||||||
ReadTimeout: readTimeout,
|
ReadTimeout: readTimeout,
|
||||||
WriteTimeout: writeTimeout,
|
WriteTimeout: writeTimeout,
|
||||||
IdleTimeout: idleTimeout,
|
IdleTimeout: idleTimeout,
|
||||||
}, nil
|
},
|
||||||
|
listener,
|
||||||
|
nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildServerTimeouts(globalConfig configuration.GlobalConfiguration) (readTimeout, writeTimeout, idleTimeout time.Duration) {
|
func buildServerTimeouts(globalConfig configuration.GlobalConfiguration) (readTimeout, writeTimeout, idleTimeout time.Duration) {
|
||||||
|
|
|
@ -100,7 +100,7 @@ func TestPrepareServerTimeouts(t *testing.T) {
|
||||||
router := middlewares.NewHandlerSwitcher(mux.NewRouter())
|
router := middlewares.NewHandlerSwitcher(mux.NewRouter())
|
||||||
|
|
||||||
srv := NewServer(test.globalConfig)
|
srv := NewServer(test.globalConfig)
|
||||||
httpServer, err := srv.prepareServer(entryPointName, entryPoint, router)
|
httpServer, _, err := srv.prepareServer(entryPointName, entryPoint, router)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Unexpected error when preparing srv: %s", err)
|
t.Fatalf("Unexpected error when preparing srv: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -344,6 +344,12 @@
|
||||||
# address = ":80"
|
# address = ":80"
|
||||||
# whiteListSourceRange = ["127.0.0.1/32"]
|
# whiteListSourceRange = ["127.0.0.1/32"]
|
||||||
|
|
||||||
|
# To enable ProxyProtocol support (https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt):
|
||||||
|
# [entryPoints]
|
||||||
|
# [entryPoints.http]
|
||||||
|
# address = ":80"
|
||||||
|
# proxyprotocol = true
|
||||||
|
|
||||||
# Enable retry sending request if network error
|
# Enable retry sending request if network error
|
||||||
#
|
#
|
||||||
# Optional
|
# Optional
|
||||||
|
|
21
vendor/github.com/armon/go-proxyproto/LICENSE
generated
vendored
Normal file
21
vendor/github.com/armon/go-proxyproto/LICENSE
generated
vendored
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2014 Armon Dadgar
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
244
vendor/github.com/armon/go-proxyproto/protocol.go
generated
vendored
Normal file
244
vendor/github.com/armon/go-proxyproto/protocol.go
generated
vendored
Normal file
|
@ -0,0 +1,244 @@
|
||||||
|
package proxyproto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// prefix is the string we look for at the start of a connection
|
||||||
|
// to check if this connection is using the proxy protocol
|
||||||
|
prefix = []byte("PROXY ")
|
||||||
|
prefixLen = len(prefix)
|
||||||
|
|
||||||
|
ErrInvalidUpstream = errors.New("upstream connection address not trusted for PROXY information")
|
||||||
|
)
|
||||||
|
|
||||||
|
// SourceChecker can be used to decide whether to trust the PROXY info or pass
|
||||||
|
// the original connection address through. If set, the connecting address is
|
||||||
|
// passed in as an argument. If the function returns an error due to the source
|
||||||
|
// being disallowed, it should return ErrInvalidUpstream.
|
||||||
|
//
|
||||||
|
// Behavior is as follows:
|
||||||
|
// * If error is not nil, the call to Accept() will fail. If the reason for
|
||||||
|
// triggering this failure is due to a disallowed source, it should return
|
||||||
|
// ErrInvalidUpstream.
|
||||||
|
// * If bool is true, the PROXY-set address is used.
|
||||||
|
// * If bool is false, the connection's remote address is used, rather than the
|
||||||
|
// address claimed in the PROXY info.
|
||||||
|
type SourceChecker func(net.Addr) (bool, error)
|
||||||
|
|
||||||
|
// Listener is used to wrap an underlying listener,
|
||||||
|
// whose connections may be using the HAProxy Proxy Protocol (version 1).
|
||||||
|
// If the connection is using the protocol, the RemoteAddr() will return
|
||||||
|
// the correct client address.
|
||||||
|
//
|
||||||
|
// Optionally define ProxyHeaderTimeout to set a maximum time to
|
||||||
|
// receive the Proxy Protocol Header. Zero means no timeout.
|
||||||
|
type Listener struct {
|
||||||
|
Listener net.Listener
|
||||||
|
ProxyHeaderTimeout time.Duration
|
||||||
|
SourceCheck SourceChecker
|
||||||
|
}
|
||||||
|
|
||||||
|
// Conn is used to wrap and underlying connection which
|
||||||
|
// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
|
||||||
|
// return the address of the client instead of the proxy address.
|
||||||
|
type Conn struct {
|
||||||
|
bufReader *bufio.Reader
|
||||||
|
conn net.Conn
|
||||||
|
dstAddr *net.TCPAddr
|
||||||
|
srcAddr *net.TCPAddr
|
||||||
|
useConnRemoteAddr bool
|
||||||
|
once sync.Once
|
||||||
|
proxyHeaderTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accept waits for and returns the next connection to the listener.
|
||||||
|
func (p *Listener) Accept() (net.Conn, error) {
|
||||||
|
// Get the underlying connection
|
||||||
|
conn, err := p.Listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var useConnRemoteAddr bool
|
||||||
|
if p.SourceCheck != nil {
|
||||||
|
allowed, err := p.SourceCheck(conn.RemoteAddr())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !allowed {
|
||||||
|
useConnRemoteAddr = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
newConn := NewConn(conn, p.ProxyHeaderTimeout)
|
||||||
|
newConn.useConnRemoteAddr = useConnRemoteAddr
|
||||||
|
return newConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the underlying listener.
|
||||||
|
func (p *Listener) Close() error {
|
||||||
|
return p.Listener.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Addr returns the underlying listener's network address.
|
||||||
|
func (p *Listener) Addr() net.Addr {
|
||||||
|
return p.Listener.Addr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConn is used to wrap a net.Conn that may be speaking
|
||||||
|
// the proxy protocol into a proxyproto.Conn
|
||||||
|
func NewConn(conn net.Conn, timeout time.Duration) *Conn {
|
||||||
|
pConn := &Conn{
|
||||||
|
bufReader: bufio.NewReader(conn),
|
||||||
|
conn: conn,
|
||||||
|
proxyHeaderTimeout: timeout,
|
||||||
|
}
|
||||||
|
return pConn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read is check for the proxy protocol header when doing
|
||||||
|
// the initial scan. If there is an error parsing the header,
|
||||||
|
// it is returned and the socket is closed.
|
||||||
|
func (p *Conn) Read(b []byte) (int, error) {
|
||||||
|
var err error
|
||||||
|
p.once.Do(func() { err = p.checkPrefix() })
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return p.bufReader.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Conn) Write(b []byte) (int, error) {
|
||||||
|
return p.conn.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Conn) Close() error {
|
||||||
|
return p.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Conn) LocalAddr() net.Addr {
|
||||||
|
return p.conn.LocalAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoteAddr returns the address of the client if the proxy
|
||||||
|
// protocol is being used, otherwise just returns the address of
|
||||||
|
// the socket peer. If there is an error parsing the header, the
|
||||||
|
// address of the client is not returned, and the socket is closed.
|
||||||
|
// Once implication of this is that the call could block if the
|
||||||
|
// client is slow. Using a Deadline is recommended if this is called
|
||||||
|
// before Read()
|
||||||
|
func (p *Conn) RemoteAddr() net.Addr {
|
||||||
|
p.once.Do(func() {
|
||||||
|
if err := p.checkPrefix(); err != nil && err != io.EOF {
|
||||||
|
log.Printf("[ERR] Failed to read proxy prefix: %v", err)
|
||||||
|
p.Close()
|
||||||
|
p.bufReader = bufio.NewReader(p.conn)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if p.srcAddr != nil && !p.useConnRemoteAddr {
|
||||||
|
return p.srcAddr
|
||||||
|
}
|
||||||
|
return p.conn.RemoteAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Conn) SetDeadline(t time.Time) error {
|
||||||
|
return p.conn.SetDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Conn) SetReadDeadline(t time.Time) error {
|
||||||
|
return p.conn.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Conn) SetWriteDeadline(t time.Time) error {
|
||||||
|
return p.conn.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Conn) checkPrefix() error {
|
||||||
|
if p.proxyHeaderTimeout != 0 {
|
||||||
|
readDeadLine := time.Now().Add(p.proxyHeaderTimeout)
|
||||||
|
p.conn.SetReadDeadline(readDeadLine)
|
||||||
|
defer p.conn.SetReadDeadline(time.Time{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Incrementally check each byte of the prefix
|
||||||
|
for i := 1; i <= prefixLen; i++ {
|
||||||
|
inp, err := p.bufReader.Peek(i)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for a prefix mis-match, quit early
|
||||||
|
if !bytes.Equal(inp, prefix[:i]) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the header line
|
||||||
|
header, err := p.bufReader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
p.conn.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strip the carriage return and new line
|
||||||
|
header = header[:len(header)-2]
|
||||||
|
|
||||||
|
// Split on spaces, should be (PROXY <type> <src addr> <dst addr> <src port> <dst port>)
|
||||||
|
parts := strings.Split(header, " ")
|
||||||
|
if len(parts) != 6 {
|
||||||
|
p.conn.Close()
|
||||||
|
return fmt.Errorf("Invalid header line: %s", header)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the type is known
|
||||||
|
switch parts[1] {
|
||||||
|
case "TCP4":
|
||||||
|
case "TCP6":
|
||||||
|
default:
|
||||||
|
p.conn.Close()
|
||||||
|
return fmt.Errorf("Unhandled address type: %s", parts[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse out the source address
|
||||||
|
ip := net.ParseIP(parts[2])
|
||||||
|
if ip == nil {
|
||||||
|
p.conn.Close()
|
||||||
|
return fmt.Errorf("Invalid source ip: %s", parts[2])
|
||||||
|
}
|
||||||
|
port, err := strconv.Atoi(parts[4])
|
||||||
|
if err != nil {
|
||||||
|
p.conn.Close()
|
||||||
|
return fmt.Errorf("Invalid source port: %s", parts[4])
|
||||||
|
}
|
||||||
|
p.srcAddr = &net.TCPAddr{IP: ip, Port: port}
|
||||||
|
|
||||||
|
// Parse out the destination address
|
||||||
|
ip = net.ParseIP(parts[3])
|
||||||
|
if ip == nil {
|
||||||
|
p.conn.Close()
|
||||||
|
return fmt.Errorf("Invalid destination ip: %s", parts[3])
|
||||||
|
}
|
||||||
|
port, err = strconv.Atoi(parts[5])
|
||||||
|
if err != nil {
|
||||||
|
p.conn.Close()
|
||||||
|
return fmt.Errorf("Invalid destination port: %s", parts[5])
|
||||||
|
}
|
||||||
|
p.dstAddr = &net.TCPAddr{IP: ip, Port: port}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
Loading…
Reference in a new issue