Merge pull request #1350 from containous/toml-compatible-duration-type

Use TOML-compatible duration type.
This commit is contained in:
Timo Reimann 2017-04-03 19:30:33 +02:00 committed by GitHub
commit 7d256c9bb9
15 changed files with 313 additions and 103 deletions

View file

@ -9,6 +9,7 @@ import (
"strings"
"time"
"github.com/containous/flaeg"
"github.com/containous/traefik/acme"
"github.com/containous/traefik/provider"
"github.com/containous/traefik/types"
@ -23,7 +24,7 @@ type TraefikConfiguration struct {
// GlobalConfiguration holds global configuration (with providers, etc.).
// It's populated from the traefik configuration file passed as an argument to the binary.
type GlobalConfiguration struct {
GraceTimeOut int64 `short:"g" description:"Duration to give active requests a chance to finish during hot-reload"`
GraceTimeOut flaeg.Duration `short:"g" description:"Duration to give active requests a chance to finish during hot-reload"`
Debug bool `short:"d" description:"Enable debug mode"`
CheckNewVersion bool `description:"Periodically check if a new version has been released"`
AccessLogsFile string `description:"Access logs file"`
@ -34,7 +35,7 @@ type GlobalConfiguration struct {
Constraints types.Constraints `description:"Filter services by constraint, matching with service tags"`
ACME *acme.ACME `description:"Enable ACME (Let's Encrypt): automatic SSL"`
DefaultEntryPoints DefaultEntryPoints `description:"Entrypoints to be used by frontends that do not specify any entrypoint"`
ProvidersThrottleDuration time.Duration `description:"Backends throttle duration: minimum duration between 2 events from providers before applying a new configuration. It avoids unnecessary reloads if multiples events are sent in a short amount of time."`
ProvidersThrottleDuration flaeg.Duration `description:"Backends throttle duration: minimum duration between 2 events from providers before applying a new configuration. It avoids unnecessary reloads if multiples events are sent in a short amount of time."`
MaxIdleConnsPerHost int `description:"If non-zero, controls the maximum idle (keep-alive) to keep per-host. If zero, DefaultMaxIdleConnsPerHost is used"`
InsecureSkipVerify bool `description:"Disable SSL certificate verification"`
Retry *Retry `description:"Enable retry sending request if network error"`
@ -457,14 +458,14 @@ func NewTraefikDefaultPointersConfiguration() *TraefikConfiguration {
func NewTraefikConfiguration() *TraefikConfiguration {
return &TraefikConfiguration{
GlobalConfiguration: GlobalConfiguration{
GraceTimeOut: 10,
GraceTimeOut: flaeg.Duration(10 * time.Second),
AccessLogsFile: "",
TraefikLogsFile: "",
LogLevel: "ERROR",
EntryPoints: map[string]*EntryPoint{},
Constraints: types.Constraints{},
DefaultEntryPoints: []string{},
ProvidersThrottleDuration: time.Duration(2 * time.Second),
ProvidersThrottleDuration: flaeg.Duration(2 * time.Second),
MaxIdleConnsPerHost: 200,
CheckNewVersion: true,
},

View file

@ -9,13 +9,15 @@
# Global configuration
################################################################
# Timeout in seconds.
# Duration to give active requests a chance to finish during hot-reloads
# Duration to give active requests a chance to finish during hot-reloads.
# Can be provided in a format supported by [time.ParseDuration](https://golang.org/pkg/time/#ParseDuration) or as raw
# values (digits). If no units are provided, the value is parsed assuming
# seconds.
#
# Optional
# Default: 10
# Default: "10s"
#
# graceTimeOut = 10
# graceTimeOut = "10s"
# Enable debug mode
#
@ -56,11 +58,14 @@
# Backends throttle duration: minimum duration in seconds between 2 events from providers
# before applying a new configuration. It avoids unnecessary reloads if multiples events
# are sent in a short amount of time.
# Can be provided in a format supported by [time.ParseDuration](https://golang.org/pkg/time/#ParseDuration) or as raw
# values (digits). If no units are provided, the value is parsed assuming
# seconds.
#
# Optional
# Default: "2"
# Default: "2s"
#
# ProvidersThrottleDuration = "5"
# ProvidersThrottleDuration = "2s"
# If non-zero, controls the maximum idle (keep-alive) to keep per-host. If zero, DefaultMaxIdleConnsPerHost is used.
# If you encounter 'too many open files' errors, you can either change this value, or change `ulimit` value.
@ -932,19 +937,25 @@ domain = "marathon.localhost"
# dcosToken = "xxxxxx"
# Override DialerTimeout
# Amount of time in seconds to allow the Marathon provider to wait to open a TCP
# connection to a Marathon master
# Amount of time to allow the Marathon provider to wait to open a TCP connection
# to a Marathon master.
# Can be provided in a format supported by [time.ParseDuration](https://golang.org/pkg/time/#ParseDuration) or as raw
# values (digits). If no units are provided, the value is parsed assuming
# seconds.
#
# Optional
# Default: 60
# dialerTimeout = 5
# Default: "60s"
# dialerTimeout = "60s"
# Set the TCP Keep Alive interval (in seconds) for the Marathon HTTP Client
# Set the TCP Keep Alive interval for the Marathon HTTP Client.
# Can be provided in a format supported by [time.ParseDuration](https://golang.org/pkg/time/#ParseDuration) or as raw
# values (digits). If no units are provided, the value is parsed assuming
# seconds.
#
# Optional
# Default: 10
# Default: "10s"
#
# keepAlive = 10
# keepAlive = "10s"
```
Labels can be used on containers to override default behaviour:

10
glide.lock generated
View file

@ -1,5 +1,5 @@
hash: 56175c5c588abf1169ba2425ac6dd7e3e5a8cf8ab6ad3f75cad51e45f5e94e3a
updated: 2017-03-23T22:43:06.217624505+01:00
hash: 8349c21c53c639aa79752c4e4b07e323ddb4f05be783564d7de687afbb6f2d67
updated: 2017-03-28T22:35:17.448681338+02:00
imports:
- name: bitbucket.org/ww/goautoneg
version: 75cd24fc2f2c2a2088577d12123ddee5f54e0675
@ -87,7 +87,7 @@ imports:
- name: github.com/codegangsta/negroni
version: dc6b9d037e8dab60cbfc09c61d6932537829be8b
- name: github.com/containous/flaeg
version: a731c034dda967333efce5f8d276aeff11f8ff87
version: b5d2dc5878df07c2d74413348186982e7b865871
- name: github.com/containous/mux
version: a819b77bba13f0c0cbe36e437bc2e948411b3996
- name: github.com/containous/staert
@ -239,7 +239,7 @@ imports:
- name: github.com/gorilla/context
version: 1ea25387ff6f684839d82767c1733ff4d4d15d0a
- name: github.com/gorilla/websocket
version: c36f2fe5c330f0ac404b616b96c438b8616b1aaf
version: 4873052237e4eeda85cf50c071ef33836fe8e139
- name: github.com/hashicorp/consul
version: fce7d75609a04eeb9d4bf41c8dc592aac18fc97d
subpackages:
@ -322,7 +322,7 @@ imports:
- name: github.com/pborman/uuid
version: 5007efa264d92316c43112bc573e754bc889b7b1
- name: github.com/pkg/errors
version: 248dadf4e9068a0b3e79f02ed0a610d935de5302
version: bfd5150e4e41705ded2129ec33379de1cb90b513
- name: github.com/pmezard/go-difflib
version: d8ed2627bdf02c080bf22230dbb337003b7aba2d
subpackages:

View file

@ -7,7 +7,6 @@ import:
- package: github.com/Sirupsen/logrus
- package: github.com/cenk/backoff
- package: github.com/containous/flaeg
version: a731c034dda967333efce5f8d276aeff11f8ff87
- package: github.com/vulcand/oxy
version: fcc76b52eb8568540a020b7a99e854d9d752b364
repo: https://github.com/containous/oxy.git

View file

@ -13,6 +13,7 @@ import (
"github.com/BurntSushi/ty/fun"
"github.com/cenk/backoff"
"github.com/containous/flaeg"
"github.com/containous/traefik/job"
"github.com/containous/traefik/log"
"github.com/containous/traefik/safe"
@ -25,15 +26,15 @@ var _ Provider = (*Marathon)(nil)
// Marathon holds configuration of the Marathon provider.
type Marathon struct {
BaseProvider
Endpoint string `description:"Marathon server endpoint. You can also specify multiple endpoint for Marathon"`
Domain string `description:"Default domain used"`
ExposedByDefault bool `description:"Expose Marathon apps by default"`
GroupsAsSubDomains bool `description:"Convert Marathon groups to subdomains"`
DCOSToken string `description:"DCOSToken for DCOS environment, This will override the Authorization header"`
MarathonLBCompatibility bool `description:"Add compatibility with marathon-lb labels"`
TLS *ClientTLS `description:"Enable Docker TLS support"`
DialerTimeout time.Duration `description:"Set a non-default connection timeout for Marathon"`
KeepAlive time.Duration `description:"Set a non-default TCP Keep Alive time in seconds"`
Endpoint string `description:"Marathon server endpoint. You can also specify multiple endpoint for Marathon"`
Domain string `description:"Default domain used"`
ExposedByDefault bool `description:"Expose Marathon apps by default"`
GroupsAsSubDomains bool `description:"Convert Marathon groups to subdomains"`
DCOSToken string `description:"DCOSToken for DCOS environment, This will override the Authorization header"`
MarathonLBCompatibility bool `description:"Add compatibility with marathon-lb labels"`
TLS *ClientTLS `description:"Enable Docker TLS support"`
DialerTimeout flaeg.Duration `description:"Set a non-default connection timeout for Marathon"`
KeepAlive flaeg.Duration `description:"Set a non-default TCP Keep Alive time in seconds"`
Basic *MarathonBasic
marathonClient marathon.Marathon
}
@ -71,8 +72,8 @@ func (provider *Marathon) Provide(configurationChan chan<- types.ConfigMessage,
config.HTTPClient = &http.Client{
Transport: &http.Transport{
DialContext: (&net.Dialer{
KeepAlive: provider.KeepAlive * time.Second,
Timeout: time.Second * provider.DialerTimeout,
KeepAlive: time.Duration(provider.KeepAlive),
Timeout: time.Duration(provider.DialerTimeout),
}).DialContext,
TLSClientConfig: TLSConfig,
},

View file

@ -20,6 +20,8 @@ import (
"syscall"
"time"
"sync"
"github.com/codegangsta/negroni"
"github.com/containous/mux"
"github.com/containous/traefik/cluster"
@ -35,7 +37,6 @@ import (
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/roundrobin"
"github.com/vulcand/oxy/utils"
"sync"
)
var oxyLogger = &OxyLogger{}
@ -120,8 +121,9 @@ func (server *Server) Stop() {
wg.Add(1)
go func(serverEntryPointName string, serverEntryPoint *serverEntryPoint) {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(server.globalConfiguration.GraceTimeOut)*time.Second)
log.Debugf("Waiting %d seconds before killing connections on entrypoint %s...", server.globalConfiguration.GraceTimeOut, serverEntryPointName)
graceTimeOut := time.Duration(server.globalConfiguration.GraceTimeOut)
ctx, cancel := context.WithTimeout(context.Background(), graceTimeOut)
log.Debugf("Waiting %s seconds before killing connections on entrypoint %s...", graceTimeOut, serverEntryPointName)
if err := serverEntryPoint.httpServer.Shutdown(ctx); err != nil {
log.Debugf("Wait is over due to: %s", err)
serverEntryPoint.httpServer.Close()
@ -136,7 +138,7 @@ func (server *Server) Stop() {
// Close destroys the server
func (server *Server) Close() {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(server.globalConfiguration.GraceTimeOut)*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(server.globalConfiguration.GraceTimeOut))
go func(ctx context.Context) {
<-ctx.Done()
if ctx.Err() == context.Canceled {
@ -231,16 +233,17 @@ func (server *Server) listenProviders(stop chan bool) {
} else {
lastConfigs.Set(configMsg.ProviderName, &configMsg)
lastReceivedConfigurationValue := lastReceivedConfiguration.Get().(time.Time)
if time.Now().After(lastReceivedConfigurationValue.Add(time.Duration(server.globalConfiguration.ProvidersThrottleDuration))) {
providersThrottleDuration := time.Duration(server.globalConfiguration.ProvidersThrottleDuration)
if time.Now().After(lastReceivedConfigurationValue.Add(providersThrottleDuration)) {
log.Debugf("Last %s config received more than %s, OK", configMsg.ProviderName, server.globalConfiguration.ProvidersThrottleDuration.String())
// last config received more than n s ago
server.configurationValidatedChan <- configMsg
} else {
log.Debugf("Last %s config received less than %s, waiting...", configMsg.ProviderName, server.globalConfiguration.ProvidersThrottleDuration.String())
safe.Go(func() {
<-time.After(server.globalConfiguration.ProvidersThrottleDuration)
<-time.After(providersThrottleDuration)
lastReceivedConfigurationValue := lastReceivedConfiguration.Get().(time.Time)
if time.Now().After(lastReceivedConfigurationValue.Add(time.Duration(server.globalConfiguration.ProvidersThrottleDuration))) {
if time.Now().After(lastReceivedConfigurationValue.Add(time.Duration(providersThrottleDuration))) {
log.Debugf("Waited for %s config, OK", configMsg.ProviderName)
if lastConfig, ok := lastConfigs.Get(configMsg.ProviderName); ok {
server.configurationValidatedChan <- *lastConfig.(*types.ConfigMessage)

View file

@ -2,13 +2,15 @@
# Global configuration
################################################################
# Timeout in seconds.
# Duration to give active requests a chance to finish during hot-reloads
# Duration to give active requests a chance to finish during hot-reloads.
# Can be provided in a format supported by Go's time.ParseDuration function or
# as raw values (digits). If no units are provided, the value is parsed assuming
# seconds.
#
# Optional
# Default: 10
# Default: "10s"
#
# graceTimeOut = 10
# graceTimeOut = "10s"
# Enable debug mode
#
@ -47,11 +49,14 @@
# Backends throttle duration: minimum duration in seconds between 2 events from providers
# before applying a new configuration. It avoids unnecessary reloads if multiples events
# are sent in a short amount of time.
# Can be provided in a format supported by Go's time.ParseDuration function or
# as raw values (digits). If no units are provided, the value is parsed assuming
# seconds.
#
# Optional
# Default: "2"
# Default: "2s"
#
# ProvidersThrottleDuration = "5"
# ProvidersThrottleDuration = "5s"
# If non-zero, controls the maximum idle (keep-alive) to keep per-host. If zero, DefaultMaxIdleConnsPerHost is used.
# If you encounter 'too many open files' errors, you can either change this value, or change `ulimit` value.
@ -557,12 +562,15 @@
# groupsAsSubDomains = true
# Override DialerTimeout
# Amount of time in seconds to allow the Marathon provider to wait to open a TCP
# connection to a Marathon master
# Amount of time to allow the Marathon provider to wait to open a TCP connection
# to a Marathon master.
# Can be provided in a format supported by Go's time.ParseDuration function or
# as raw values (digits). If no units are provided, the value is parsed assuming
# seconds.
#
# Optional
# Default: 60
# dialerTimeout = 5
# Default: "60s"
# dialerTimeout = "60s"
# Enable Marathon basic authentication
#
@ -579,12 +587,15 @@
# dcosToken = "xxxxxx"
# Set the TCP Keep Alive interval (in seconds) for the Marathon HTTP Client
# Set the TCP Keep Alive interval for the Marathon HTTP Client.
# Can be provided in a format supported by Go's time.ParseDuration function or
# as raw values (digits). If no units are provided, the value is parsed assuming
# seconds.
#
# Optional
# Default: 10
# Default: "10s"
#
# keepAlive = 10
# keepAlive = "10s"
################################################################
# Mesos configuration backend

View file

@ -3,7 +3,6 @@ package flaeg
import (
"errors"
"fmt"
flag "github.com/ogier/pflag"
"io"
"io/ioutil"
"os"
@ -14,6 +13,8 @@ import (
"text/tabwriter"
"text/template"
"time"
flag "github.com/ogier/pflag"
)
// ErrParserNotFound is thrown when a field is flaged but not parser match its type
@ -131,8 +132,8 @@ func loadParsers(customParsers map[reflect.Type]Parser) (map[reflect.Type]Parser
var float64Parser float64Value
parsers[reflect.TypeOf(float64(1.5))] = &float64Parser
var durationParser durationValue
parsers[reflect.TypeOf(time.Second)] = &durationParser
var durationParser Duration
parsers[reflect.TypeOf(Duration(time.Second))] = &durationParser
var timeParser timeValue
parsers[reflect.TypeOf(time.Now())] = &timeParser
@ -473,7 +474,7 @@ func PrintHelpWithCommand(flagmap map[string]reflect.StructField, defaultValmap
// Define a templates
// Using POSXE STD : http://pubs.opengroup.org/onlinepubs/9699919799/
const helper = `{{if .ProgDescription}}{{.ProgDescription}}
{{end}}Usage: {{.ProgName}} [--flag=flag_argument] [-f[flag_argument]] ... set flag_argument to flag(s)
or: {{.ProgName}} [--flag[=true|false| ]] [-f[true|false| ]] ... set true/false to boolean flag(s)
{{if .SubCommands}}

View file

@ -143,21 +143,38 @@ func (f *float64Value) SetValue(val interface{}) {
*f = float64Value(val.(float64))
}
// -- time.Duration Value
type durationValue time.Duration
// Duration is a custom type suitable for parsing duration values.
// It supports `time.ParseDuration`-compatible values and suffix-less digits; in
// the latter case, seconds are assumed.
type Duration time.Duration
// Set sets the duration from the given string value.
func (d *Duration) Set(s string) error {
if v, err := strconv.Atoi(s); err == nil {
*d = Duration(time.Duration(v) * time.Second)
return nil
}
func (d *durationValue) Set(s string) error {
v, err := time.ParseDuration(s)
*d = durationValue(v)
*d = Duration(v)
return err
}
func (d *durationValue) Get() interface{} { return time.Duration(*d) }
// Get returns the duration value.
func (d *Duration) Get() interface{} { return time.Duration(*d) }
func (d *durationValue) String() string { return (*time.Duration)(d).String() }
// String returns a string representation of the duration value.
func (d *Duration) String() string { return (*time.Duration)(d).String() }
func (d *durationValue) SetValue(val interface{}) {
*d = durationValue(val.(time.Duration))
// SetValue sets the duration from the given Duration-asserted value.
func (d *Duration) SetValue(val interface{}) {
*d = Duration(val.(Duration))
}
// UnmarshalText deserializes the given text into a duration value.
// It is meant to support TOML decoding of durations.
func (d *Duration) UnmarshalText(text []byte) error {
return d.Set(string(text))
}
// -- time.Time Value

View file

@ -389,32 +389,3 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
netConn = nil // to avoid close in defer.
return conn, resp, nil
}
// cloneTLSConfig clones all public fields except the fields
// SessionTicketsDisabled and SessionTicketKey. This avoids copying the
// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a
// config in active use.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
}
}

16
vendor/github.com/gorilla/websocket/client_clone.go generated vendored Normal file
View file

@ -0,0 +1,16 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.8
package websocket
import "crypto/tls"
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return cfg.Clone()
}

View file

@ -0,0 +1,38 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.8
package websocket
import "crypto/tls"
// cloneTLSConfig clones all public fields except the fields
// SessionTicketsDisabled and SessionTicketKey. This avoids copying the
// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a
// config in active use.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
}
}

View file

@ -265,6 +265,10 @@ type Conn struct {
}
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil)
}
func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn {
mu := make(chan bool, 1)
mu <- true
@ -274,13 +278,28 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int)
if readBufferSize < maxControlFramePayloadSize {
readBufferSize = maxControlFramePayloadSize
}
// Reuse the supplied brw.Reader if brw.Reader's buf is the requested size.
var br *bufio.Reader
if brw != nil && brw.Reader != nil {
// This code assumes that peek on a reset reader returns
// bufio.Reader.buf[:0].
brw.Reader.Reset(conn)
if p, err := brw.Reader.Peek(0); err == nil && cap(p) == readBufferSize {
br = brw.Reader
}
}
if br == nil {
br = bufio.NewReaderSize(conn, readBufferSize)
}
if writeBufferSize == 0 {
writeBufferSize = defaultWriteBufferSize
}
c := &Conn{
isServer: isServer,
br: bufio.NewReaderSize(conn, readBufferSize),
br: br,
conn: conn,
mu: mu,
readFinal: true,
@ -659,12 +678,33 @@ func (w *messageWriter) Close() error {
return nil
}
// WritePreparedMessage writes prepared message into connection.
func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error {
frameType, frameData, err := pm.frame(prepareKey{
isServer: c.isServer,
compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType),
compressionLevel: c.compressionLevel,
})
if err != nil {
return err
}
if c.isWriting {
panic("concurrent write to websocket connection")
}
c.isWriting = true
err = c.write(frameType, c.writeDeadline, frameData, nil)
if !c.isWriting {
panic("concurrent write to websocket connection")
}
c.isWriting = false
return err
}
// WriteMessage is a helper method for getting a writer using NextWriter,
// writing the message and closing the writer.
func (c *Conn) WriteMessage(messageType int, data []byte) error {
if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) {
// Fast path with no allocations and single frame.
if err := c.prepWrite(messageType); err != nil {

103
vendor/github.com/gorilla/websocket/prepared.go generated vendored Normal file
View file

@ -0,0 +1,103 @@
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bytes"
"net"
"sync"
"time"
)
// PreparedMessage caches on the wire representations of a message payload.
// Use PreparedMessage to efficiently send a message payload to multiple
// connections. PreparedMessage is especially useful when compression is used
// because the CPU and memory expensive compression operation can be executed
// once for a given set of compression options.
type PreparedMessage struct {
messageType int
data []byte
err error
mu sync.Mutex
frames map[prepareKey]*preparedFrame
}
// prepareKey defines a unique set of options to cache prepared frames in PreparedMessage.
type prepareKey struct {
isServer bool
compress bool
compressionLevel int
}
// preparedFrame contains data in wire representation.
type preparedFrame struct {
once sync.Once
data []byte
}
// NewPreparedMessage returns an initialized PreparedMessage. You can then send
// it to connection using WritePreparedMessage method. Valid wire
// representation will be calculated lazily only once for a set of current
// connection options.
func NewPreparedMessage(messageType int, data []byte) (*PreparedMessage, error) {
pm := &PreparedMessage{
messageType: messageType,
frames: make(map[prepareKey]*preparedFrame),
data: data,
}
// Prepare a plain server frame.
_, frameData, err := pm.frame(prepareKey{isServer: true, compress: false})
if err != nil {
return nil, err
}
// To protect against caller modifying the data argument, remember the data
// copied to the plain server frame.
pm.data = frameData[len(frameData)-len(data):]
return pm, nil
}
func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) {
pm.mu.Lock()
frame, ok := pm.frames[key]
if !ok {
frame = &preparedFrame{}
pm.frames[key] = frame
}
pm.mu.Unlock()
var err error
frame.once.Do(func() {
// Prepare a frame using a 'fake' connection.
// TODO: Refactor code in conn.go to allow more direct construction of
// the frame.
mu := make(chan bool, 1)
mu <- true
var nc prepareConn
c := &Conn{
conn: &nc,
mu: mu,
isServer: key.isServer,
compressionLevel: key.compressionLevel,
enableWriteCompression: true,
writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize),
}
if key.compress {
c.newCompressionWriter = compressNoContextTakeover
}
err = c.WriteMessage(pm.messageType, pm.data)
frame.data = nc.buf.Bytes()
})
return pm.messageType, frame.data, err
}
type prepareConn struct {
buf bytes.Buffer
net.Conn
}
func (pc *prepareConn) Write(p []byte) (int, error) { return pc.buf.Write(p) }
func (pc *prepareConn) SetWriteDeadline(t time.Time) error { return nil }

View file

@ -152,7 +152,6 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
var (
netConn net.Conn
br *bufio.Reader
err error
)
@ -160,19 +159,18 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
if !ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
}
var rw *bufio.ReadWriter
netConn, rw, err = h.Hijack()
var brw *bufio.ReadWriter
netConn, brw, err = h.Hijack()
if err != nil {
return u.returnError(w, r, http.StatusInternalServerError, err.Error())
}
br = rw.Reader
if br.Buffered() > 0 {
if brw.Reader.Buffered() > 0 {
netConn.Close()
return nil, errors.New("websocket: client sent data before handshake is complete")
}
c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize)
c := newConnBRW(netConn, true, u.ReadBufferSize, u.WriteBufferSize, brw)
c.subprotocol = subprotocol
if compress {