199 lines
4.6 KiB
Go
199 lines
4.6 KiB
Go
|
// Package capture is a middleware that captures requests/responses size, and status.
|
||
|
//
|
||
|
// For another middleware to get those attributes of a request/response, this middleware
|
||
|
// should be added before in the middleware chain.
|
||
|
//
|
||
|
// handler, _ := NewHandler()
|
||
|
// chain := alice.New().
|
||
|
// Append(WrapHandler(handler)).
|
||
|
// Append(myOtherMiddleware).
|
||
|
// then(...)
|
||
|
//
|
||
|
// As this middleware stores those data in the request's context, the data can
|
||
|
// be retrieved at anytime after the ServerHTTP.
|
||
|
//
|
||
|
// func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.Handler) {
|
||
|
// capt, err := capture.FromContext(req.Context())
|
||
|
// if err != nil {
|
||
|
// ...
|
||
|
// }
|
||
|
//
|
||
|
// fmt.Println(capt.Status())
|
||
|
// fmt.Println(capt.ResponseSize())
|
||
|
// fmt.Println(capt.RequestSize())
|
||
|
// }
|
||
|
package capture
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"context"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"net"
|
||
|
"net/http"
|
||
|
|
||
|
"github.com/containous/alice"
|
||
|
"github.com/traefik/traefik/v2/pkg/middlewares"
|
||
|
)
|
||
|
|
||
|
type key string
|
||
|
|
||
|
const capturedData key = "capturedData"
|
||
|
|
||
|
// Handler will store each request data to its context.
|
||
|
type Handler struct{}
|
||
|
|
||
|
// WrapHandler wraps capture handler into an Alice Constructor.
|
||
|
func WrapHandler(handler *Handler) alice.Constructor {
|
||
|
return func(next http.Handler) (http.Handler, error) {
|
||
|
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||
|
handler.ServeHTTP(rw, req, next)
|
||
|
}), nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.Handler) {
|
||
|
c := Capture{}
|
||
|
if req.Body != nil {
|
||
|
readCounter := &readCounter{source: req.Body}
|
||
|
c.rr = readCounter
|
||
|
req.Body = readCounter
|
||
|
}
|
||
|
responseWriter := newResponseWriter(rw)
|
||
|
c.rw = responseWriter
|
||
|
ctx := context.WithValue(req.Context(), capturedData, &c)
|
||
|
next.ServeHTTP(responseWriter, req.WithContext(ctx))
|
||
|
}
|
||
|
|
||
|
// Capture is the object populated by the capture middleware,
|
||
|
// allowing to gather information about the request and response.
|
||
|
type Capture struct {
|
||
|
rr *readCounter
|
||
|
rw responseWriter
|
||
|
}
|
||
|
|
||
|
// FromContext returns the Capture value found in ctx, or an empty Capture otherwise.
|
||
|
func FromContext(ctx context.Context) (*Capture, error) {
|
||
|
c := ctx.Value(capturedData)
|
||
|
if c == nil {
|
||
|
return nil, errors.New("value not found")
|
||
|
}
|
||
|
capt, ok := c.(*Capture)
|
||
|
if !ok {
|
||
|
return nil, errors.New("value stored in Context is not a *Capture")
|
||
|
}
|
||
|
return capt, nil
|
||
|
}
|
||
|
|
||
|
func (c Capture) ResponseSize() int64 {
|
||
|
return c.rw.Size()
|
||
|
}
|
||
|
|
||
|
func (c Capture) StatusCode() int {
|
||
|
return c.rw.Status()
|
||
|
}
|
||
|
|
||
|
// RequestSize returns the size of the request's body if it applies,
|
||
|
// zero otherwise.
|
||
|
func (c Capture) RequestSize() int64 {
|
||
|
if c.rr == nil {
|
||
|
return 0
|
||
|
}
|
||
|
return c.rr.size
|
||
|
}
|
||
|
|
||
|
type readCounter struct {
|
||
|
// source ReadCloser from where the request body is read.
|
||
|
source io.ReadCloser
|
||
|
// size is total the number of bytes read.
|
||
|
size int64
|
||
|
}
|
||
|
|
||
|
func (r *readCounter) Read(p []byte) (int, error) {
|
||
|
n, err := r.source.Read(p)
|
||
|
r.size += int64(n)
|
||
|
return n, err
|
||
|
}
|
||
|
|
||
|
func (r *readCounter) Close() error {
|
||
|
return r.source.Close()
|
||
|
}
|
||
|
|
||
|
var _ middlewares.Stateful = &responseWriterWithCloseNotify{}
|
||
|
|
||
|
type responseWriter interface {
|
||
|
http.ResponseWriter
|
||
|
Size() int64
|
||
|
Status() int
|
||
|
}
|
||
|
|
||
|
func newResponseWriter(rw http.ResponseWriter) responseWriter {
|
||
|
capt := &captureResponseWriter{rw: rw}
|
||
|
if _, ok := rw.(http.CloseNotifier); !ok {
|
||
|
return capt
|
||
|
}
|
||
|
|
||
|
return &responseWriterWithCloseNotify{capt}
|
||
|
}
|
||
|
|
||
|
// captureResponseWriter is a wrapper of type http.ResponseWriter
|
||
|
// that tracks response status and size.
|
||
|
type captureResponseWriter struct {
|
||
|
rw http.ResponseWriter
|
||
|
status int
|
||
|
size int64
|
||
|
}
|
||
|
|
||
|
func (crw *captureResponseWriter) Header() http.Header {
|
||
|
return crw.rw.Header()
|
||
|
}
|
||
|
|
||
|
func (crw *captureResponseWriter) Size() int64 {
|
||
|
return crw.size
|
||
|
}
|
||
|
|
||
|
func (crw *captureResponseWriter) Status() int {
|
||
|
return crw.status
|
||
|
}
|
||
|
|
||
|
func (crw *captureResponseWriter) Write(b []byte) (int, error) {
|
||
|
if crw.status == 0 {
|
||
|
crw.status = http.StatusOK
|
||
|
}
|
||
|
|
||
|
size, err := crw.rw.Write(b)
|
||
|
crw.size += int64(size)
|
||
|
|
||
|
return size, err
|
||
|
}
|
||
|
|
||
|
func (crw *captureResponseWriter) WriteHeader(s int) {
|
||
|
crw.rw.WriteHeader(s)
|
||
|
crw.status = s
|
||
|
}
|
||
|
|
||
|
func (crw *captureResponseWriter) Flush() {
|
||
|
if f, ok := crw.rw.(http.Flusher); ok {
|
||
|
f.Flush()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (crw *captureResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||
|
if h, ok := crw.rw.(http.Hijacker); ok {
|
||
|
return h.Hijack()
|
||
|
}
|
||
|
|
||
|
return nil, nil, fmt.Errorf("not a hijacker: %T", crw.rw)
|
||
|
}
|
||
|
|
||
|
type responseWriterWithCloseNotify struct {
|
||
|
*captureResponseWriter
|
||
|
}
|
||
|
|
||
|
// CloseNotify returns a channel that receives at most a
|
||
|
// single value (true) when the client connection has gone away.
|
||
|
func (r *responseWriterWithCloseNotify) CloseNotify() <-chan bool {
|
||
|
return r.rw.(http.CloseNotifier).CloseNotify()
|
||
|
}
|