Support mirroring request body

Co-authored-by: Mathieu Lonjaret <mathieu.lonjaret@gmail.com>
Co-authored-by: Julien Salleyron <julien.salleyron@gmail.com>
This commit is contained in:
Dmytro Tananayskiy 2020-03-05 18:03:08 +01:00 committed by Traefiker Bot
parent 09c07f45ee
commit cf7f0f878a
20 changed files with 454 additions and 44 deletions

View file

@ -65,6 +65,7 @@
[http.services.Service02] [http.services.Service02]
[http.services.Service02.mirroring] [http.services.Service02.mirroring]
service = "foobar" service = "foobar"
maxBodySize = 42
[[http.services.Service02.mirroring.mirrors]] [[http.services.Service02.mirroring.mirrors]]
name = "foobar" name = "foobar"

View file

@ -72,6 +72,7 @@ http:
Service02: Service02:
mirroring: mirroring:
service: foobar service: foobar
maxBodySize: 42
mirrors: mirrors:
- name: foobar - name: foobar
percent: 42 percent: 42

View file

@ -65,6 +65,8 @@ spec:
kind: TraefikService kind: TraefikService
mirrors: mirrors:
- name: s2 - name: s2
# Optional
maxBodySize: 2000000000
# Optional, as it is the default value # Optional, as it is the default value
kind: Service kind: Service
percent: 20 percent: 20

View file

@ -174,6 +174,7 @@
| `traefik/http/services/Service01/loadBalancer/sticky/cookie/httpOnly` | `true` | | `traefik/http/services/Service01/loadBalancer/sticky/cookie/httpOnly` | `true` |
| `traefik/http/services/Service01/loadBalancer/sticky/cookie/name` | `foobar` | | `traefik/http/services/Service01/loadBalancer/sticky/cookie/name` | `foobar` |
| `traefik/http/services/Service01/loadBalancer/sticky/cookie/secure` | `true` | | `traefik/http/services/Service01/loadBalancer/sticky/cookie/secure` | `true` |
| `traefik/http/services/Service02/mirroring/maxBodySize` | `42` |
| `traefik/http/services/Service02/mirroring/mirrors/0/name` | `foobar` | | `traefik/http/services/Service02/mirroring/mirrors/0/name` | `foobar` |
| `traefik/http/services/Service02/mirroring/mirrors/0/percent` | `42` | | `traefik/http/services/Service02/mirroring/mirrors/0/percent` | `42` |
| `traefik/http/services/Service02/mirroring/mirrors/1/name` | `foobar` | | `traefik/http/services/Service02/mirroring/mirrors/1/name` | `foobar` |

View file

@ -462,6 +462,8 @@ http:
### Mirroring (service) ### Mirroring (service)
The mirroring is able to mirror requests sent to a service to other services. The mirroring is able to mirror requests sent to a service to other services.
Please note that by default the whole request is buffered in memory while it is being mirrored.
See the maxBodySize option in the example below for how to modify this behaviour.
!!! info "Supported Providers" !!! info "Supported Providers"
@ -473,6 +475,10 @@ The mirroring is able to mirror requests sent to a service to other services.
[http.services.mirrored-api] [http.services.mirrored-api]
[http.services.mirrored-api.mirroring] [http.services.mirrored-api.mirroring]
service = "appv1" service = "appv1"
# maxBodySize is the maximum size in bytes allowed for the body of the request.
# If the body is larger, the request is not mirrored.
# Default value is -1, which means unlimited size.
maxBodySize = 1024
[[http.services.mirrored-api.mirroring.mirrors]] [[http.services.mirrored-api.mirroring.mirrors]]
name = "appv2" name = "appv2"
percent = 10 percent = 10
@ -495,6 +501,10 @@ http:
mirrored-api: mirrored-api:
mirroring: mirroring:
service: appv1 service: appv1
# maxBodySize is the maximum size allowed for the body of the request.
# If the body is larger, the request is not mirrored.
# Default value is -1, which means unlimited size.
maxBodySize = 1024
mirrors: mirrors:
- name: appv2 - name: appv2
percent: 10 percent: 10

