traefik/pkg/middlewares/requestdecorator/hostresolver.go

134 lines
3 KiB
Go
Raw Normal View History

2018-11-14 10:18:03 +01:00
package requestdecorator
import (
"context"
"errors"
2018-11-14 10:18:03 +01:00
"fmt"
"net"
"sort"
"strings"
"time"
"github.com/miekg/dns"
"github.com/patrickmn/go-cache"
2022-11-21 18:36:05 +01:00
"github.com/rs/zerolog/log"
2018-11-14 10:18:03 +01:00
)
type cnameResolv struct {
TTL time.Duration
Record string
}
type byTTL []*cnameResolv
func (a byTTL) Len() int { return len(a) }
func (a byTTL) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a byTTL) Less(i, j int) bool { return a[i].TTL > a[j].TTL }
// Resolver used for host resolver.
type Resolver struct {
CnameFlattening bool
ResolvConfig string
ResolvDepth int
cache *cache.Cache
}
// CNAMEFlatten check if CNAME record exists, flatten if possible.
func (hr *Resolver) CNAMEFlatten(ctx context.Context, host string) string {
if hr.cache == nil {
hr.cache = cache.New(30*time.Minute, 5*time.Minute)
}
result := host
request := host
value, found := hr.cache.Get(host)
if found {
return value.(string)
}
2022-11-21 18:36:05 +01:00
logger := log.Ctx(ctx)
2020-07-07 14:42:03 +02:00
cacheDuration := 0 * time.Second
2024-02-19 15:44:03 +01:00
for depth := range hr.ResolvDepth {
2018-11-14 10:18:03 +01:00
resolv, err := cnameResolve(ctx, request, hr.ResolvConfig)
if err != nil {
2022-11-21 18:36:05 +01:00
logger.Error().Err(err).Send()
2018-11-14 10:18:03 +01:00
break
}
if resolv == nil {
break
}
result = resolv.Record
if depth == 0 {
cacheDuration = resolv.TTL
}
request = resolv.Record
}
hr.cache.Set(host, result, cacheDuration)
2018-11-14 10:18:03 +01:00
return result
}
// cnameResolve resolves CNAME if exists, and return with the highest TTL.
2020-07-07 14:42:03 +02:00
func cnameResolve(ctx context.Context, host, resolvPath string) (*cnameResolv, error) {
2018-11-14 10:18:03 +01:00
config, err := dns.ClientConfigFromFile(resolvPath)
if err != nil {
return nil, fmt.Errorf("invalid resolver configuration file: %s", resolvPath)
}
if net.ParseIP(host) != nil {
return nil, nil
}
2018-11-14 10:18:03 +01:00
client := &dns.Client{Timeout: 30 * time.Second}
m := &dns.Msg{}
m.SetQuestion(dns.Fqdn(host), dns.TypeCNAME)
var result []*cnameResolv
for _, server := range config.Servers {
tempRecord, err := getRecord(client, m, server, config.Port)
if err != nil {
if errors.Is(err, errNoCNAMERecord) {
2023-10-11 16:20:26 +02:00
log.Ctx(ctx).Debug().Err(err).Msgf("CNAME lookup for hostname %q", host)
continue
}
2023-10-11 16:20:26 +02:00
log.Ctx(ctx).Error().Err(err).Msgf("CNAME lookup for hostname %q", host)
2018-11-14 10:18:03 +01:00
continue
}
result = append(result, tempRecord)
}
if len(result) == 0 {
2018-11-14 10:18:03 +01:00
return nil, nil
}
sort.Sort(byTTL(result))
return result[0], nil
}
var errNoCNAMERecord = errors.New("no CNAME record for host")
2020-07-07 14:42:03 +02:00
func getRecord(client *dns.Client, msg *dns.Msg, server, port string) (*cnameResolv, error) {
2018-11-14 10:18:03 +01:00
resp, _, err := client.Exchange(msg, net.JoinHostPort(server, port))
if err != nil {
2020-05-11 12:06:07 +02:00
return nil, fmt.Errorf("exchange error for server %s: %w", server, err)
2018-11-14 10:18:03 +01:00
}
if resp == nil || len(resp.Answer) == 0 {
return nil, fmt.Errorf("%w: %s", errNoCNAMERecord, server)
2018-11-14 10:18:03 +01:00
}
rr, ok := resp.Answer[0].(*dns.CNAME)
if !ok {
return nil, fmt.Errorf("invalid response type for server %s", server)
}
return &cnameResolv{
TTL: time.Duration(rr.Hdr.Ttl) * time.Second,
Record: strings.TrimSuffix(rr.Target, "."),
}, nil
}