2021-06-11 15:30:05 +02:00
package tcpmiddleware
import (
"context"
"fmt"
"strings"
2023-02-03 15:24:05 +01:00
"github.com/traefik/traefik/v3/pkg/config/runtime"
"github.com/traefik/traefik/v3/pkg/middlewares/tcp/inflightconn"
"github.com/traefik/traefik/v3/pkg/middlewares/tcp/ipallowlist"
"github.com/traefik/traefik/v3/pkg/server/provider"
"github.com/traefik/traefik/v3/pkg/tcp"
2021-06-11 15:30:05 +02:00
)
type middlewareStackType int
const (
middlewareStackKey middlewareStackType = iota
)
// Builder the middleware builder.
type Builder struct {
configs map [ string ] * runtime . TCPMiddlewareInfo
}
// NewBuilder creates a new Builder.
func NewBuilder ( configs map [ string ] * runtime . TCPMiddlewareInfo ) * Builder {
return & Builder { configs : configs }
}
// BuildChain creates a middleware chain.
func ( b * Builder ) BuildChain ( ctx context . Context , middlewares [ ] string ) * tcp . Chain {
chain := tcp . NewChain ( )
for _ , name := range middlewares {
middlewareName := provider . GetQualifiedName ( ctx , name )
chain = chain . Append ( func ( next tcp . Handler ) ( tcp . Handler , error ) {
constructorContext := provider . AddInContext ( ctx , middlewareName )
if midInf , ok := b . configs [ middlewareName ] ; ! ok || midInf . TCPMiddleware == nil {
return nil , fmt . Errorf ( "middleware %q does not exist" , middlewareName )
}
var err error
if constructorContext , err = checkRecursion ( constructorContext , middlewareName ) ; err != nil {
b . configs [ middlewareName ] . AddError ( err , true )
return nil , err
}
constructor , err := b . buildConstructor ( constructorContext , middlewareName )
if err != nil {
b . configs [ middlewareName ] . AddError ( err , true )
return nil , err
}
handler , err := constructor ( next )
if err != nil {
b . configs [ middlewareName ] . AddError ( err , true )
return nil , err
}
return handler , nil
} )
}
return & chain
}
func checkRecursion ( ctx context . Context , middlewareName string ) ( context . Context , error ) {
currentStack , ok := ctx . Value ( middlewareStackKey ) . ( [ ] string )
if ! ok {
currentStack = [ ] string { }
}
if inSlice ( middlewareName , currentStack ) {
return ctx , fmt . Errorf ( "could not instantiate middleware %s: recursion detected in %s" , middlewareName , strings . Join ( append ( currentStack , middlewareName ) , "->" ) )
}
return context . WithValue ( ctx , middlewareStackKey , append ( currentStack , middlewareName ) ) , nil
}
func ( b * Builder ) buildConstructor ( ctx context . Context , middlewareName string ) ( tcp . Constructor , error ) {
config := b . configs [ middlewareName ]
if config == nil || config . TCPMiddleware == nil {
return nil , fmt . Errorf ( "invalid middleware %q configuration" , middlewareName )
}
var middleware tcp . Constructor
2021-11-29 17:12:06 +01:00
// InFlightConn
if config . InFlightConn != nil {
middleware = func ( next tcp . Handler ) ( tcp . Handler , error ) {
return inflightconn . New ( ctx , next , * config . InFlightConn , middlewareName )
}
}
2022-10-26 18:16:05 +03:00
// IPAllowList
if config . IPAllowList != nil {
2021-06-11 15:30:05 +02:00
middleware = func ( next tcp . Handler ) ( tcp . Handler , error ) {
2022-10-26 18:16:05 +03:00
return ipallowlist . New ( ctx , next , * config . IPAllowList , middlewareName )
2021-06-11 15:30:05 +02:00
}
}
if middleware == nil {
return nil , fmt . Errorf ( "invalid middleware %q configuration: invalid middleware type or middleware does not exist" , middlewareName )
}
return middleware , nil
}
func inSlice ( element string , stack [ ] string ) bool {
for _ , value := range stack {
if value == element {
return true
}
}
return false
}