View file

@ -23,6 +23,11 @@
service = "mirror" service = "mirror"
rule = "Path(`/whoami`)" rule = "Path(`/whoami`)"
[http.routers.router2]
service = "mirrorWithMaxBody"
rule = "Path(`/whoamiWithMaxBody`)"
[http.services] [http.services]
[http.services.mirror.mirroring] [http.services.mirror.mirroring]
service = "service1" service = "service1"
@ -33,6 +38,17 @@
name = "mirror2" name = "mirror2"
percent = 50 percent = 50
[http.services.mirrorWithMaxBody.mirroring]
service = "service1"
maxBodySize = 8
[[http.services.mirrorWithMaxBody.mirroring.mirrors]]
name = "mirror1"
percent = 10
[[http.services.mirrorWithMaxBody.mirroring.mirrors]]
name = "mirror2"
percent = 50
[http.services.service1.loadBalancer] [http.services.service1.loadBalancer]
[[http.services.service1.loadBalancer.servers]] [[http.services.service1.loadBalancer.servers]]
url = "{{ .MainServer }}" url = "{{ .MainServer }}"

View file

@ -2,6 +2,7 @@ package integration
import ( import (
"bytes" "bytes"
"crypto/rand"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -777,6 +778,129 @@ func (s *SimpleSuite) TestMirror(c *check.C) {
c.Assert(val2, checker.Equals, int32(5)) c.Assert(val2, checker.Equals, int32(5))
} }
func (s *SimpleSuite) TestMirrorWithBody(c *check.C) {
var count, countMirror1, countMirror2 int32
body20 := make([]byte, 20)
_, err := rand.Read(body20)
c.Assert(err, checker.IsNil)
body5 := make([]byte, 5)
_, err = rand.Read(body5)
c.Assert(err, checker.IsNil)
verifyBody := func(req *http.Request) {
b, _ := ioutil.ReadAll(req.Body)
switch req.Header.Get("Size") {
case "20":
if !bytes.Equal(b, body20) {
c.Fatalf("Not Equals \n%v \n%v", body20, b)
}
case "5":
if !bytes.Equal(b, body5) {
c.Fatalf("Not Equals \n%v \n%v", body5, b)
}
default:
c.Fatal("Size header not present")
}
}
main := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
verifyBody(req)
atomic.AddInt32(&count, 1)
}))
mirror1 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
verifyBody(req)
atomic.AddInt32(&countMirror1, 1)
}))
mirror2 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
verifyBody(req)
atomic.AddInt32(&countMirror2, 1)
}))
mainServer := main.URL
mirror1Server := mirror1.URL
mirror2Server := mirror2.URL
file := s.adaptFile(c, "fixtures/mirror.toml", struct {
MainServer string
Mirror1Server string
Mirror2Server string
}{MainServer: mainServer, Mirror1Server: mirror1Server, Mirror2Server: mirror2Server})
defer os.Remove(file)
cmd, output := s.traefikCmd(withConfigFile(file))
defer output(c)
err = cmd.Start()
c.Assert(err, checker.IsNil)
defer cmd.Process.Kill()
err = try.GetRequest("http://127.0.0.1:8080/api/http/services", 1000*time.Millisecond, try.BodyContains("mirror1", "mirror2", "service1"))
c.Assert(err, checker.IsNil)
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/whoami", bytes.NewBuffer(body20))
c.Assert(err, checker.IsNil)
req.Header.Set("Size", "20")
for i := 0; i < 10; i++ {
response, err := http.DefaultClient.Do(req)
c.Assert(err, checker.IsNil)
c.Assert(response.StatusCode, checker.Equals, http.StatusOK)
}
countTotal := atomic.LoadInt32(&count)
val1 := atomic.LoadInt32(&countMirror1)
val2 := atomic.LoadInt32(&countMirror2)
c.Assert(countTotal, checker.Equals, int32(10))
c.Assert(val1, checker.Equals, int32(1))
c.Assert(val2, checker.Equals, int32(5))
atomic.StoreInt32(&count, 0)
atomic.StoreInt32(&countMirror1, 0)
atomic.StoreInt32(&countMirror2, 0)
req, err = http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/whoamiWithMaxBody", bytes.NewBuffer(body5))
req.Header.Set("Size", "5")
c.Assert(err, checker.IsNil)
for i := 0; i < 10; i++ {
response, err := http.DefaultClient.Do(req)
c.Assert(err, checker.IsNil)
c.Assert(response.StatusCode, checker.Equals, http.StatusOK)
}
countTotal = atomic.LoadInt32(&count)
val1 = atomic.LoadInt32(&countMirror1)
val2 = atomic.LoadInt32(&countMirror2)
c.Assert(countTotal, checker.Equals, int32(10))
c.Assert(val1, checker.Equals, int32(1))
c.Assert(val2, checker.Equals, int32(5))
atomic.StoreInt32(&count, 0)
atomic.StoreInt32(&countMirror1, 0)
atomic.StoreInt32(&countMirror2, 0)
req, err = http.NewRequest(http.MethodGet, "http://127.0.0.1:8000/whoamiWithMaxBody", bytes.NewBuffer(body20))
req.Header.Set("Size", "20")
c.Assert(err, checker.IsNil)
for i := 0; i < 10; i++ {
response, err := http.DefaultClient.Do(req)
c.Assert(err, checker.IsNil)
c.Assert(response.StatusCode, checker.Equals, http.StatusOK)
}
countTotal = atomic.LoadInt32(&count)
val1 = atomic.LoadInt32(&countMirror1)
val2 = atomic.LoadInt32(&countMirror2)
c.Assert(countTotal, checker.Equals, int32(10))
c.Assert(val1, checker.Equals, int32(0))
c.Assert(val2, checker.Equals, int32(0))
}
func (s *SimpleSuite) TestMirrorCanceled(c *check.C) { func (s *SimpleSuite) TestMirrorCanceled(c *check.C) {
var count, countMirror1, countMirror2 int32 var count, countMirror1, countMirror2 int32

View file

@ -152,6 +152,7 @@
"mirror@consul": { "mirror@consul": {
"mirroring": { "mirroring": {
"service": "simplesvc", "service": "simplesvc",
"maxBodySize": -1,
"mirrors": [ "mirrors": [
{ {
"name": "srvcA", "name": "srvcA",

View file

@ -152,6 +152,7 @@
"mirror@etcd": { "mirror@etcd": {
"mirroring": { "mirroring": {
"service": "simplesvc", "service": "simplesvc",
"maxBodySize": -1,
"mirrors": [ "mirrors": [
{ {
"name": "srvcA", "name": "srvcA",

View file

@ -152,6 +152,7 @@
"mirror@redis": { "mirror@redis": {
"mirroring": { "mirroring": {
"service": "simplesvc", "service": "simplesvc",
"maxBodySize": -1,
"mirrors": [ "mirrors": [
{ {
"name": "srvcA", "name": "srvcA",

View file

@ -152,6 +152,7 @@
"mirror@zookeeper": { "mirror@zookeeper": {
"mirroring": { "mirroring": {
"service": "simplesvc", "service": "simplesvc",
"maxBodySize": -1,
"mirrors": [ "mirrors": [
{ {
"name": "srvcA", "name": "srvcA",

View file

@ -59,9 +59,16 @@ type RouterTLSConfig struct {
// Mirroring holds the Mirroring configuration. // Mirroring holds the Mirroring configuration.
type Mirroring struct { type Mirroring struct {
Service string `json:"service,omitempty" toml:"service,omitempty" yaml:"service,omitempty"` Service string `json:"service,omitempty" toml:"service,omitempty" yaml:"service,omitempty"`
MaxBodySize *int64 `json:"maxBodySize,omitempty" toml:"maxBodySize,omitempty" yaml:"maxBodySize,omitempty"`
Mirrors []MirrorService `json:"mirrors,omitempty" toml:"mirrors,omitempty" yaml:"mirrors,omitempty"` Mirrors []MirrorService `json:"mirrors,omitempty" toml:"mirrors,omitempty" yaml:"mirrors,omitempty"`
} }
// SetDefaults Default values for a WRRService.
func (m *Mirroring) SetDefaults() {
var defaultMaxBodySize int64 = -1
m.MaxBodySize = &defaultMaxBodySize
}
// +k8s:deepcopy-gen=true // +k8s:deepcopy-gen=true
// MirrorService holds the MirrorService configuration. // MirrorService holds the MirrorService configuration.

View file

@ -762,6 +762,11 @@ func (in *MirrorService) DeepCopy() *MirrorService {
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *Mirroring) DeepCopyInto(out *Mirroring) { func (in *Mirroring) DeepCopyInto(out *Mirroring) {
*out = *in *out = *in
if in.MaxBodySize != nil {
in, out := &in.MaxBodySize, &out.MaxBodySize
*out = new(int64)
**out = **in
}
if in.Mirrors != nil { if in.Mirrors != nil {
in, out := &in.Mirrors, &out.Mirrors in, out := &in.Mirrors, &out.Mirrors
*out = make([]MirrorService, len(*in)) *out = make([]MirrorService, len(*in))

View file

@ -242,6 +242,7 @@ func (c configBuilder) buildMirroring(ctx context.Context, tService *v1alpha1.Tr
Mirroring: &dynamic.Mirroring{ Mirroring: &dynamic.Mirroring{
Service: fullNameMain, Service: fullNameMain,
Mirrors: mirrorServices, Mirrors: mirrorServices,
MaxBodySize: tService.Spec.Mirroring.MaxBodySize,
}, },
} }

View file

@ -44,6 +44,7 @@ type ServiceSpec struct {
// load-balancer, and a list of mirrors. // load-balancer, and a list of mirrors.
type Mirroring struct { type Mirroring struct {
LoadBalancerSpec LoadBalancerSpec
MaxBodySize *int64
Mirrors []MirrorService `json:"mirrors,omitempty"` Mirrors []MirrorService `json:"mirrors,omitempty"`
} }

View file

@ -749,6 +749,11 @@ func (in *MirrorService) DeepCopy() *MirrorService {
func (in *Mirroring) DeepCopyInto(out *Mirroring) { func (in *Mirroring) DeepCopyInto(out *Mirroring) {
*out = *in *out = *in
in.LoadBalancerSpec.DeepCopyInto(&out.LoadBalancerSpec) in.LoadBalancerSpec.DeepCopyInto(&out.LoadBalancerSpec)
if in.MaxBodySize != nil {
in, out := &in.MaxBodySize, &out.MaxBodySize
*out = new(int64)
**out = **in
}
if in.Mirrors != nil { if in.Mirrors != nil {
in, out := &in.Mirrors, &out.Mirrors in, out := &in.Mirrors, &out.Mirrors
*out = make([]MirrorService, len(*in)) *out = make([]MirrorService, len(*in))

View file

@ -56,6 +56,7 @@ func Test_buildConfiguration(t *testing.T) {
"traefik/http/services/Service01/loadBalancer/servers/0/url": "foobar", "traefik/http/services/Service01/loadBalancer/servers/0/url": "foobar",
"traefik/http/services/Service01/loadBalancer/servers/1/url": "foobar", "traefik/http/services/Service01/loadBalancer/servers/1/url": "foobar",
"traefik/http/services/Service02/mirroring/service": "foobar", "traefik/http/services/Service02/mirroring/service": "foobar",
"traefik/http/services/Service02/mirroring/maxBodySize": "42",
"traefik/http/services/Service02/mirroring/mirrors/0/name": "foobar", "traefik/http/services/Service02/mirroring/mirrors/0/name": "foobar",
"traefik/http/services/Service02/mirroring/mirrors/0/percent": "42", "traefik/http/services/Service02/mirroring/mirrors/0/percent": "42",
"traefik/http/services/Service02/mirroring/mirrors/1/name": "foobar", "traefik/http/services/Service02/mirroring/mirrors/1/name": "foobar",
@ -636,6 +637,7 @@ func Test_buildConfiguration(t *testing.T) {
"Service02": { "Service02": {
Mirroring: &dynamic.Mirroring{ Mirroring: &dynamic.Mirroring{
Service: "foobar", Service: "foobar",
MaxBodySize: func(v int64) *int64 { return &v }(42),
Mirrors: []dynamic.MirrorService{ Mirrors: []dynamic.MirrorService{
{ {
Name: "foobar", Name: "foobar",

View file

@ -2,12 +2,17 @@ package mirror
import ( import (
"bufio" "bufio"
"bytes"
"context" "context"
"errors" "errors"
"fmt"
"io"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"sync" "sync"
"github.com/containous/traefik/v2/pkg/log"
"github.com/containous/traefik/v2/pkg/middlewares/accesslog" "github.com/containous/traefik/v2/pkg/middlewares/accesslog"
"github.com/containous/traefik/v2/pkg/safe" "github.com/containous/traefik/v2/pkg/safe"
) )
@ -19,16 +24,19 @@ type Mirroring struct {
rw http.ResponseWriter rw http.ResponseWriter
routinePool *safe.Pool routinePool *safe.Pool
maxBodySize int64
lock sync.RWMutex lock sync.RWMutex
total uint64 total uint64
} }
// New returns a new instance of *Mirroring. // New returns a new instance of *Mirroring.
func New(handler http.Handler, pool *safe.Pool) *Mirroring { func New(handler http.Handler, pool *safe.Pool, maxBodySize int64) *Mirroring {
return &Mirroring{ return &Mirroring{
routinePool: pool, routinePool: pool,
handler: handler, handler: handler,
rw: blackholeResponseWriter{}, rw: blackHoleResponseWriter{},
maxBodySize: maxBodySize,
} }
} }
@ -47,41 +55,73 @@ type mirrorHandler struct {
count uint64 count uint64
} }
func (m *Mirroring) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (m *Mirroring) getActiveMirrors() []http.Handler {
m.handler.ServeHTTP(rw, req)
select {
case <-req.Context().Done():
// No mirroring if request has been canceled during main handler ServeHTTP
return
default:
}
m.routinePool.GoCtx(func(_ context.Context) {
total := m.inc() total := m.inc()
var mirrors []http.Handler
for _, handler := range m.mirrorHandlers { for _, handler := range m.mirrorHandlers {
handler.lock.Lock() handler.lock.Lock()
if handler.count*100 < total*uint64(handler.percent) { if handler.count*100 < total*uint64(handler.percent) {
handler.count++ handler.count++
handler.lock.Unlock() handler.lock.Unlock()
mirrors = append(mirrors, handler)
} else {
handler.lock.Unlock()
}
}
return mirrors
}
// In ServeHTTP, we rely on the presence of the accesslog datatable found in the func (m *Mirroring) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// request's context to know whether we should mutate said datatable (and mirrors := m.getActiveMirrors()
// contribute some fields to the log). In this instance, we do not want the mirrors if len(mirrors) == 0 {
// mutating (i.e. changing the service name in) the logs related to the mirrored m.handler.ServeHTTP(rw, req)
// server. Especially since it would result in unguarded concurrent reads/writes on return
// the datatable. Therefore, we reset any potential datatable key in the new }
// context that we pass around.
ctx := context.WithValue(req.Context(), accesslog.DataTableKey, nil) logger := log.FromContext(req.Context())
rr, bytesRead, err := newReusableRequest(req, m.maxBodySize)
if err != nil && err != errBodyTooLarge {
http.Error(rw, http.StatusText(http.StatusInternalServerError)+
fmt.Sprintf("error creating reusable request: %v", err), http.StatusInternalServerError)
return
}
if err == errBodyTooLarge {
req.Body = ioutil.NopCloser(io.MultiReader(bytes.NewReader(bytesRead), req.Body))
m.handler.ServeHTTP(rw, req)
logger.Debugf("no mirroring, request body larger than allowed size")
return
}
m.handler.ServeHTTP(rw, rr.clone(req.Context()))
select {
case <-req.Context().Done():
// No mirroring if request has been canceled during main handler ServeHTTP
logger.Warn("no mirroring, request has been canceled during main handler ServeHTTP")
return
default:
}
m.routinePool.GoCtx(func(_ context.Context) {
for _, handler := range mirrors {
// prepare request, update body from buffer
r := rr.clone(req.Context())
// In ServeHTTP, we rely on the presence of the accessLog datatable found in the request's context
// to know whether we should mutate said datatable (and contribute some fields to the log).
// In this instance, we do not want the mirrors mutating (i.e. changing the service name in)
// the logs related to the mirrored server.
// Especially since it would result in unguarded concurrent reads/writes on the datatable.
// Therefore, we reset any potential datatable key in the new context that we pass around.
ctx := context.WithValue(r.Context(), accesslog.DataTableKey, nil)
// When a request served by m.handler is successful, req.Context will be canceled, // When a request served by m.handler is successful, req.Context will be canceled,
// which would trigger a cancellation of the ongoing mirrored requests. // which would trigger a cancellation of the ongoing mirrored requests.
// Therefore, we give a new, non-cancellable context to each of the mirrored calls, // Therefore, we give a new, non-cancellable context to each of the mirrored calls,
// so they can terminate by themselves. // so they can terminate by themselves.
handler.ServeHTTP(m.rw, req.WithContext(contextStopPropagation{ctx})) handler.ServeHTTP(m.rw, r.WithContext(contextStopPropagation{ctx}))
} else {
handler.lock.Unlock()
}
} }
}) })
} }
@ -95,23 +135,23 @@ func (m *Mirroring) AddMirror(handler http.Handler, percent int) error {
return nil return nil
} }
type blackholeResponseWriter struct{} type blackHoleResponseWriter struct{}
func (b blackholeResponseWriter) Flush() {} func (b blackHoleResponseWriter) Flush() {}
func (b blackholeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (b blackHoleResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, errors.New("connection on blackholeResponseWriter cannot be hijacked") return nil, nil, errors.New("connection on blackHoleResponseWriter cannot be hijacked")
} }
func (b blackholeResponseWriter) Header() http.Header { func (b blackHoleResponseWriter) Header() http.Header {
return http.Header{} return http.Header{}
} }
func (b blackholeResponseWriter) Write(bytes []byte) (int, error) { func (b blackHoleResponseWriter) Write(bytes []byte) (int, error) {
return len(bytes), nil return len(bytes), nil
} }
func (b blackholeResponseWriter) WriteHeader(statusCode int) {} func (b blackHoleResponseWriter) WriteHeader(statusCode int) {}
type contextStopPropagation struct { type contextStopPropagation struct {
context.Context context.Context
@ -120,3 +160,65 @@ type contextStopPropagation struct {
func (c contextStopPropagation) Done() <-chan struct{} { func (c contextStopPropagation) Done() <-chan struct{} {
return make(chan struct{}) return make(chan struct{})
} }
// reusableRequest keeps in memory the body of the given request,
// so that the request can be fully cloned by each mirror.
type reusableRequest struct {
req *http.Request
body []byte
}
var errBodyTooLarge = errors.New("request body too large")
// if the returned error is errBodyTooLarge, newReusableRequest also returns the
// bytes that were already consumed from the request's body.
func newReusableRequest(req *http.Request, maxBodySize int64) (*reusableRequest, []byte, error) {
if req == nil {
return nil, nil, errors.New("nil input request")
}
if req.Body == nil {
return &reusableRequest{req: req}, nil, nil
}
// unbounded body size
if maxBodySize < 0 {
body, err := ioutil.ReadAll(req.Body)
if err != nil {
return nil, nil, err
}
return &reusableRequest{
req: req,
body: body,
}, nil, nil
}
// we purposefully try to read _more_ than maxBodySize to detect whether
// the request body is larger than what we allow for the mirrors.
body := make([]byte, maxBodySize+1)
n, err := io.ReadFull(req.Body, body)
if err != nil && err != io.ErrUnexpectedEOF {
return nil, nil, err
}
// we got an ErrUnexpectedEOF, which means there was less than maxBodySize data to read,
// which permits us sending also to all the mirrors later.
if err == io.ErrUnexpectedEOF {
return &reusableRequest{
req: req,
body: body[:n],
}, nil, nil
}
// err == nil , which means data size > maxBodySize
return nil, body[:n], errBodyTooLarge
}
func (rr reusableRequest) clone(ctx context.Context) *http.Request {
req := rr.req.Clone(ctx)
if rr.body != nil {
req.Body = ioutil.NopCloser(bytes.NewReader(rr.body))
}
return req
}

View file

@ -1,7 +1,9 @@
package mirror package mirror
import ( import (
"bytes"
"context" "context"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync/atomic" "sync/atomic"
@ -11,13 +13,15 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
const defaultMaxBodySize int64 = -1
func TestMirroringOn100(t *testing.T) { func TestMirroringOn100(t *testing.T) {
var countMirror1, countMirror2 int32 var countMirror1, countMirror2 int32
handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
}) })
pool := safe.NewPool(context.Background()) pool := safe.NewPool(context.Background())
mirror := New(handler, pool) mirror := New(handler, pool, defaultMaxBodySize)
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
atomic.AddInt32(&countMirror1, 1) atomic.AddInt32(&countMirror1, 1)
}), 10) }), 10)
@ -46,7 +50,7 @@ func TestMirroringOn10(t *testing.T) {
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
}) })
pool := safe.NewPool(context.Background()) pool := safe.NewPool(context.Background())
mirror := New(handler, pool) mirror := New(handler, pool, defaultMaxBodySize)
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
atomic.AddInt32(&countMirror1, 1) atomic.AddInt32(&countMirror1, 1)
}), 10) }), 10)
@ -70,7 +74,7 @@ func TestMirroringOn10(t *testing.T) {
} }
func TestInvalidPercent(t *testing.T) { func TestInvalidPercent(t *testing.T) {
mirror := New(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), safe.NewPool(context.Background())) mirror := New(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), safe.NewPool(context.Background()), defaultMaxBodySize)
err := mirror.AddMirror(nil, -1) err := mirror.AddMirror(nil, -1)
assert.Error(t, err) assert.Error(t, err)
@ -89,7 +93,7 @@ func TestHijack(t *testing.T) {
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
}) })
pool := safe.NewPool(context.Background()) pool := safe.NewPool(context.Background())
mirror := New(handler, pool) mirror := New(handler, pool, defaultMaxBodySize)
var mirrorRequest bool var mirrorRequest bool
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
@ -113,7 +117,7 @@ func TestFlush(t *testing.T) {
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
}) })
pool := safe.NewPool(context.Background()) pool := safe.NewPool(context.Background())
mirror := New(handler, pool) mirror := New(handler, pool, defaultMaxBodySize)
var mirrorRequest bool var mirrorRequest bool
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
@ -131,3 +135,121 @@ func TestFlush(t *testing.T) {
pool.Stop() pool.Stop()
assert.Equal(t, true, mirrorRequest) assert.Equal(t, true, mirrorRequest)
} }
func TestMirroringWithBody(t *testing.T) {
const numMirrors = 10
var (
countMirror int32
body = []byte(`body`)
)
pool := safe.NewPool(context.Background())
handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
assert.NotNil(t, r.Body)
bb, err := ioutil.ReadAll(r.Body)
assert.NoError(t, err)
assert.Equal(t, body, bb)
rw.WriteHeader(http.StatusOK)
})
mirror := New(handler, pool, defaultMaxBodySize)
for i := 0; i < numMirrors; i++ {
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
assert.NotNil(t, r.Body)
bb, err := ioutil.ReadAll(r.Body)
assert.NoError(t, err)
assert.Equal(t, body, bb)
atomic.AddInt32(&countMirror, 1)
}), 100)
assert.NoError(t, err)
}
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(body))
mirror.ServeHTTP(httptest.NewRecorder(), req)
pool.Stop()
val := atomic.LoadInt32(&countMirror)
assert.Equal(t, numMirrors, int(val))
}
func TestCloneRequest(t *testing.T) {
t.Run("http request body is nil", func(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "/", nil)
assert.NoError(t, err)
ctx := req.Context()
rr, _, err := newReusableRequest(req, defaultMaxBodySize)
assert.NoError(t, err)
// first call
cloned := rr.clone(ctx)
assert.Equal(t, cloned, req)
assert.Nil(t, cloned.Body)
// second call
cloned = rr.clone(ctx)
assert.Equal(t, cloned, req)
assert.Nil(t, cloned.Body)
})
t.Run("http request body is not nil", func(t *testing.T) {
bb := []byte(`¯\_(ツ)_/¯`)
contentLength := len(bb)
buf := bytes.NewBuffer(bb)
req, err := http.NewRequest(http.MethodPost, "/", buf)
assert.NoError(t, err)
ctx := req.Context()
req.ContentLength = int64(contentLength)
rr, _, err := newReusableRequest(req, defaultMaxBodySize)
assert.NoError(t, err)
// first call
cloned := rr.clone(ctx)
body, err := ioutil.ReadAll(cloned.Body)
assert.NoError(t, err)
assert.Equal(t, bb, body)
// second call
cloned = rr.clone(ctx)
body, err = ioutil.ReadAll(cloned.Body)
assert.NoError(t, err)
assert.Equal(t, bb, body)
})
t.Run("failed case", func(t *testing.T) {
bb := []byte(`1234567890`)
buf := bytes.NewBuffer(bb)
req, err := http.NewRequest(http.MethodPost, "/", buf)
assert.NoError(t, err)
_, expectedBytes, err := newReusableRequest(req, 2)
assert.Error(t, err)
assert.Equal(t, bb[:3], expectedBytes)
})
t.Run("valid case with maxBodySize", func(t *testing.T) {
bb := []byte(`1234567890`)
buf := bytes.NewBuffer(bb)
req, err := http.NewRequest(http.MethodPost, "/", buf)
assert.NoError(t, err)
_, expectedBytes, err := newReusableRequest(req, 20)
assert.NoError(t, err)
assert.Nil(t, expectedBytes)
})
t.Run("no request given", func(t *testing.T) {
_, _, err := newReusableRequest(nil, defaultMaxBodySize)
assert.Error(t, err)
})
}

View file

@ -33,6 +33,8 @@ const (
defaultHealthCheckTimeout = 5 * time.Second defaultHealthCheckTimeout = 5 * time.Second
) )
const defaultMaxBodySize int64 = -1
// NewManager creates a new Manager // NewManager creates a new Manager
func NewManager(configs map[string]*runtime.ServiceInfo, defaultRoundTripper http.RoundTripper, metricsRegistry metrics.Registry, routinePool *safe.Pool) *Manager { func NewManager(configs map[string]*runtime.ServiceInfo, defaultRoundTripper http.RoundTripper, metricsRegistry metrics.Registry, routinePool *safe.Pool) *Manager {
return &Manager{ return &Manager{
@ -123,7 +125,11 @@ func (m *Manager) getMirrorServiceHandler(ctx context.Context, config *dynamic.M
return nil, err return nil, err
} }
handler := mirror.New(serviceHandler, m.routinePool) maxBodySize := defaultMaxBodySize
if config.MaxBodySize != nil {
maxBodySize = *config.MaxBodySize
}
handler := mirror.New(serviceHandler, m.routinePool, maxBodySize)
for _, mirrorConfig := range config.Mirrors { for _, mirrorConfig := range config.Mirrors {
mirrorHandler, err := m.BuildHTTP(ctx, mirrorConfig.Name, responseModifier) mirrorHandler, err := m.BuildHTTP(ctx, mirrorConfig.Name, responseModifier)
if err != nil { if err != nil {