Update Lego

This commit is contained in:
Ludovic Fernandez 2019-01-07 18:30:06 +01:00 committed by Traefiker Bot
parent fc8c24e987
commit 9b2423aaba
192 changed files with 11105 additions and 8535 deletions

53
Gopkg.lock generated
View file

@ -430,19 +430,19 @@
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:ec6271918b59b872a2d25e374569a4f75f1839d91e4191470c297b7eaaaf7641" digest = "1:f9adc21a937e5da643ea14a3488cb7506788876737a5e205394e508627a6eec8"
name = "github.com/dimchansky/utfbom" name = "github.com/dimchansky/utfbom"
packages = ["."] packages = ["."]
pruneopts = "NUT" pruneopts = "NUT"
revision = "c410c2305b32ec96caea4e24b4ecbf648e2eeb25" revision = "d2133a1ce379ef6fa992b0514a77146c60db9d1c"
[[projects]] [[projects]]
branch = "master" digest = "1:e856fc44ab196970612bdc8c15e65ccf92ed8d4ccb3a2e65b88dc240a2fe5d0b"
digest = "1:e8055cec2992f8bbf63c390aa6b36d78c2f93b13617e4569168c09219f88c6b0"
name = "github.com/dnsimple/dnsimple-go" name = "github.com/dnsimple/dnsimple-go"
packages = ["dnsimple"] packages = ["dnsimple"]
pruneopts = "NUT" pruneopts = "NUT"
revision = "bbe1a2c87affea187478e24d3aea3cac25f870b3" revision = "f5ead9c20763fd925dea1362f2af5d671ed2a459"
version = "v0.21.0"
[[projects]] [[projects]]
digest = "1:cf7cba074c4d2f8e2a5cc2f10b1f6762c86cff2e39917b9f9a6dbd7df57fe9c9" digest = "1:cf7cba074c4d2f8e2a5cc2f10b1f6762c86cff2e39917b9f9a6dbd7df57fe9c9"
@ -1188,12 +1188,12 @@
source = "https://github.com/containous/mesos-dns.git" source = "https://github.com/containous/mesos-dns.git"
[[projects]] [[projects]]
branch = "master" digest = "1:b83995756f9b1a24c518d40052d80f524f0a9024ee0479d8a8e91ec2548074d1"
digest = "1:68cbf3a326abda169b26327b95302db8bcf297a49e689e85746bfc04d1cd8c33"
name = "github.com/miekg/dns" name = "github.com/miekg/dns"
packages = ["."] packages = ["."]
pruneopts = "NUT" pruneopts = "NUT"
revision = "906238edc6eb0ddface4a1923f6d41ef2a5ca59b" revision = "7586a3cbe8ccfc63f82de3ab2ceeb08c9939af72"
version = "v1.1.1"
[[projects]] [[projects]]
branch = "master" branch = "master"
@ -1705,12 +1705,25 @@
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:1c455a900935917ff95c6dfc95461ffc0d7b138a803e9504319f2d9a39e5d6f3" digest = "1:f3f9f7b883b89edc06283285964a8125b15f31c1634a8ccb2860c8c8b3f6231d"
name = "github.com/xenolf/lego" name = "github.com/xenolf/lego"
packages = [ packages = [
"acme", "acme",
"acme/api",
"acme/api/internal/nonces",
"acme/api/internal/secure",
"acme/api/internal/sender",
"certcrypto",
"certificate",
"challenge",
"challenge/dns01",
"challenge/http01",
"challenge/resolver",
"challenge/tlsalpn01",
"lego",
"log", "log",
"platform/config/env", "platform/config/env",
"platform/wait",
"providers/dns", "providers/dns",
"providers/dns/acmedns", "providers/dns/acmedns",
"providers/dns/alidns", "providers/dns/alidns",
@ -1719,10 +1732,13 @@
"providers/dns/bluecat", "providers/dns/bluecat",
"providers/dns/cloudflare", "providers/dns/cloudflare",
"providers/dns/cloudxns", "providers/dns/cloudxns",
"providers/dns/cloudxns/internal",
"providers/dns/conoha", "providers/dns/conoha",
"providers/dns/conoha/internal",
"providers/dns/digitalocean", "providers/dns/digitalocean",
"providers/dns/dnsimple", "providers/dns/dnsimple",
"providers/dns/dnsmadeeasy", "providers/dns/dnsmadeeasy",
"providers/dns/dnsmadeeasy/internal",
"providers/dns/dnspod", "providers/dns/dnspod",
"providers/dns/dreamhost", "providers/dns/dreamhost",
"providers/dns/duckdns", "providers/dns/duckdns",
@ -1746,7 +1762,9 @@
"providers/dns/namecheap", "providers/dns/namecheap",
"providers/dns/namedotcom", "providers/dns/namedotcom",
"providers/dns/netcup", "providers/dns/netcup",
"providers/dns/netcup/internal",
"providers/dns/nifcloud", "providers/dns/nifcloud",
"providers/dns/nifcloud/internal",
"providers/dns/ns1", "providers/dns/ns1",
"providers/dns/otc", "providers/dns/otc",
"providers/dns/ovh", "providers/dns/ovh",
@ -1756,18 +1774,21 @@
"providers/dns/route53", "providers/dns/route53",
"providers/dns/sakuracloud", "providers/dns/sakuracloud",
"providers/dns/selectel", "providers/dns/selectel",
"providers/dns/selectel/internal",
"providers/dns/stackpath", "providers/dns/stackpath",
"providers/dns/transip", "providers/dns/transip",
"providers/dns/vegadns", "providers/dns/vegadns",
"providers/dns/vscale", "providers/dns/vscale",
"providers/dns/vscale/internal",
"providers/dns/vultr", "providers/dns/vultr",
"registration",
] ]
pruneopts = "NUT" pruneopts = "NUT"
revision = "a5f0a3ff8026e05cbdd11c391c0e25122497c736" revision = "43401f2475dd1f6cc2e220908f0caba246ea854e"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:d25ae0946d8bebad296b1a167cd62b92c3079e90c59f91657f69da08fe0cf76c" digest = "1:30c1930f8c9fee79f3af60c8b7cd92edd12a4f22187f5527d53509b1a794f555"
name = "golang.org/x/crypto" name = "golang.org/x/crypto"
packages = [ packages = [
"bcrypt", "bcrypt",
@ -1782,7 +1803,7 @@
"ssh/terminal", "ssh/terminal",
] ]
pruneopts = "NUT" pruneopts = "NUT"
revision = "91a49db82a88618983a78a06c1cbd4e00ab749ab" revision = "505ab145d0a99da450461ae2c1a9f6cd10d1f447"
[[projects]] [[projects]]
branch = "master" branch = "master"
@ -2343,8 +2364,16 @@
"github.com/vulcand/oxy/roundrobin", "github.com/vulcand/oxy/roundrobin",
"github.com/vulcand/oxy/utils", "github.com/vulcand/oxy/utils",
"github.com/xenolf/lego/acme", "github.com/xenolf/lego/acme",
"github.com/xenolf/lego/certcrypto",
"github.com/xenolf/lego/certificate",
"github.com/xenolf/lego/challenge",
"github.com/xenolf/lego/challenge/dns01",
"github.com/xenolf/lego/challenge/http01",
"github.com/xenolf/lego/challenge/tlsalpn01",
"github.com/xenolf/lego/lego",
"github.com/xenolf/lego/log", "github.com/xenolf/lego/log",
"github.com/xenolf/lego/providers/dns", "github.com/xenolf/lego/providers/dns",
"github.com/xenolf/lego/registration",
"golang.org/x/net/http/httpguts", "golang.org/x/net/http/httpguts",
"golang.org/x/net/http2", "golang.org/x/net/http2",
"golang.org/x/net/http2/hpack", "golang.org/x/net/http2/hpack",

View file

@ -246,7 +246,7 @@
revision = "7e6055773c5137efbeb3bd2410d705fe10ab6bfd" revision = "7e6055773c5137efbeb3bd2410d705fe10ab6bfd"
[[override]] [[override]]
branch = "master" version = "v1.1.1"
name = "github.com/miekg/dns" name = "github.com/miekg/dns"
[[constraint]] [[constraint]]

View file

@ -18,15 +18,16 @@ import (
"github.com/containous/traefik/log" "github.com/containous/traefik/log"
acmeprovider "github.com/containous/traefik/provider/acme" acmeprovider "github.com/containous/traefik/provider/acme"
"github.com/containous/traefik/types" "github.com/containous/traefik/types"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/certcrypto"
"github.com/xenolf/lego/registration"
) )
// Account is used to store lets encrypt registration info // Account is used to store lets encrypt registration info
type Account struct { type Account struct {
Email string Email string
Registration *acme.RegistrationResource Registration *registration.Resource
PrivateKey []byte PrivateKey []byte
KeyType acme.KeyType KeyType certcrypto.KeyType
DomainsCertificate DomainsCertificates DomainsCertificate DomainsCertificates
ChallengeCerts map[string]*ChallengeCert ChallengeCerts map[string]*ChallengeCert
HTTPChallenge map[string]map[string][]byte HTTPChallenge map[string]map[string][]byte
@ -101,7 +102,7 @@ func (a *Account) GetEmail() string {
} }
// GetRegistration returns lets encrypt registration resource // GetRegistration returns lets encrypt registration resource
func (a *Account) GetRegistration() *acme.RegistrationResource { func (a *Account) GetRegistration() *registration.Resource {
return a.Registration return a.Registration
} }

View file

@ -17,7 +17,6 @@ import (
"github.com/BurntSushi/ty/fun" "github.com/BurntSushi/ty/fun"
"github.com/cenk/backoff" "github.com/cenk/backoff"
"github.com/containous/flaeg"
"github.com/containous/mux" "github.com/containous/mux"
"github.com/containous/staert" "github.com/containous/staert"
"github.com/containous/traefik/cluster" "github.com/containous/traefik/cluster"
@ -27,9 +26,14 @@ import (
"github.com/containous/traefik/types" "github.com/containous/traefik/types"
"github.com/containous/traefik/version" "github.com/containous/traefik/version"
"github.com/eapache/channels" "github.com/eapache/channels"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/certificate"
"github.com/xenolf/lego/challenge"
"github.com/xenolf/lego/challenge/dns01"
"github.com/xenolf/lego/challenge/http01"
"github.com/xenolf/lego/lego"
legolog "github.com/xenolf/lego/log" legolog "github.com/xenolf/lego/log"
"github.com/xenolf/lego/providers/dns" "github.com/xenolf/lego/providers/dns"
"github.com/xenolf/lego/registration"
) )
var ( var (
@ -53,7 +57,7 @@ type ACME struct {
TLSChallenge *acmeprovider.TLSChallenge `description:"Activate TLS-ALPN-01 Challenge"` TLSChallenge *acmeprovider.TLSChallenge `description:"Activate TLS-ALPN-01 Challenge"`
ACMELogging bool `description:"Enable debug logging of ACME actions."` ACMELogging bool `description:"Enable debug logging of ACME actions."`
OverrideCertificates bool `description:"Enable to override certificates in key-value store when using storeconfig"` OverrideCertificates bool `description:"Enable to override certificates in key-value store when using storeconfig"`
client *acme.Client client *lego.Client
store cluster.Store store cluster.Store
challengeHTTPProvider *challengeHTTPProvider challengeHTTPProvider *challengeHTTPProvider
challengeTLSProvider *challengeTLSProvider challengeTLSProvider *challengeTLSProvider
@ -66,8 +70,6 @@ type ACME struct {
} }
func (a *ACME) init() error { func (a *ACME) init() error {
acme.UserAgent = fmt.Sprintf("containous-traefik/%s", version.Version)
if a.ACMELogging { if a.ACMELogging {
legolog.Logger = log.WithoutContext() legolog.Logger = log.WithoutContext()
} else { } else {
@ -85,7 +87,7 @@ func (a *ACME) init() error {
// AddRoutes add routes on internal router // AddRoutes add routes on internal router
func (a *ACME) AddRoutes(router *mux.Router) { func (a *ACME) AddRoutes(router *mux.Router) {
router.Methods(http.MethodGet). router.Methods(http.MethodGet).
Path(acme.HTTP01ChallengePath("{token}")). Path(http01.ChallengePath("{token}")).
Handler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { Handler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if a.challengeHTTPProvider == nil { if a.challengeHTTPProvider == nil {
rw.WriteHeader(http.StatusNotFound) rw.WriteHeader(http.StatusNotFound)
@ -218,7 +220,7 @@ func (a *ACME) leadershipListener(elected bool) error {
// New users will need to register; be sure to save it // New users will need to register; be sure to save it
log.Debug("Register...") log.Debug("Register...")
reg, err := a.client.Register(true) reg, err := a.client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
if err != nil { if err != nil {
return err return err
} }
@ -363,7 +365,7 @@ func (a *ACME) renewCertificates() {
} }
func (a *ACME) renewACMECertificate(certificateResource *DomainsCertificate) (*Certificate, error) { func (a *ACME) renewACMECertificate(certificateResource *DomainsCertificate) (*Certificate, error) {
renewedCert, err := a.client.RenewCertificate(acme.CertificateResource{ renewedCert, err := a.client.Certificate.Renew(certificate.Resource{
Domain: certificateResource.Certificate.Domain, Domain: certificateResource.Certificate.Domain,
CertURL: certificateResource.Certificate.CertURL, CertURL: certificateResource.Certificate.CertURL,
CertStableURL: certificateResource.Certificate.CertStableURL, CertStableURL: certificateResource.Certificate.CertStableURL,
@ -412,28 +414,19 @@ func (a *ACME) storeRenewedCertificate(certificateResource *DomainsCertificate,
return nil return nil
} }
func dnsOverrideDelay(delay flaeg.Duration) error { func (a *ACME) buildACMEClient(account *Account) (*lego.Client, error) {
var err error
if delay > 0 {
log.Debugf("Delaying %d rather than validating DNS propagation", delay)
acme.PreCheckDNS = func(_, _ string) (bool, error) {
time.Sleep(time.Duration(delay))
return true, nil
}
} else if delay < 0 {
err = fmt.Errorf("invalid negative DelayBeforeCheck: %d", delay)
}
return err
}
func (a *ACME) buildACMEClient(account *Account) (*acme.Client, error) {
log.Debug("Building ACME client...") log.Debug("Building ACME client...")
caServer := "https://acme-v02.api.letsencrypt.org/directory" caServer := "https://acme-v02.api.letsencrypt.org/directory"
if len(a.CAServer) > 0 { if len(a.CAServer) > 0 {
caServer = a.CAServer caServer = a.CAServer
} }
client, err := acme.NewClient(caServer, account, account.KeyType) config := lego.NewConfig(account)
config.CADirURL = caServer
config.KeyType = account.KeyType
config.UserAgent = fmt.Sprintf("containous-traefik/%s", version.Version)
client, err := lego.NewClient(config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -442,22 +435,23 @@ func (a *ACME) buildACMEClient(account *Account) (*acme.Client, error) {
if a.DNSChallenge != nil && len(a.DNSChallenge.Provider) > 0 { if a.DNSChallenge != nil && len(a.DNSChallenge.Provider) > 0 {
log.Debugf("Using DNS Challenge provider: %s", a.DNSChallenge.Provider) log.Debugf("Using DNS Challenge provider: %s", a.DNSChallenge.Provider)
err = dnsOverrideDelay(a.DNSChallenge.DelayBeforeCheck) var provider challenge.Provider
if err != nil {
return nil, err
}
acmeprovider.SetRecursiveNameServers(a.DNSChallenge.Resolvers)
acmeprovider.SetPropagationCheck(a.DNSChallenge.DisablePropagationCheck)
var provider acme.ChallengeProvider
provider, err = dns.NewDNSChallengeProviderByName(a.DNSChallenge.Provider) provider, err = dns.NewDNSChallengeProviderByName(a.DNSChallenge.Provider)
if err != nil { if err != nil {
return nil, err return nil, err
} }
client.ExcludeChallenges([]acme.Challenge{acme.HTTP01, acme.TLSALPN01}) err = client.Challenge.SetDNS01Provider(provider,
err = client.SetChallengeProvider(acme.DNS01, provider) dns01.CondOption(len(a.DNSChallenge.Resolvers) > 0, dns01.AddRecursiveNameservers(a.DNSChallenge.Resolvers)),
dns01.CondOption(a.DNSChallenge.DisablePropagationCheck || a.DNSChallenge.DelayBeforeCheck > 0,
dns01.AddPreCheck(func(_, _ string) (bool, error) {
if a.DNSChallenge.DelayBeforeCheck > 0 {
log.Debugf("Delaying %d rather than validating DNS propagation now.", a.DNSChallenge.DelayBeforeCheck)
time.Sleep(time.Duration(a.DNSChallenge.DelayBeforeCheck))
}
return true, nil
})),
)
return client, err return client, err
} }
@ -465,17 +459,16 @@ func (a *ACME) buildACMEClient(account *Account) (*acme.Client, error) {
if a.HTTPChallenge != nil && len(a.HTTPChallenge.EntryPoint) > 0 { if a.HTTPChallenge != nil && len(a.HTTPChallenge.EntryPoint) > 0 {
log.Debug("Using HTTP Challenge provider.") log.Debug("Using HTTP Challenge provider.")
client.ExcludeChallenges([]acme.Challenge{acme.DNS01, acme.TLSALPN01})
a.challengeHTTPProvider = &challengeHTTPProvider{store: a.store} a.challengeHTTPProvider = &challengeHTTPProvider{store: a.store}
err = client.SetChallengeProvider(acme.HTTP01, a.challengeHTTPProvider) err = client.Challenge.SetHTTP01Provider(a.challengeHTTPProvider)
return client, err return client, err
} }
// TLS Challenge // TLS Challenge
if a.TLSChallenge != nil { if a.TLSChallenge != nil {
log.Debug("Using TLS Challenge provider.") log.Debug("Using TLS Challenge provider.")
client.ExcludeChallenges([]acme.Challenge{acme.HTTP01, acme.DNS01})
err = client.SetChallengeProvider(acme.TLSALPN01, a.challengeTLSProvider) err = client.Challenge.SetTLSALPN01Provider(a.challengeTLSProvider)
return client, err return client, err
} }
@ -547,7 +540,7 @@ func (a *ACME) LoadCertificateForDomains(domains []string) {
a.addResolvingDomains(uncheckedDomains) a.addResolvingDomains(uncheckedDomains)
defer a.removeResolvingDomains(uncheckedDomains) defer a.removeResolvingDomains(uncheckedDomains)
certificate, err := a.getDomainsCertificates(uncheckedDomains) cert, err := a.getDomainsCertificates(uncheckedDomains)
if err != nil { if err != nil {
log.Errorf("Error getting ACME certificates %+v : %v", uncheckedDomains, err) log.Errorf("Error getting ACME certificates %+v : %v", uncheckedDomains, err)
return return
@ -566,7 +559,7 @@ func (a *ACME) LoadCertificateForDomains(domains []string) {
domain = types.Domain{Main: uncheckedDomains[0]} domain = types.Domain{Main: uncheckedDomains[0]}
} }
account = object.(*Account) account = object.(*Account)
_, err = account.DomainsCertificate.addCertificateForDomains(certificate, domain) _, err = account.DomainsCertificate.addCertificateForDomains(cert, domain)
if err != nil { if err != nil {
log.Errorf("Error adding ACME certificates %+v : %v", uncheckedDomains, err) log.Errorf("Error adding ACME certificates %+v : %v", uncheckedDomains, err)
return return
@ -694,7 +687,7 @@ func (a *ACME) getDomainsCertificates(domains []string) (*Certificate, error) {
var cleanDomains []string var cleanDomains []string
for _, domain := range domains { for _, domain := range domains {
canonicalDomain := types.CanonicalDomain(domain) canonicalDomain := types.CanonicalDomain(domain)
cleanDomain := acme.UnFqdn(canonicalDomain) cleanDomain := dns01.UnFqdn(canonicalDomain)
if canonicalDomain != cleanDomain { if canonicalDomain != cleanDomain {
log.Warnf("FQDN detected, please remove the trailing dot: %s", canonicalDomain) log.Warnf("FQDN detected, please remove the trailing dot: %s", canonicalDomain)
} }
@ -704,18 +697,24 @@ func (a *ACME) getDomainsCertificates(domains []string) (*Certificate, error) {
log.Debugf("Loading ACME certificates %s...", cleanDomains) log.Debugf("Loading ACME certificates %s...", cleanDomains)
bundle := true bundle := true
certificate, err := a.client.ObtainCertificate(cleanDomains, bundle, nil, OSCPMustStaple) request := certificate.ObtainRequest{
Domains: cleanDomains,
Bundle: bundle,
MustStaple: OSCPMustStaple,
}
cert, err := a.client.Certificate.Obtain(request)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot obtain certificates: %+v", err) return nil, fmt.Errorf("cannot obtain certificates: %+v", err)
} }
log.Debugf("Loaded ACME certificates %s", cleanDomains) log.Debugf("Loaded ACME certificates %s", cleanDomains)
return &Certificate{ return &Certificate{
Domain: certificate.Domain, Domain: cert.Domain,
CertURL: certificate.CertURL, CertURL: cert.CertURL,
CertStableURL: certificate.CertStableURL, CertStableURL: cert.CertStableURL,
PrivateKey: certificate.PrivateKey, PrivateKey: cert.PrivateKey,
Certificate: certificate.Certificate, Certificate: cert.Certificate,
}, nil }, nil
} }

View file

@ -15,7 +15,6 @@ import (
"github.com/containous/traefik/tls/generate" "github.com/containous/traefik/tls/generate"
"github.com/containous/traefik/types" "github.com/containous/traefik/types"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/xenolf/lego/acme"
) )
func TestDomainsSet(t *testing.T) { func TestDomainsSet(t *testing.T) {
@ -258,39 +257,10 @@ func TestRemoveDuplicates(t *testing.T) {
} }
} }
func TestNoPreCheckOverride(t *testing.T) {
acme.PreCheckDNS = nil // Irreversable - but not expecting real calls into this during testing process
err := dnsOverrideDelay(0)
if err != nil {
t.Errorf("Error in dnsOverrideDelay :%v", err)
}
if acme.PreCheckDNS != nil {
t.Error("Unexpected change to acme.PreCheckDNS when leaving DNS verification as is.")
}
}
func TestSillyPreCheckOverride(t *testing.T) {
err := dnsOverrideDelay(-5)
if err == nil {
t.Error("Missing expected error in dnsOverrideDelay!")
}
}
func TestPreCheckOverride(t *testing.T) {
acme.PreCheckDNS = nil // Irreversable - but not expecting real calls into this during testing process
err := dnsOverrideDelay(5)
if err != nil {
t.Errorf("Error in dnsOverrideDelay :%v", err)
}
if acme.PreCheckDNS == nil {
t.Error("No change to acme.PreCheckDNS when meant to be adding enforcing override function.")
}
}
func TestAcmeClientCreation(t *testing.T) { func TestAcmeClientCreation(t *testing.T) {
acme.PreCheckDNS = nil // Irreversable - but not expecting real calls into this during testing process
// Lengthy setup to avoid external web requests - oh for easier golang testing! // Lengthy setup to avoid external web requests - oh for easier golang testing!
account := &Account{Email: "f@f"} account := &Account{Email: "f@f"}
account.PrivateKey, _ = base64.StdEncoding.DecodeString(` account.PrivateKey, _ = base64.StdEncoding.DecodeString(`
MIIBPAIBAAJBAMp2Ni92FfEur+CAvFkgC12LT4l9D53ApbBpDaXaJkzzks+KsLw9zyAxvlrfAyTCQ MIIBPAIBAAJBAMp2Ni92FfEur+CAvFkgC12LT4l9D53ApbBpDaXaJkzzks+KsLw9zyAxvlrfAyTCQ
7tDnEnIltAXyQ0uOFUUdcMCAwEAAQJAK1FbipATZcT9cGVa5x7KD7usytftLW14heQUPXYNV80r/3 7tDnEnIltAXyQ0uOFUUdcMCAwEAAQJAK1FbipATZcT9cGVa5x7KD7usytftLW14heQUPXYNV80r/3
@ -298,8 +268,9 @@ lmnpvjL06dffRpwkYeN8DATQF/QOcy3NNNGDw/4QIhAPAKmiZFxA/qmRXsuU8Zhlzf16WrNZ68K64
asn/h3qZrAiEA1+wFR3WXCPIolOvd7AHjfgcTKQNkoMPywU4FYUNQ1AkCIQDv8yk0qPjckD6HVCPJ asn/h3qZrAiEA1+wFR3WXCPIolOvd7AHjfgcTKQNkoMPywU4FYUNQ1AkCIQDv8yk0qPjckD6HVCPJ
llJh9MC0svjevGtNlxJoE3lmEQIhAKXy1wfZ32/XtcrnENPvi6lzxI0T94X7s5pP3aCoPPoJAiEAl llJh9MC0svjevGtNlxJoE3lmEQIhAKXy1wfZ32/XtcrnENPvi6lzxI0T94X7s5pP3aCoPPoJAiEAl
cijFkALeQp/qyeXdFld2v9gUN3eCgljgcl0QweRoIc=---`) cijFkALeQp/qyeXdFld2v9gUN3eCgljgcl0QweRoIc=---`)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{ _, err := w.Write([]byte(`{
"GPHhmRVEDas": "https://community.letsencrypt.org/t/adding-random-entries-to-the-directory/33417", "GPHhmRVEDas": "https://community.letsencrypt.org/t/adding-random-entries-to-the-directory/33417",
"keyChange": "https://foo/acme/key-change", "keyChange": "https://foo/acme/key-change",
"meta": { "meta": {
@ -310,9 +281,20 @@ cijFkALeQp/qyeXdFld2v9gUN3eCgljgcl0QweRoIc=---`)
"newOrder": "https://foo/acme/new-order", "newOrder": "https://foo/acme/new-order",
"revokeCert": "https://foo/acme/revoke-cert" "revokeCert": "https://foo/acme/revoke-cert"
}`)) }`))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
})) }))
defer ts.Close() defer ts.Close()
a := ACME{DNSChallenge: &acmeprovider.DNSChallenge{Provider: "manual", DelayBeforeCheck: 10}, CAServer: ts.URL}
a := ACME{
CAServer: ts.URL,
DNSChallenge: &acmeprovider.DNSChallenge{
Provider: "manual",
DelayBeforeCheck: 10,
DisablePropagationCheck: true,
},
}
client, err := a.buildACMEClient(account) client, err := a.buildACMEClient(account)
if err != nil { if err != nil {
@ -321,9 +303,6 @@ cijFkALeQp/qyeXdFld2v9gUN3eCgljgcl0QweRoIc=---`)
if client == nil { if client == nil {
t.Error("No client from buildACMEClient!") t.Error("No client from buildACMEClient!")
} }
if acme.PreCheckDNS == nil {
t.Error("No change to acme.PreCheckDNS when meant to be adding enforcing override function.")
}
} }
func TestAcme_getUncheckedCertificates(t *testing.T) { func TestAcme_getUncheckedCertificates(t *testing.T) {

View file

@ -9,10 +9,10 @@ import (
"github.com/containous/traefik/cluster" "github.com/containous/traefik/cluster"
"github.com/containous/traefik/log" "github.com/containous/traefik/log"
"github.com/containous/traefik/safe" "github.com/containous/traefik/safe"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/challenge"
) )
var _ acme.ChallengeProviderTimeout = (*challengeHTTPProvider)(nil) var _ challenge.ProviderTimeout = (*challengeHTTPProvider)(nil)
type challengeHTTPProvider struct { type challengeHTTPProvider struct {
store cluster.Store store cluster.Store

View file

@ -11,10 +11,11 @@ import (
"github.com/containous/traefik/cluster" "github.com/containous/traefik/cluster"
"github.com/containous/traefik/log" "github.com/containous/traefik/log"
"github.com/containous/traefik/safe" "github.com/containous/traefik/safe"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/challenge"
"github.com/xenolf/lego/challenge/tlsalpn01"
) )
var _ acme.ChallengeProviderTimeout = (*challengeTLSProvider)(nil) var _ challenge.ProviderTimeout = (*challengeTLSProvider)(nil)
type challengeTLSProvider struct { type challengeTLSProvider struct {
store cluster.Store store cluster.Store
@ -113,7 +114,7 @@ func (c *challengeTLSProvider) Timeout() (timeout, interval time.Duration) {
} }
func tlsALPN01ChallengeCert(domain, keyAuth string) (*ChallengeCert, error) { func tlsALPN01ChallengeCert(domain, keyAuth string) (*ChallengeCert, error) {
tempCertPEM, rsaPrivPEM, err := acme.TLSALPNChallengeBlocks(domain, keyAuth) tempCertPEM, rsaPrivPEM, err := tlsalpn01.ChallengeBlocks(domain, keyAuth)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -31,7 +31,7 @@ import (
"github.com/containous/traefik/tracing/zipkin" "github.com/containous/traefik/tracing/zipkin"
"github.com/containous/traefik/types" "github.com/containous/traefik/types"
"github.com/elazarl/go-bindata-assetfs" "github.com/elazarl/go-bindata-assetfs"
lego "github.com/xenolf/lego/acme" "github.com/xenolf/lego/challenge/dns01"
) )
const ( const (
@ -374,11 +374,11 @@ func convertACMEChallenge(oldACMEChallenge *acme.ACME) *acmeprovider.Configurati
} }
for _, domain := range oldACMEChallenge.Domains { for _, domain := range oldACMEChallenge.Domains {
if domain.Main != lego.UnFqdn(domain.Main) { if domain.Main != dns01.UnFqdn(domain.Main) {
log.Warnf("FQDN detected, please remove the trailing dot: %s", domain.Main) log.Warnf("FQDN detected, please remove the trailing dot: %s", domain.Main)
} }
for _, san := range domain.SANs { for _, san := range domain.SANs {
if san != lego.UnFqdn(san) { if san != dns01.UnFqdn(san) {
log.Warnf("FQDN detected, please remove the trailing dot: %s", san) log.Warnf("FQDN detected, please remove the trailing dot: %s", san)
} }
} }

View file

@ -332,6 +332,14 @@ Here is a list of supported `provider`s, that can automate the DNS verification,
Use custom DNS servers to resolve the FQDN authority. Use custom DNS servers to resolve the FQDN authority.
```toml
[acme]
# ...
[acme.dnsChallenge]
# ...
resolvers = ["1.1.1.1:53", "8.8.8.8:53"]
```
### `domains` ### `domains`
You can provide SANs (alternative domains) to each main domain. You can provide SANs (alternative domains) to each main domain.

View file

@ -34,7 +34,7 @@ import (
acmeprovider "github.com/containous/traefik/provider/acme" acmeprovider "github.com/containous/traefik/provider/acme"
newtypes "github.com/containous/traefik/types" newtypes "github.com/containous/traefik/types"
"github.com/pkg/errors" "github.com/pkg/errors"
lego "github.com/xenolf/lego/acme" "github.com/xenolf/lego/challenge/dns01"
) )
const ( const (
@ -414,11 +414,11 @@ func convertACMEChallenge(oldACMEChallenge *acme.ACME) *acmeprovider.Configurati
} }
for _, domain := range oldACMEChallenge.Domains { for _, domain := range oldACMEChallenge.Domains {
if domain.Main != lego.UnFqdn(domain.Main) { if domain.Main != dns01.UnFqdn(domain.Main) {
log.Warnf("FQDN detected, please remove the trailing dot: %s", domain.Main) log.Warnf("FQDN detected, please remove the trailing dot: %s", domain.Main)
} }
for _, san := range domain.SANs { for _, san := range domain.SANs {
if san != lego.UnFqdn(san) { if san != dns01.UnFqdn(san) {
log.Warnf("FQDN detected, please remove the trailing dot: %s", san) log.Warnf("FQDN detected, please remove the trailing dot: %s", san)
} }
} }

View file

@ -8,15 +8,16 @@ import (
"crypto/x509" "crypto/x509"
"github.com/containous/traefik/log" "github.com/containous/traefik/log"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/certcrypto"
"github.com/xenolf/lego/registration"
) )
// Account is used to store lets encrypt registration info // Account is used to store lets encrypt registration info
type Account struct { type Account struct {
Email string Email string
Registration *acme.RegistrationResource Registration *registration.Resource
PrivateKey []byte PrivateKey []byte
KeyType acme.KeyType KeyType certcrypto.KeyType
} }
const ( const (
@ -47,7 +48,7 @@ func (a *Account) GetEmail() string {
} }
// GetRegistration returns lets encrypt registration resource // GetRegistration returns lets encrypt registration resource
func (a *Account) GetRegistration() *acme.RegistrationResource { func (a *Account) GetRegistration() *registration.Resource {
return a.Registration return a.Registration
} }
@ -64,25 +65,25 @@ func (a *Account) GetPrivateKey() crypto.PrivateKey {
} }
// GetKeyType used to determine which algo to used // GetKeyType used to determine which algo to used
func GetKeyType(ctx context.Context, value string) acme.KeyType { func GetKeyType(ctx context.Context, value string) certcrypto.KeyType {
logger := log.FromContext(ctx) logger := log.FromContext(ctx)
switch value { switch value {
case "EC256": case "EC256":
return acme.EC256 return certcrypto.EC256
case "EC384": case "EC384":
return acme.EC384 return certcrypto.EC384
case "RSA2048": case "RSA2048":
return acme.RSA2048 return certcrypto.RSA2048
case "RSA4096": case "RSA4096":
return acme.RSA4096 return certcrypto.RSA4096
case "RSA8192": case "RSA8192":
return acme.RSA8192 return certcrypto.RSA8192
case "": case "":
logger.Infof("The key type is empty. Use default key type %v.", acme.RSA4096) logger.Infof("The key type is empty. Use default key type %v.", certcrypto.RSA4096)
return acme.RSA4096 return certcrypto.RSA4096
default: default:
logger.Infof("Unable to determine the key type value %q: falling back on %v.", value, acme.RSA4096) logger.Infof("Unable to determine the key type value %q: falling back on %v.", value, certcrypto.RSA4096)
return acme.RSA4096 return certcrypto.RSA4096
} }
} }

View file

@ -10,10 +10,11 @@ import (
"github.com/containous/mux" "github.com/containous/mux"
"github.com/containous/traefik/log" "github.com/containous/traefik/log"
"github.com/containous/traefik/safe" "github.com/containous/traefik/safe"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/challenge"
"github.com/xenolf/lego/challenge/http01"
) )
var _ acme.ChallengeProviderTimeout = (*challengeHTTP)(nil) var _ challenge.ProviderTimeout = (*challengeHTTP)(nil)
type challengeHTTP struct { type challengeHTTP struct {
Store Store Store Store
@ -37,7 +38,7 @@ func (c *challengeHTTP) Timeout() (timeout, interval time.Duration) {
// Append adds routes on internal router // Append adds routes on internal router
func (p *Provider) Append(router *mux.Router) { func (p *Provider) Append(router *mux.Router) {
router.Methods(http.MethodGet). router.Methods(http.MethodGet).
Path(acme.HTTP01ChallengePath("{token}")). Path(http01.ChallengePath("{token}")).
Handler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { Handler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
vars := mux.Vars(req) vars := mux.Vars(req)

View file

@ -5,10 +5,11 @@ import (
"github.com/containous/traefik/log" "github.com/containous/traefik/log"
"github.com/containous/traefik/types" "github.com/containous/traefik/types"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/challenge"
"github.com/xenolf/lego/challenge/tlsalpn01"
) )
var _ acme.ChallengeProvider = (*challengeTLSALPN)(nil) var _ challenge.Provider = (*challengeTLSALPN)(nil)
type challengeTLSALPN struct { type challengeTLSALPN struct {
Store Store Store Store
@ -18,7 +19,7 @@ func (c *challengeTLSALPN) Present(domain, token, keyAuth string) error {
log.WithoutContext().WithField(log.ProviderName, "acme"). log.WithoutContext().WithField(log.ProviderName, "acme").
Debugf("TLS Challenge Present temp certificate for %s", domain) Debugf("TLS Challenge Present temp certificate for %s", domain)
certPEMBlock, keyPEMBlock, err := acme.TLSALPNChallengeBlocks(domain, keyAuth) certPEMBlock, keyPEMBlock, err := tlsalpn01.ChallengeBlocks(domain, keyAuth)
if err != nil { if err != nil {
return err return err
} }

View file

@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
fmtlog "log" fmtlog "log"
"net"
"net/url" "net/url"
"reflect" "reflect"
"strings" "strings"
@ -25,9 +24,13 @@ import (
"github.com/containous/traefik/version" "github.com/containous/traefik/version"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/certificate"
"github.com/xenolf/lego/challenge"
"github.com/xenolf/lego/challenge/dns01"
"github.com/xenolf/lego/lego"
legolog "github.com/xenolf/lego/log" legolog "github.com/xenolf/lego/log"
"github.com/xenolf/lego/providers/dns" "github.com/xenolf/lego/providers/dns"
"github.com/xenolf/lego/registration"
) )
var ( var (
@ -83,7 +86,7 @@ type Provider struct {
Store Store Store Store
certificates []*Certificate certificates []*Certificate
account *Account account *Account
client *acme.Client client *lego.Client
certsChan chan *Certificate certsChan chan *Certificate
configurationChan chan<- config.Message configurationChan chan<- config.Message
certificateStore *traefiktls.CertificateStore certificateStore *traefiktls.CertificateStore
@ -118,9 +121,9 @@ func (p *Provider) ListenRequest(domain string) (*tls.Certificate, error) {
return nil, err return nil, err
} }
certificate, err := tls.X509KeyPair(acmeCert.Certificate, acmeCert.PrivateKey) cert, err := tls.X509KeyPair(acmeCert.Certificate, acmeCert.PrivateKey)
return &certificate, err return &cert, err
} }
// Init for compatibility reason the BaseProvider implements an empty Init // Init for compatibility reason the BaseProvider implements an empty Init
@ -128,8 +131,6 @@ func (p *Provider) Init() error {
ctx := log.With(context.Background(), log.Str(log.ProviderName, "acme")) ctx := log.With(context.Background(), log.Str(log.ProviderName, "acme"))
logger := log.FromContext(ctx) logger := log.FromContext(ctx)
acme.UserAgent = fmt.Sprintf("containous-traefik/%s", version.Version)
if p.ACMELogging { if p.ACMELogging {
legolog.Logger = fmtlog.New(logger.WriterLevel(logrus.InfoLevel), "legolog: ", 0) legolog.Logger = fmtlog.New(logger.WriterLevel(logrus.InfoLevel), "legolog: ", 0)
} else { } else {
@ -223,7 +224,7 @@ func (p *Provider) Provide(configurationChan chan<- config.Message, pool *safe.P
return nil return nil
} }
func (p *Provider) getClient() (*acme.Client, error) { func (p *Provider) getClient() (*lego.Client, error) {
p.clientMutex.Lock() p.clientMutex.Lock()
defer p.clientMutex.Unlock() defer p.clientMutex.Unlock()
@ -247,7 +248,12 @@ func (p *Provider) getClient() (*acme.Client, error) {
} }
logger.Debug(caServer) logger.Debug(caServer)
client, err := acme.NewClient(caServer, account, account.KeyType) config := lego.NewConfig(account)
config.CADirURL = caServer
config.KeyType = account.KeyType
config.UserAgent = fmt.Sprintf("containous-traefik/%s", version.Version)
client, err := lego.NewClient(config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -256,7 +262,7 @@ func (p *Provider) getClient() (*acme.Client, error) {
if account.GetRegistration() == nil { if account.GetRegistration() == nil {
logger.Info("Register...") logger.Info("Register...")
reg, errR := client.Register(true) reg, errR := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
if errR != nil { if errR != nil {
return nil, errR return nil, errR
} }
@ -274,23 +280,23 @@ func (p *Provider) getClient() (*acme.Client, error) {
if p.DNSChallenge != nil && len(p.DNSChallenge.Provider) > 0 { if p.DNSChallenge != nil && len(p.DNSChallenge.Provider) > 0 {
logger.Debugf("Using DNS Challenge provider: %s", p.DNSChallenge.Provider) logger.Debugf("Using DNS Challenge provider: %s", p.DNSChallenge.Provider)
SetRecursiveNameServers(p.DNSChallenge.Resolvers) var provider challenge.Provider
SetPropagationCheck(p.DNSChallenge.DisablePropagationCheck)
err = dnsOverrideDelay(ctx, p.DNSChallenge.DelayBeforeCheck)
if err != nil {
return nil, err
}
var provider acme.ChallengeProvider
provider, err = dns.NewDNSChallengeProviderByName(p.DNSChallenge.Provider) provider, err = dns.NewDNSChallengeProviderByName(p.DNSChallenge.Provider)
if err != nil { if err != nil {
return nil, err return nil, err
} }
client.ExcludeChallenges([]acme.Challenge{acme.HTTP01, acme.TLSALPN01}) err = client.Challenge.SetDNS01Provider(provider,
dns01.CondOption(len(p.DNSChallenge.Resolvers) > 0, dns01.AddRecursiveNameservers(p.DNSChallenge.Resolvers)),
err = client.SetChallengeProvider(acme.DNS01, provider) dns01.CondOption(p.DNSChallenge.DisablePropagationCheck || p.DNSChallenge.DelayBeforeCheck > 0,
dns01.AddPreCheck(func(_, _ string) (bool, error) {
if p.DNSChallenge.DelayBeforeCheck > 0 {
log.Debugf("Delaying %d rather than validating DNS propagation now.", p.DNSChallenge.DelayBeforeCheck)
time.Sleep(time.Duration(p.DNSChallenge.DelayBeforeCheck))
}
return true, nil
})),
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -300,25 +306,21 @@ func (p *Provider) getClient() (*acme.Client, error) {
p.DNSChallenge.preCheckInterval = 2 * time.Second p.DNSChallenge.preCheckInterval = 2 * time.Second
// Set the precheck timeout into the DNSChallenge provider // Set the precheck timeout into the DNSChallenge provider
if challengeProviderTimeout, ok := provider.(acme.ChallengeProviderTimeout); ok { if challengeProviderTimeout, ok := provider.(challenge.ProviderTimeout); ok {
p.DNSChallenge.preCheckTimeout, p.DNSChallenge.preCheckInterval = challengeProviderTimeout.Timeout() p.DNSChallenge.preCheckTimeout, p.DNSChallenge.preCheckInterval = challengeProviderTimeout.Timeout()
} }
} else if p.HTTPChallenge != nil && len(p.HTTPChallenge.EntryPoint) > 0 { } else if p.HTTPChallenge != nil && len(p.HTTPChallenge.EntryPoint) > 0 {
logger.Debug("Using HTTP Challenge provider.") logger.Debug("Using HTTP Challenge provider.")
client.ExcludeChallenges([]acme.Challenge{acme.DNS01, acme.TLSALPN01}) err = client.Challenge.SetHTTP01Provider(&challengeHTTP{Store: p.Store})
err = client.SetChallengeProvider(acme.HTTP01, &challengeHTTP{Store: p.Store})
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else if p.TLSChallenge != nil { } else if p.TLSChallenge != nil {
logger.Debug("Using TLS Challenge provider.") logger.Debug("Using TLS Challenge provider.")
client.ExcludeChallenges([]acme.Challenge{acme.HTTP01, acme.DNS01}) err = client.Challenge.SetTLSALPN01Provider(&challengeTLSALPN{Store: p.Store})
err = client.SetChallengeProvider(acme.TLSALPN01, &challengeTLSALPN{Store: p.Store})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -383,7 +385,6 @@ func (p *Provider) watchNewDomains(ctx context.Context) {
} }
}) })
} }
} }
case <-stop: case <-stop:
return return
@ -392,7 +393,7 @@ func (p *Provider) watchNewDomains(ctx context.Context) {
}) })
} }
func (p *Provider) resolveCertificate(ctx context.Context, domain types.Domain, domainFromConfigurationFile bool) (*acme.CertificateResource, error) { func (p *Provider) resolveCertificate(ctx context.Context, domain types.Domain, domainFromConfigurationFile bool) (*certificate.Resource, error) {
domains, err := p.getValidDomains(ctx, domain, domainFromConfigurationFile) domains, err := p.getValidDomains(ctx, domain, domainFromConfigurationFile)
if err != nil { if err != nil {
return nil, err return nil, err
@ -415,22 +416,27 @@ func (p *Provider) resolveCertificate(ctx context.Context, domain types.Domain,
return nil, fmt.Errorf("cannot get ACME client %v", err) return nil, fmt.Errorf("cannot get ACME client %v", err)
} }
var certificate *acme.CertificateResource var cert *certificate.Resource
bundle := true bundle := true
if p.useCertificateWithRetry(uncheckedDomains) { if p.useCertificateWithRetry(uncheckedDomains) {
certificate, err = obtainCertificateWithRetry(ctx, domains, client, p.DNSChallenge.preCheckTimeout, p.DNSChallenge.preCheckInterval, bundle) cert, err = obtainCertificateWithRetry(ctx, domains, client, p.DNSChallenge.preCheckTimeout, p.DNSChallenge.preCheckInterval, bundle)
} else { } else {
certificate, err = client.ObtainCertificate(domains, bundle, nil, oscpMustStaple) request := certificate.ObtainRequest{
Domains: domains,
Bundle: bundle,
MustStaple: oscpMustStaple,
}
cert, err = client.Certificate.Obtain(request)
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to generate a certificate for the domains %v: %v", uncheckedDomains, err) return nil, fmt.Errorf("unable to generate a certificate for the domains %v: %v", uncheckedDomains, err)
} }
if certificate == nil { if cert == nil {
return nil, fmt.Errorf("domains %v do not generate a certificate", uncheckedDomains) return nil, fmt.Errorf("domains %v do not generate a certificate", uncheckedDomains)
} }
if len(certificate.Certificate) == 0 || len(certificate.PrivateKey) == 0 { if len(cert.Certificate) == 0 || len(cert.PrivateKey) == 0 {
return nil, fmt.Errorf("domains %v generate certificate with no value: %v", uncheckedDomains, certificate) return nil, fmt.Errorf("domains %v generate certificate with no value: %v", uncheckedDomains, cert)
} }
logger.Debugf("Certificates obtained for domains %+v", uncheckedDomains) logger.Debugf("Certificates obtained for domains %+v", uncheckedDomains)
@ -440,9 +446,9 @@ func (p *Provider) resolveCertificate(ctx context.Context, domain types.Domain,
} else { } else {
domain = types.Domain{Main: uncheckedDomains[0]} domain = types.Domain{Main: uncheckedDomains[0]}
} }
p.addCertificateForDomain(domain, certificate.Certificate, certificate.PrivateKey) p.addCertificateForDomain(domain, cert.Certificate, cert.PrivateKey)
return certificate, nil return cert, nil
} }
func (p *Provider) removeResolvingDomains(resolvingDomains []string) { func (p *Provider) removeResolvingDomains(resolvingDomains []string) {
@ -489,14 +495,19 @@ func (p *Provider) useCertificateWithRetry(domains []string) bool {
return false return false
} }
func obtainCertificateWithRetry(ctx context.Context, domains []string, client *acme.Client, timeout, interval time.Duration, bundle bool) (*acme.CertificateResource, error) { func obtainCertificateWithRetry(ctx context.Context, domains []string, client *lego.Client, timeout, interval time.Duration, bundle bool) (*certificate.Resource, error) {
logger := log.FromContext(ctx) logger := log.FromContext(ctx)
var certificate *acme.CertificateResource var cert *certificate.Resource
var err error var err error
operation := func() error { operation := func() error {
certificate, err = client.ObtainCertificate(domains, bundle, nil, oscpMustStaple) request := certificate.ObtainRequest{
Domains: domains,
Bundle: bundle,
MustStaple: oscpMustStaple,
}
cert, err = client.Certificate.Obtain(request)
return err return err
} }
@ -516,25 +527,7 @@ func obtainCertificateWithRetry(ctx context.Context, domains []string, client *a
return nil, err return nil, err
} }
return certificate, nil return cert, nil
}
func dnsOverrideDelay(ctx context.Context, delay parse.Duration) error {
if delay == 0 {
return nil
}
if delay > 0 {
log.FromContext(ctx).Debugf("Delaying %d rather than validating DNS propagation now.", delay)
acme.PreCheckDNS = func(_, _ string) (bool, error) {
time.Sleep(time.Duration(delay))
return true, nil
}
} else {
return fmt.Errorf("delayBeforeCheck: %d cannot be less than 0", delay)
}
return nil
} }
func (p *Provider) addCertificateForDomain(domain types.Domain, certificate []byte, key []byte) { func (p *Provider) addCertificateForDomain(domain types.Domain, certificate []byte, key []byte) {
@ -649,8 +642,8 @@ func (p *Provider) refreshCertificates() {
} }
for _, cert := range p.certificates { for _, cert := range p.certificates {
certificate := &traefiktls.Certificate{CertFile: traefiktls.FileOrContent(cert.Certificate), KeyFile: traefiktls.FileOrContent(cert.Key)} cert := &traefiktls.Certificate{CertFile: traefiktls.FileOrContent(cert.Certificate), KeyFile: traefiktls.FileOrContent(cert.Key)}
conf.Configuration.TLS = append(conf.Configuration.TLS, &traefiktls.Configuration{Certificate: certificate, EntryPoints: []string{p.EntryPoint}}) conf.Configuration.TLS = append(conf.Configuration.TLS, &traefiktls.Configuration{Certificate: cert, EntryPoints: []string{p.EntryPoint}})
} }
p.configurationChan <- conf p.configurationChan <- conf
} }
@ -659,36 +652,36 @@ func (p *Provider) renewCertificates(ctx context.Context) {
logger := log.FromContext(ctx) logger := log.FromContext(ctx)
logger.Info("Testing certificate renew...") logger.Info("Testing certificate renew...")
for _, certificate := range p.certificates { for _, cert := range p.certificates {
crt, err := getX509Certificate(ctx, certificate) crt, err := getX509Certificate(ctx, cert)
// If there's an error, we assume the cert is broken, and needs update // If there's an error, we assume the cert is broken, and needs update
// <= 30 days left, renew certificate // <= 30 days left, renew certificate
if err != nil || crt == nil || crt.NotAfter.Before(time.Now().Add(24*30*time.Hour)) { if err != nil || crt == nil || crt.NotAfter.Before(time.Now().Add(24*30*time.Hour)) {
client, err := p.getClient() client, err := p.getClient()
if err != nil { if err != nil {
logger.Infof("Error renewing certificate from LE : %+v, %v", certificate.Domain, err) logger.Infof("Error renewing certificate from LE : %+v, %v", cert.Domain, err)
continue continue
} }
logger.Infof("Renewing certificate from LE : %+v", certificate.Domain) logger.Infof("Renewing certificate from LE : %+v", cert.Domain)
renewedCert, err := client.RenewCertificate(acme.CertificateResource{ renewedCert, err := client.Certificate.Renew(certificate.Resource{
Domain: certificate.Domain.Main, Domain: cert.Domain.Main,
PrivateKey: certificate.Key, PrivateKey: cert.Key,
Certificate: certificate.Certificate, Certificate: cert.Certificate,
}, true, oscpMustStaple) }, true, oscpMustStaple)
if err != nil { if err != nil {
logger.Errorf("Error renewing certificate from LE: %v, %v", certificate.Domain, err) logger.Errorf("Error renewing certificate from LE: %v, %v", cert.Domain, err)
continue continue
} }
if len(renewedCert.Certificate) == 0 || len(renewedCert.PrivateKey) == 0 { if len(renewedCert.Certificate) == 0 || len(renewedCert.PrivateKey) == 0 {
logger.Errorf("domains %v renew certificate with no value: %v", certificate.Domain.ToStrArray(), certificate) logger.Errorf("domains %v renew certificate with no value: %v", cert.Domain.ToStrArray(), cert)
continue continue
} }
p.addCertificateForDomain(certificate.Domain, renewedCert.Certificate, renewedCert.PrivateKey) p.addCertificateForDomain(cert.Domain, renewedCert.Certificate, renewedCert.PrivateKey)
} }
} }
} }
@ -704,8 +697,8 @@ func (p *Provider) getUncheckedDomains(ctx context.Context, domainsToCheck []str
allDomains := p.certificateStore.GetAllDomains() allDomains := p.certificateStore.GetAllDomains()
// Get ACME certificates // Get ACME certificates
for _, certificate := range p.certificates { for _, cert := range p.certificates {
allDomains = append(allDomains, strings.Join(certificate.Domain.ToStrArray(), ",")) allDomains = append(allDomains, strings.Join(cert.Domain.ToStrArray(), ","))
} }
// Get currently resolved domains // Get currently resolved domains
@ -740,12 +733,12 @@ func searchUncheckedDomains(ctx context.Context, domainsToCheck []string, existe
return uncheckedDomains return uncheckedDomains
} }
func getX509Certificate(ctx context.Context, certificate *Certificate) (*x509.Certificate, error) { func getX509Certificate(ctx context.Context, cert *Certificate) (*x509.Certificate, error) {
logger := log.FromContext(ctx) logger := log.FromContext(ctx)
tlsCert, err := tls.X509KeyPair(certificate.Certificate, certificate.Key) tlsCert, err := tls.X509KeyPair(cert.Certificate, cert.Key)
if err != nil { if err != nil {
logger.Errorf("Failed to load TLS key pair from ACME certificate for domain %q (SAN : %q), certificate will be renewed : %v", certificate.Domain.Main, strings.Join(certificate.Domain.SANs, ","), err) logger.Errorf("Failed to load TLS key pair from ACME certificate for domain %q (SAN : %q), certificate will be renewed : %v", cert.Domain.Main, strings.Join(cert.Domain.SANs, ","), err)
return nil, err return nil, err
} }
@ -753,7 +746,7 @@ func getX509Certificate(ctx context.Context, certificate *Certificate) (*x509.Ce
if crt == nil { if crt == nil {
crt, err = x509.ParseCertificate(tlsCert.Certificate[0]) crt, err = x509.ParseCertificate(tlsCert.Certificate[0])
if err != nil { if err != nil {
logger.Errorf("Failed to parse TLS key pair from ACME certificate for domain %q (SAN : %q), certificate will be renewed : %v", certificate.Domain.Main, strings.Join(certificate.Domain.SANs, ","), err) logger.Errorf("Failed to parse TLS key pair from ACME certificate for domain %q (SAN : %q), certificate will be renewed : %v", cert.Domain.Main, strings.Join(cert.Domain.SANs, ","), err)
} }
} }
@ -790,7 +783,7 @@ func (p *Provider) getValidDomains(ctx context.Context, domain types.Domain, wil
var cleanDomains []string var cleanDomains []string
for _, domain := range domains { for _, domain := range domains {
canonicalDomain := types.CanonicalDomain(domain) canonicalDomain := types.CanonicalDomain(domain)
cleanDomain := acme.UnFqdn(canonicalDomain) cleanDomain := dns01.UnFqdn(canonicalDomain)
if canonicalDomain != cleanDomain { if canonicalDomain != cleanDomain {
log.FromContext(ctx).Warnf("FQDN detected, please remove the trailing dot: %s", canonicalDomain) log.FromContext(ctx).Warnf("FQDN detected, please remove the trailing dot: %s", canonicalDomain)
} }
@ -810,37 +803,3 @@ func isDomainAlreadyChecked(domainToCheck string, existentDomains []string) bool
} }
return false return false
} }
// SetPropagationCheck to disable the Lego PreCheck.
func SetPropagationCheck(disable bool) {
if disable {
acme.PreCheckDNS = func(_, _ string) (bool, error) {
return true, nil
}
}
}
// SetRecursiveNameServers to provide a custom DNS resolver.
func SetRecursiveNameServers(dnsResolvers []string) {
resolvers := normaliseDNSResolvers(dnsResolvers)
if len(resolvers) > 0 {
acme.RecursiveNameservers = resolvers
log.Infof("Validating FQDN authority with DNS using %+v", resolvers)
}
}
// ensure all servers have a port number
func normaliseDNSResolvers(dnsResolvers []string) []string {
var normalisedResolvers []string
for _, server := range dnsResolvers {
srv := strings.TrimSpace(server)
if len(srv) > 0 {
if host, port, err := net.SplitHostPort(srv); err != nil {
normalisedResolvers = append(normalisedResolvers, net.JoinHostPort(srv, "53"))
} else {
normalisedResolvers = append(normalisedResolvers, net.JoinHostPort(host, port))
}
}
}
return normalisedResolvers
}

View file

@ -9,7 +9,7 @@ import (
traefiktls "github.com/containous/traefik/tls" traefiktls "github.com/containous/traefik/tls"
"github.com/containous/traefik/types" "github.com/containous/traefik/types"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/certcrypto"
) )
func TestGetUncheckedCertificates(t *testing.T) { func TestGetUncheckedCertificates(t *testing.T) {
@ -592,11 +592,11 @@ func TestInitAccount(t *testing.T) {
desc: "Existing account with all information", desc: "Existing account with all information",
account: &Account{ account: &Account{
Email: "foo@foo.net", Email: "foo@foo.net",
KeyType: acme.EC256, KeyType: certcrypto.EC256,
}, },
expectedAccount: &Account{ expectedAccount: &Account{
Email: "foo@foo.net", Email: "foo@foo.net",
KeyType: acme.EC256, KeyType: certcrypto.EC256,
}, },
}, },
{ {
@ -605,19 +605,19 @@ func TestInitAccount(t *testing.T) {
keyType: "EC256", keyType: "EC256",
expectedAccount: &Account{ expectedAccount: &Account{
Email: "foo@foo.net", Email: "foo@foo.net",
KeyType: acme.EC256, KeyType: certcrypto.EC256,
}, },
}, },
{ {
desc: "Existing account with no email", desc: "Existing account with no email",
account: &Account{ account: &Account{
KeyType: acme.RSA4096, KeyType: certcrypto.RSA4096,
}, },
email: "foo@foo.net", email: "foo@foo.net",
keyType: "EC256", keyType: "EC256",
expectedAccount: &Account{ expectedAccount: &Account{
Email: "foo@foo.net", Email: "foo@foo.net",
KeyType: acme.EC256, KeyType: certcrypto.EC256,
}, },
}, },
{ {
@ -629,7 +629,7 @@ func TestInitAccount(t *testing.T) {
keyType: "EC256", keyType: "EC256",
expectedAccount: &Account{ expectedAccount: &Account{
Email: "foo@foo.net", Email: "foo@foo.net",
KeyType: acme.EC256, KeyType: certcrypto.EC256,
}, },
}, },
{ {
@ -640,7 +640,7 @@ func TestInitAccount(t *testing.T) {
email: "bar@foo.net", email: "bar@foo.net",
expectedAccount: &Account{ expectedAccount: &Account{
Email: "foo@foo.net", Email: "foo@foo.net",
KeyType: acme.RSA4096, KeyType: certcrypto.RSA4096,
}, },
}, },
} }

View file

@ -22,7 +22,7 @@ import (
"github.com/containous/traefik/tls/generate" "github.com/containous/traefik/tls/generate"
"github.com/containous/traefik/types" "github.com/containous/traefik/types"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/xenolf/lego/acme" "github.com/xenolf/lego/challenge/tlsalpn01"
) )
// EntryPoints map of EntryPoint // EntryPoints map of EntryPoint
@ -380,7 +380,7 @@ func buildTLSConfig(tlsOption traefiktls.TLS) (*tls.Config, error) {
conf := &tls.Config{} conf := &tls.Config{}
// ensure http2 enabled // ensure http2 enabled
conf.NextProtos = []string{"h2", "http/1.1", acme.ACMETLS1Protocol} conf.NextProtos = []string{"h2", "http/1.1", tlsalpn01.ACMETLS1Protocol}
if len(tlsOption.ClientCA.Files) > 0 { if len(tlsOption.ClientCA.Files) > 0 {
pool := x509.NewCertPool() pool := x509.NewCertPool()

View file

@ -32,6 +32,24 @@ const (
UTF32LittleEndian UTF32LittleEndian
) )
// String returns a user-friendly string representation of the encoding. Satisfies fmt.Stringer interface.
func (e Encoding) String() string {
switch e {
case UTF8:
return "UTF8"
case UTF16BigEndian:
return "UTF16BigEndian"
case UTF16LittleEndian:
return "UTF16LittleEndian"
case UTF32BigEndian:
return "UTF32BigEndian"
case UTF32LittleEndian:
return "UTF32LittleEndian"
default:
return "Unknown"
}
}
const maxConsecutiveEmptyReads = 100 const maxConsecutiveEmptyReads = 100
// Skip creates Reader which automatically detects BOM (Unicode Byte Order Mark) and removes it as necessary. // Skip creates Reader which automatically detects BOM (Unicode Byte Order Mark) and removes it as necessary.

View file

@ -1,68 +1,52 @@
package dnsimple package dnsimple
import ( import (
"encoding/base64" "net/http"
) )
const ( // BasicAuthTransport is an http.RoundTripper that authenticates all requests
httpHeaderDomainToken = "X-DNSimple-Domain-Token" // using HTTP Basic Authentication with the provided username and password.
httpHeaderApiToken = "X-DNSimple-Token" type BasicAuthTransport struct {
httpHeaderAuthorization = "Authorization" Username string
) Password string
// Provides credentials that can be used for authenticating with DNSimple. // Transport is the transport RoundTripper used to make HTTP requests.
// // If nil, http.DefaultTransport is used.
// See https://developer.dnsimple.com/v2/#authentication Transport http.RoundTripper
type Credentials interface {
// Returns the HTTP headers that should be set
// to authenticate the HTTP Request.
Headers() map[string]string
} }
// Domain token authentication // RoundTrip implements the RoundTripper interface. We just add the
type domainTokenCredentials struct { // basic auth and return the RoundTripper for this transport type.
domainToken string func (t *BasicAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req2 := cloneRequest(req) // per RoundTripper contract
req2.SetBasicAuth(t.Username, t.Password)
return t.transport().RoundTrip(req2)
} }
// NewDomainTokenCredentials construct Credentials using the DNSimple Domain Token method. // Client returns an *http.Client that uses the BasicAuthTransport transport
func NewDomainTokenCredentials(domainToken string) Credentials { // to authenticate the request via HTTP Basic Auth.
return &domainTokenCredentials{domainToken: domainToken} func (t *BasicAuthTransport) Client() *http.Client {
return &http.Client{Transport: t}
} }
func (c *domainTokenCredentials) Headers() map[string]string { func (t *BasicAuthTransport) transport() http.RoundTripper {
return map[string]string{httpHeaderDomainToken: c.domainToken} if t.Transport != nil {
return t.Transport
}
return http.DefaultTransport
} }
// HTTP basic authentication // cloneRequest returns a clone of the provided *http.Request.
type httpBasicCredentials struct { // The clone is a shallow copy of the struct and its Header map.
email string func cloneRequest(r *http.Request) *http.Request {
password string // shallow copy of the struct
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header, len(r.Header))
for k, s := range r.Header {
r2.Header[k] = append([]string(nil), s...)
} }
return r2
// NewHTTPBasicCredentials construct Credentials using HTTP Basic Auth.
func NewHTTPBasicCredentials(email, password string) Credentials {
return &httpBasicCredentials{email, password}
}
func (c *httpBasicCredentials) Headers() map[string]string {
return map[string]string{httpHeaderAuthorization: "Basic " + c.basicAuth(c.email, c.password)}
}
func (c *httpBasicCredentials) basicAuth(username, password string) string {
auth := username + ":" + password
return base64.StdEncoding.EncodeToString([]byte(auth))
}
// OAuth token authentication
type oauthTokenCredentials struct {
oauthToken string
}
// NewOauthTokenCredentials construct Credentials using the OAuth access token.
func NewOauthTokenCredentials(oauthToken string) Credentials {
return &oauthTokenCredentials{oauthToken: oauthToken}
}
func (c *oauthTokenCredentials) Headers() map[string]string {
return map[string]string{httpHeaderAuthorization: "Bearer " + c.oauthToken}
} }

View file

@ -23,7 +23,7 @@ const (
// This is a pro-forma convention given that Go dependencies // This is a pro-forma convention given that Go dependencies
// tends to be fetched directly from the repo. // tends to be fetched directly from the repo.
// It is also used in the user-agent identify the client. // It is also used in the user-agent identify the client.
Version = "0.16.0" Version = "0.21.0"
// defaultBaseURL to the DNSimple production API. // defaultBaseURL to the DNSimple production API.
defaultBaseURL = "https://api.dnsimple.com" defaultBaseURL = "https://api.dnsimple.com"
@ -37,12 +37,9 @@ const (
// Client represents a client to the DNSimple API. // Client represents a client to the DNSimple API.
type Client struct { type Client struct {
// HttpClient is the underlying HTTP client // httpClient is the underlying HTTP client
// used to communicate with the API. // used to communicate with the API.
HttpClient *http.Client httpClient *http.Client
// Credentials used for accessing the DNSimple API
Credentials Credentials
// BaseURL for API requests. // BaseURL for API requests.
// Defaults to the public DNSimple API, but can be set to a different endpoint (e.g. the sandbox). // Defaults to the public DNSimple API, but can be set to a different endpoint (e.g. the sandbox).
@ -85,9 +82,12 @@ type ListOptions struct {
Sort string `url:"sort,omitempty"` Sort string `url:"sort,omitempty"`
} }
// NewClient returns a new DNSimple API client using the given credentials. // NewClient returns a new DNSimple API client.
func NewClient(credentials Credentials) *Client { //
c := &Client{Credentials: credentials, HttpClient: &http.Client{}, BaseURL: defaultBaseURL} // To authenticate you must provide an http.Client that will perform authentication
// for you with one of the currently supported mechanisms: OAuth or HTTP Basic.
func NewClient(httpClient *http.Client) *Client {
c := &Client{httpClient: httpClient, BaseURL: defaultBaseURL}
c.Identity = &IdentityService{client: c} c.Identity = &IdentityService{client: c}
c.Accounts = &AccountsService{client: c} c.Accounts = &AccountsService{client: c}
c.Certificates = &CertificatesService{client: c} c.Certificates = &CertificatesService{client: c}
@ -126,9 +126,6 @@ func (c *Client) NewRequest(method, path string, payload interface{}) (*http.Req
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Add("Accept", "application/json") req.Header.Add("Accept", "application/json")
req.Header.Add("User-Agent", formatUserAgent(c.UserAgent)) req.Header.Add("User-Agent", formatUserAgent(c.UserAgent))
for key, value := range c.Credentials.Headers() {
req.Header.Add(key, value)
}
return req, nil return req, nil
} }
@ -212,7 +209,7 @@ func (c *Client) Do(req *http.Request, obj interface{}) (*http.Response, error)
log.Printf("Executing request (%v): %#v", req.URL, req) log.Printf("Executing request (%v): %#v", req.URL, req)
} }
resp, err := c.HttpClient.Do(req) resp, err := c.httpClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -231,7 +228,7 @@ func (c *Client) Do(req *http.Request, obj interface{}) (*http.Response, error)
// the response body is decoded into v. // the response body is decoded into v.
if obj != nil { if obj != nil {
if w, ok := obj.(io.Writer); ok { if w, ok := obj.(io.Writer); ok {
io.Copy(w, resp.Body) _, err = io.Copy(w, resp.Body)
} else { } else {
err = json.NewDecoder(resp.Body).Decode(obj) err = json.NewDecoder(resp.Body).Decode(obj)
} }

View file

@ -72,7 +72,7 @@ func (s *OauthService) ExchangeAuthorizationForToken(authorization *ExchangeAuth
return nil, err return nil, err
} }
resp, err := s.client.HttpClient.Do(req) resp, err := s.client.httpClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -0,0 +1,46 @@
package dnsimple
import "fmt"
// ZoneDistribution is the result of the zone distribution check.
type ZoneDistribution struct {
Distributed bool `json:"distributed"`
}
// zoneDistributionResponse represents a response from an API method that returns a ZoneDistribution struct.
type zoneDistributionResponse struct {
Response
Data *ZoneDistribution `json:"data"`
}
// CheckZoneDistribution checks if a zone is fully distributed across DNSimple nodes.
//
// See https://developer.dnsimple.com/v2/zones/#checkZoneDistribution
func (s *ZonesService) CheckZoneDistribution(accountID string, zoneName string) (*zoneDistributionResponse, error) {
path := versioned(fmt.Sprintf("/%v/zones/%v/distribution", accountID, zoneName))
zoneDistributionResponse := &zoneDistributionResponse{}
resp, err := s.client.get(path, zoneDistributionResponse)
if err != nil {
return nil, err
}
zoneDistributionResponse.HttpResponse = resp
return zoneDistributionResponse, nil
}
// CheckZoneRecordDistribution checks if a zone is fully distributed across DNSimple nodes.
//
// See https://developer.dnsimple.com/v2/zones/#checkZoneRecordDistribution
func (s *ZonesService) CheckZoneRecordDistribution(accountID string, zoneName string, recordID int64) (*zoneDistributionResponse, error) {
path := versioned(fmt.Sprintf("/%v/zones/%v/records/%v/distribution", accountID, zoneName, recordID))
zoneDistributionResponse := &zoneDistributionResponse{}
resp, err := s.client.get(path, zoneDistributionResponse)
if err != nil {
return nil, err
}
zoneDistributionResponse.HttpResponse = resp
return zoneDistributionResponse, nil
}

54
vendor/github.com/miekg/dns/acceptfunc.go generated vendored Normal file
View file

@ -0,0 +1,54 @@
package dns
// MsgAcceptFunc is used early in the server code to accept or reject a message with RcodeFormatError.
// It returns a MsgAcceptAction to indicate what should happen with the message.
type MsgAcceptFunc func(dh Header) MsgAcceptAction
// DefaultMsgAcceptFunc checks the request and will reject if:
//
// * isn't a request (don't respond in that case).
// * opcode isn't OpcodeQuery or OpcodeNotify
// * Zero bit isn't zero
// * has more than 1 question in the question section
// * has more than 0 RRs in the Answer section
// * has more than 0 RRs in the Authority section
// * has more than 2 RRs in the Additional section
var DefaultMsgAcceptFunc MsgAcceptFunc = defaultMsgAcceptFunc
// MsgAcceptAction represents the action to be taken.
type MsgAcceptAction int
const (
MsgAccept MsgAcceptAction = iota // Accept the message
MsgReject // Reject the message with a RcodeFormatError
MsgIgnore // Ignore the error and send nothing back.
)
var defaultMsgAcceptFunc = func(dh Header) MsgAcceptAction {
if isResponse := dh.Bits&_QR != 0; isResponse {
return MsgIgnore
}
// Don't allow dynamic updates, because then the sections can contain a whole bunch of RRs.
opcode := int(dh.Bits>>11) & 0xF
if opcode != OpcodeQuery && opcode != OpcodeNotify {
return MsgReject
}
if isZero := dh.Bits&_Z != 0; isZero {
return MsgReject
}
if dh.Qdcount != 1 {
return MsgReject
}
if dh.Ancount != 0 {
return MsgReject
}
if dh.Nscount != 0 {
return MsgReject
}
if dh.Arcount > 2 {
return MsgReject
}
return MsgAccept
}

View file

@ -13,16 +13,16 @@ import (
"time" "time"
) )
const dnsTimeout time.Duration = 2 * time.Second const (
const tcpIdleTimeout time.Duration = 8 * time.Second dnsTimeout time.Duration = 2 * time.Second
tcpIdleTimeout time.Duration = 8 * time.Second
)
// A Conn represents a connection to a DNS server. // A Conn represents a connection to a DNS server.
type Conn struct { type Conn struct {
net.Conn // a net.Conn holding the connection net.Conn // a net.Conn holding the connection
UDPSize uint16 // minimum receive buffer for UDP messages UDPSize uint16 // minimum receive buffer for UDP messages
TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) TsigSecret map[string]string // secret(s) for Tsig map[<zonename>]<base64 secret>, zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2)
rtt time.Duration
t time.Time
tsigRequestMAC string tsigRequestMAC string
} }
@ -83,33 +83,22 @@ func (c *Client) Dial(address string) (conn *Conn, err error) {
// create a new dialer with the appropriate timeout // create a new dialer with the appropriate timeout
var d net.Dialer var d net.Dialer
if c.Dialer == nil { if c.Dialer == nil {
d = net.Dialer{} d = net.Dialer{Timeout: c.getTimeoutForRequest(c.dialTimeout())}
} else { } else {
d = net.Dialer(*c.Dialer) d = *c.Dialer
} }
d.Timeout = c.getTimeoutForRequest(c.writeTimeout())
network := "udp" network := c.Net
useTLS := false if network == "" {
network = "udp"
}
switch c.Net { useTLS := strings.HasPrefix(network, "tcp") && strings.HasSuffix(network, "-tls")
case "tcp-tls":
network = "tcp"
useTLS = true
case "tcp4-tls":
network = "tcp4"
useTLS = true
case "tcp6-tls":
network = "tcp6"
useTLS = true
default:
if c.Net != "" {
network = c.Net
}
}
conn = new(Conn) conn = new(Conn)
if useTLS { if useTLS {
network = strings.TrimSuffix(network, "-tls")
conn.Conn, err = tls.DialWithDialer(&d, network, address, c.TLSConfig) conn.Conn, err = tls.DialWithDialer(&d, network, address, c.TLSConfig)
} else { } else {
conn.Conn, err = d.Dial(network, address) conn.Conn, err = d.Dial(network, address)
@ -117,6 +106,7 @@ func (c *Client) Dial(address string) (conn *Conn, err error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return conn, nil return conn, nil
} }
@ -177,8 +167,9 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
} }
co.TsigSecret = c.TsigSecret co.TsigSecret = c.TsigSecret
t := time.Now()
// write with the appropriate write timeout // write with the appropriate write timeout
co.SetWriteDeadline(time.Now().Add(c.getTimeoutForRequest(c.writeTimeout()))) co.SetWriteDeadline(t.Add(c.getTimeoutForRequest(c.writeTimeout())))
if err = co.WriteMsg(m); err != nil { if err = co.WriteMsg(m); err != nil {
return nil, 0, err return nil, 0, err
} }
@ -188,7 +179,8 @@ func (c *Client) exchange(m *Msg, a string) (r *Msg, rtt time.Duration, err erro
if err == nil && r.Id != m.Id { if err == nil && r.Id != m.Id {
err = ErrId err = ErrId
} }
return r, co.rtt, err rtt = time.Since(t)
return r, rtt, err
} }
// ReadMsg reads a message from the connection co. // ReadMsg reads a message from the connection co.
@ -240,7 +232,6 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) {
} }
p = make([]byte, l) p = make([]byte, l)
n, err = tcpRead(r, p) n, err = tcpRead(r, p)
co.rtt = time.Since(co.t)
default: default:
if co.UDPSize > MinMsgSize { if co.UDPSize > MinMsgSize {
p = make([]byte, co.UDPSize) p = make([]byte, co.UDPSize)
@ -248,7 +239,6 @@ func (co *Conn) ReadMsgHeader(hdr *Header) ([]byte, error) {
p = make([]byte, MinMsgSize) p = make([]byte, MinMsgSize)
} }
n, err = co.Read(p) n, err = co.Read(p)
co.rtt = time.Since(co.t)
} }
if err != nil { if err != nil {
@ -361,7 +351,6 @@ func (co *Conn) WriteMsg(m *Msg) (err error) {
if err != nil { if err != nil {
return err return err
} }
co.t = time.Now()
if _, err = co.Write(out); err != nil { if _, err = co.Write(out); err != nil {
return err return err
} }
@ -497,10 +486,11 @@ func (c *Client) ExchangeContext(ctx context.Context, m *Msg, a string) (r *Msg,
if deadline, ok := ctx.Deadline(); !ok { if deadline, ok := ctx.Deadline(); !ok {
timeout = 0 timeout = 0
} else { } else {
timeout = deadline.Sub(time.Now()) timeout = time.Until(deadline)
} }
// not passing the context to the underlying calls, as the API does not support // not passing the context to the underlying calls, as the API does not support
// context. For timeouts you should set up Client.Dialer and call Client.Exchange. // context. For timeouts you should set up Client.Dialer and call Client.Exchange.
// TODO(tmthrgd,miekg): this is a race condition.
c.Dialer = &net.Dialer{Timeout: timeout} c.Dialer = &net.Dialer{Timeout: timeout}
return c.Exchange(m, a) return c.Exchange(m, a)
} }

View file

@ -1,188 +0,0 @@
//+build ignore
// compression_generate.go is meant to run with go generate. It will use
// go/{importer,types} to track down all the RR struct types. Then for each type
// it will look to see if there are (compressible) names, if so it will add that
// type to compressionLenHelperType and comressionLenSearchType which "fake" the
// compression so that Len() is fast.
package main
import (
"bytes"
"fmt"
"go/format"
"go/importer"
"go/types"
"log"
"os"
)
var packageHdr = `
// Code generated by "go run compress_generate.go"; DO NOT EDIT.
package dns
`
// getTypeStruct will take a type and the package scope, and return the
// (innermost) struct if the type is considered a RR type (currently defined as
// those structs beginning with a RR_Header, could be redefined as implementing
// the RR interface). The bool return value indicates if embedded structs were
// resolved.
func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
st, ok := t.Underlying().(*types.Struct)
if !ok {
return nil, false
}
if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
return st, false
}
if st.Field(0).Anonymous() {
st, _ := getTypeStruct(st.Field(0).Type(), scope)
return st, true
}
return nil, false
}
func main() {
// Import and type-check the package
pkg, err := importer.Default().Import("github.com/miekg/dns")
fatalIfErr(err)
scope := pkg.Scope()
var domainTypes []string // Types that have a domain name in them (either compressible or not).
var cdomainTypes []string // Types that have a compressible domain name in them (subset of domainType)
Names:
for _, name := range scope.Names() {
o := scope.Lookup(name)
if o == nil || !o.Exported() {
continue
}
st, _ := getTypeStruct(o.Type(), scope)
if st == nil {
continue
}
if name == "PrivateRR" {
continue
}
if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" {
log.Fatalf("Constant Type%s does not exist.", o.Name())
}
for i := 1; i < st.NumFields(); i++ {
if _, ok := st.Field(i).Type().(*types.Slice); ok {
if st.Tag(i) == `dns:"domain-name"` {
domainTypes = append(domainTypes, o.Name())
continue Names
}
if st.Tag(i) == `dns:"cdomain-name"` {
cdomainTypes = append(cdomainTypes, o.Name())
domainTypes = append(domainTypes, o.Name())
continue Names
}
continue
}
switch {
case st.Tag(i) == `dns:"domain-name"`:
domainTypes = append(domainTypes, o.Name())
continue Names
case st.Tag(i) == `dns:"cdomain-name"`:
cdomainTypes = append(cdomainTypes, o.Name())
domainTypes = append(domainTypes, o.Name())
continue Names
}
}
}
b := &bytes.Buffer{}
b.WriteString(packageHdr)
// compressionLenHelperType - all types that have domain-name/cdomain-name can be used for compressing names
fmt.Fprint(b, "func compressionLenHelperType(c map[string]int, r RR) {\n")
fmt.Fprint(b, "switch x := r.(type) {\n")
for _, name := range domainTypes {
o := scope.Lookup(name)
st, _ := getTypeStruct(o.Type(), scope)
fmt.Fprintf(b, "case *%s:\n", name)
for i := 1; i < st.NumFields(); i++ {
out := func(s string) { fmt.Fprintf(b, "compressionLenHelper(c, x.%s)\n", st.Field(i).Name()) }
if _, ok := st.Field(i).Type().(*types.Slice); ok {
switch st.Tag(i) {
case `dns:"domain-name"`:
fallthrough
case `dns:"cdomain-name"`:
// For HIP we need to slice over the elements in this slice.
fmt.Fprintf(b, `for i := range x.%s {
compressionLenHelper(c, x.%s[i])
}
`, st.Field(i).Name(), st.Field(i).Name())
}
continue
}
switch {
case st.Tag(i) == `dns:"cdomain-name"`:
fallthrough
case st.Tag(i) == `dns:"domain-name"`:
out(st.Field(i).Name())
}
}
}
fmt.Fprintln(b, "}\n}\n\n")
// compressionLenSearchType - search cdomain-tags types for compressible names.
fmt.Fprint(b, "func compressionLenSearchType(c map[string]int, r RR) (int, bool) {\n")
fmt.Fprint(b, "switch x := r.(type) {\n")
for _, name := range cdomainTypes {
o := scope.Lookup(name)
st, _ := getTypeStruct(o.Type(), scope)
fmt.Fprintf(b, "case *%s:\n", name)
j := 1
for i := 1; i < st.NumFields(); i++ {
out := func(s string, j int) {
fmt.Fprintf(b, "k%d, ok%d := compressionLenSearch(c, x.%s)\n", j, j, st.Field(i).Name())
}
// There are no slice types with names that can be compressed.
switch {
case st.Tag(i) == `dns:"cdomain-name"`:
out(st.Field(i).Name(), j)
j++
}
}
k := "k1"
ok := "ok1"
for i := 2; i < j; i++ {
k += fmt.Sprintf(" + k%d", i)
ok += fmt.Sprintf(" && ok%d", i)
}
fmt.Fprintf(b, "return %s, %s\n", k, ok)
}
fmt.Fprintln(b, "}\nreturn 0, false\n}\n\n")
// gofmt
res, err := format.Source(b.Bytes())
if err != nil {
b.WriteTo(os.Stderr)
log.Fatal(err)
}
f, err := os.Create("zcompress.go")
fatalIfErr(err)
defer f.Close()
f.Write(res)
}
func fatalIfErr(err error) {
if err != nil {
log.Fatal(err)
}
}

View file

@ -166,7 +166,7 @@ func (dns *Msg) IsEdns0() *OPT {
// label fits in 63 characters, but there is no length check for the entire // label fits in 63 characters, but there is no length check for the entire
// string s. I.e. a domain name longer than 255 characters is considered valid. // string s. I.e. a domain name longer than 255 characters is considered valid.
func IsDomainName(s string) (labels int, ok bool) { func IsDomainName(s string) (labels int, ok bool) {
_, labels, err := packDomainName(s, nil, 0, nil, false) _, labels, err := packDomainName(s, nil, 0, compressionMap{}, false)
return labels, err == nil return labels, err == nil
} }

38
vendor/github.com/miekg/dns/dns.go generated vendored
View file

@ -34,10 +34,15 @@ type RR interface {
// copy returns a copy of the RR // copy returns a copy of the RR
copy() RR copy() RR
// len returns the length (in octets) of the uncompressed RR in wire format.
len() int // len returns the length (in octets) of the compressed or uncompressed RR in wire format.
//
// If compression is nil, the uncompressed size will be returned, otherwise the compressed
// size will be returned and domain names will be added to the map for future compression.
len(off int, compression map[string]struct{}) int
// pack packs an RR into wire format. // pack packs an RR into wire format.
pack([]byte, int, map[string]int, bool) (int, error) pack(msg []byte, off int, compression compressionMap, compress bool) (headerEnd int, off1 int, err error)
} }
// RR_Header is the header all DNS resource records share. // RR_Header is the header all DNS resource records share.
@ -55,16 +60,6 @@ func (h *RR_Header) Header() *RR_Header { return h }
// Just to implement the RR interface. // Just to implement the RR interface.
func (h *RR_Header) copy() RR { return nil } func (h *RR_Header) copy() RR { return nil }
func (h *RR_Header) copyHeader() *RR_Header {
r := new(RR_Header)
r.Name = h.Name
r.Rrtype = h.Rrtype
r.Class = h.Class
r.Ttl = h.Ttl
r.Rdlength = h.Rdlength
return r
}
func (h *RR_Header) String() string { func (h *RR_Header) String() string {
var s string var s string
@ -80,28 +75,29 @@ func (h *RR_Header) String() string {
return s return s
} }
func (h *RR_Header) len() int { func (h *RR_Header) len(off int, compression map[string]struct{}) int {
l := len(h.Name) + 1 l := domainNameLen(h.Name, off, compression, true)
l += 10 // rrtype(2) + class(2) + ttl(4) + rdlength(2) l += 10 // rrtype(2) + class(2) + ttl(4) + rdlength(2)
return l return l
} }
// ToRFC3597 converts a known RR to the unknown RR representation from RFC 3597. // ToRFC3597 converts a known RR to the unknown RR representation from RFC 3597.
func (rr *RFC3597) ToRFC3597(r RR) error { func (rr *RFC3597) ToRFC3597(r RR) error {
buf := make([]byte, r.len()*2) buf := make([]byte, Len(r)*2)
off, err := PackRR(r, buf, 0, nil, false) headerEnd, off, err := packRR(r, buf, 0, compressionMap{}, false)
if err != nil { if err != nil {
return err return err
} }
buf = buf[:off] buf = buf[:off]
if int(r.Header().Rdlength) > off {
return ErrBuf
}
rfc3597, _, err := unpackRFC3597(*r.Header(), buf, off-int(r.Header().Rdlength)) hdr := *r.Header()
hdr.Rdlength = uint16(off - headerEnd)
rfc3597, _, err := unpackRFC3597(hdr, buf, headerEnd)
if err != nil { if err != nil {
return err return err
} }
*rr = *rfc3597.(*RFC3597) *rr = *rfc3597.(*RFC3597)
return nil return nil
} }

View file

@ -73,6 +73,7 @@ var StringToAlgorithm = reverseInt8(AlgorithmToString)
// AlgorithmToHash is a map of algorithm crypto hash IDs to crypto.Hash's. // AlgorithmToHash is a map of algorithm crypto hash IDs to crypto.Hash's.
var AlgorithmToHash = map[uint8]crypto.Hash{ var AlgorithmToHash = map[uint8]crypto.Hash{
RSAMD5: crypto.MD5, // Deprecated in RFC 6725 RSAMD5: crypto.MD5, // Deprecated in RFC 6725
DSA: crypto.SHA1,
RSASHA1: crypto.SHA1, RSASHA1: crypto.SHA1,
RSASHA1NSEC3SHA1: crypto.SHA1, RSASHA1NSEC3SHA1: crypto.SHA1,
RSASHA256: crypto.SHA256, RSASHA256: crypto.SHA256,
@ -172,7 +173,7 @@ func (k *DNSKEY) KeyTag() uint16 {
keytag += int(v) << 8 keytag += int(v) << 8
} }
} }
keytag += (keytag >> 16) & 0xFFFF keytag += keytag >> 16 & 0xFFFF
keytag &= 0xFFFF keytag &= 0xFFFF
} }
return uint16(keytag) return uint16(keytag)
@ -239,7 +240,7 @@ func (k *DNSKEY) ToDS(h uint8) *DS {
// ToCDNSKEY converts a DNSKEY record to a CDNSKEY record. // ToCDNSKEY converts a DNSKEY record to a CDNSKEY record.
func (k *DNSKEY) ToCDNSKEY() *CDNSKEY { func (k *DNSKEY) ToCDNSKEY() *CDNSKEY {
c := &CDNSKEY{DNSKEY: *k} c := &CDNSKEY{DNSKEY: *k}
c.Hdr = *k.Hdr.copyHeader() c.Hdr = k.Hdr
c.Hdr.Rrtype = TypeCDNSKEY c.Hdr.Rrtype = TypeCDNSKEY
return c return c
} }
@ -247,7 +248,7 @@ func (k *DNSKEY) ToCDNSKEY() *CDNSKEY {
// ToCDS converts a DS record to a CDS record. // ToCDS converts a DS record to a CDS record.
func (d *DS) ToCDS() *CDS { func (d *DS) ToCDS() *CDS {
c := &CDS{DS: *d} c := &CDS{DS: *d}
c.Hdr = *d.Hdr.copyHeader() c.Hdr = d.Hdr
c.Hdr.Rrtype = TypeCDS c.Hdr.Rrtype = TypeCDS
return c return c
} }
@ -400,7 +401,7 @@ func (rr *RRSIG) Verify(k *DNSKEY, rrset []RR) error {
if rr.Algorithm != k.Algorithm { if rr.Algorithm != k.Algorithm {
return ErrKey return ErrKey
} }
if strings.ToLower(rr.SignerName) != strings.ToLower(k.Hdr.Name) { if !strings.EqualFold(rr.SignerName, k.Hdr.Name) {
return ErrKey return ErrKey
} }
if k.Protocol != 3 { if k.Protocol != 3 {
@ -511,8 +512,8 @@ func (rr *RRSIG) ValidityPeriod(t time.Time) bool {
} }
modi := (int64(rr.Inception) - utc) / year68 modi := (int64(rr.Inception) - utc) / year68
mode := (int64(rr.Expiration) - utc) / year68 mode := (int64(rr.Expiration) - utc) / year68
ti := int64(rr.Inception) + (modi * year68) ti := int64(rr.Inception) + modi*year68
te := int64(rr.Expiration) + (mode * year68) te := int64(rr.Expiration) + mode*year68
return ti <= utc && utc <= te return ti <= utc && utc <= te
} }
@ -532,6 +533,11 @@ func (k *DNSKEY) publicKeyRSA() *rsa.PublicKey {
return nil return nil
} }
if len(keybuf) < 1+1+64 {
// Exponent must be at least 1 byte and modulus at least 64
return nil
}
// RFC 2537/3110, section 2. RSA Public KEY Resource Records // RFC 2537/3110, section 2. RSA Public KEY Resource Records
// Length is in the 0th byte, unless its zero, then it // Length is in the 0th byte, unless its zero, then it
// it in bytes 1 and 2 and its a 16 bit number // it in bytes 1 and 2 and its a 16 bit number
@ -541,25 +547,36 @@ func (k *DNSKEY) publicKeyRSA() *rsa.PublicKey {
explen = uint16(keybuf[1])<<8 | uint16(keybuf[2]) explen = uint16(keybuf[1])<<8 | uint16(keybuf[2])
keyoff = 3 keyoff = 3
} }
if explen > 4 || explen == 0 || keybuf[keyoff] == 0 {
// Exponent larger than supported by the crypto package,
// empty, or contains prohibited leading zero.
return nil
}
modoff := keyoff + int(explen)
modlen := len(keybuf) - modoff
if modlen < 64 || modlen > 512 || keybuf[modoff] == 0 {
// Modulus is too small, large, or contains prohibited leading zero.
return nil
}
pubkey := new(rsa.PublicKey) pubkey := new(rsa.PublicKey)
pubkey.N = big.NewInt(0)
shift := uint64((explen - 1) * 8)
expo := uint64(0) expo := uint64(0)
for i := int(explen - 1); i > 0; i-- { for i := 0; i < int(explen); i++ {
expo += uint64(keybuf[keyoff+i]) << shift expo <<= 8
shift -= 8 expo |= uint64(keybuf[keyoff+i])
} }
// Remainder if expo > 1<<31-1 {
expo += uint64(keybuf[keyoff]) // Larger exponent than supported by the crypto package.
if expo > (2<<31)+1 {
// Larger expo than supported.
// println("dns: F5 primes (or larger) are not supported")
return nil return nil
} }
pubkey.E = int(expo) pubkey.E = int(expo)
pubkey.N.SetBytes(keybuf[keyoff+int(explen):]) pubkey.N = big.NewInt(0)
pubkey.N.SetBytes(keybuf[modoff:])
return pubkey return pubkey
} }
@ -707,7 +724,7 @@ func rawSignatureData(rrset []RR, s *RRSIG) (buf []byte, err error) {
x.Target = strings.ToLower(x.Target) x.Target = strings.ToLower(x.Target)
} }
// 6.2. Canonical RR Form. (5) - origTTL // 6.2. Canonical RR Form. (5) - origTTL
wire := make([]byte, r1.len()+1) // +1 to be safe(r) wire := make([]byte, Len(r1)+1) // +1 to be safe(r)
off, err1 := PackRR(r1, wire, 0, nil, false) off, err1 := PackRR(r1, wire, 0, nil, false)
if err1 != nil { if err1 != nil {
return nil, err1 return nil, err1

View file

@ -1,7 +1,7 @@
package dns package dns
import ( import (
"bytes" "bufio"
"crypto" "crypto"
"crypto/dsa" "crypto/dsa"
"crypto/ecdsa" "crypto/ecdsa"
@ -181,22 +181,10 @@ func readPrivateKeyED25519(m map[string]string) (ed25519.PrivateKey, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(p1) != 32 { if len(p1) != ed25519.SeedSize {
return nil, ErrPrivKey return nil, ErrPrivKey
} }
// RFC 8080 and Golang's x/crypto/ed25519 differ as to how the p = ed25519.NewKeyFromSeed(p1)
// private keys are represented. RFC 8080 specifies that private
// keys be stored solely as the seed value (p1 above) while the
// ed25519 package represents them as the seed value concatenated
// to the public key, which is derived from the seed value.
//
// ed25519.GenerateKey reads exactly 32 bytes from the passed in
// io.Reader and uses them as the seed. It also derives the
// public key and produces a compatible private key.
_, p, err = ed25519.GenerateKey(bytes.NewReader(p1))
if err != nil {
return nil, err
}
case "created", "publish", "activate": case "created", "publish", "activate":
/* not used in Go (yet) */ /* not used in Go (yet) */
} }
@ -207,23 +195,12 @@ func readPrivateKeyED25519(m map[string]string) (ed25519.PrivateKey, error) {
// parseKey reads a private key from r. It returns a map[string]string, // parseKey reads a private key from r. It returns a map[string]string,
// with the key-value pairs, or an error when the file is not correct. // with the key-value pairs, or an error when the file is not correct.
func parseKey(r io.Reader, file string) (map[string]string, error) { func parseKey(r io.Reader, file string) (map[string]string, error) {
s, cancel := scanInit(r)
m := make(map[string]string) m := make(map[string]string)
c := make(chan lex) var k string
k := ""
defer func() { c := newKLexer(r)
cancel()
// zlexer can send up to two tokens, the next one and possibly 1 remainders. for l, ok := c.Next(); ok; l, ok = c.Next() {
// Do a non-blocking read.
_, ok := <-c
_, ok = <-c
if !ok {
// too bad
}
}()
// Start the lexer
go klexer(s, c)
for l := range c {
// It should alternate // It should alternate
switch l.value { switch l.value {
case zKey: case zKey:
@ -232,41 +209,111 @@ func parseKey(r io.Reader, file string) (map[string]string, error) {
if k == "" { if k == "" {
return nil, &ParseError{file, "no private key seen", l} return nil, &ParseError{file, "no private key seen", l}
} }
//println("Setting", strings.ToLower(k), "to", l.token, "b")
m[strings.ToLower(k)] = l.token m[strings.ToLower(k)] = l.token
k = "" k = ""
} }
} }
// Surface any read errors from r.
if err := c.Err(); err != nil {
return nil, &ParseError{file: file, err: err.Error()}
}
return m, nil return m, nil
} }
// klexer scans the sourcefile and returns tokens on the channel c. type klexer struct {
func klexer(s *scan, c chan lex) { br io.ByteReader
var l lex
str := "" // Hold the current read text readErr error
commt := false
key := true line int
x, err := s.tokenText() column int
defer close(c)
for err == nil { key bool
l.column = s.position.Column
l.line = s.position.Line eol bool // end-of-line
}
func newKLexer(r io.Reader) *klexer {
br, ok := r.(io.ByteReader)
if !ok {
br = bufio.NewReaderSize(r, 1024)
}
return &klexer{
br: br,
line: 1,
key: true,
}
}
func (kl *klexer) Err() error {
if kl.readErr == io.EOF {
return nil
}
return kl.readErr
}
// readByte returns the next byte from the input
func (kl *klexer) readByte() (byte, bool) {
if kl.readErr != nil {
return 0, false
}
c, err := kl.br.ReadByte()
if err != nil {
kl.readErr = err
return 0, false
}
// delay the newline handling until the next token is delivered,
// fixes off-by-one errors when reporting a parse error.
if kl.eol {
kl.line++
kl.column = 0
kl.eol = false
}
if c == '\n' {
kl.eol = true
} else {
kl.column++
}
return c, true
}
func (kl *klexer) Next() (lex, bool) {
var (
l lex
str strings.Builder
commt bool
)
for x, ok := kl.readByte(); ok; x, ok = kl.readByte() {
l.line, l.column = kl.line, kl.column
switch x { switch x {
case ':': case ':':
if commt { if commt || !kl.key {
break break
} }
l.token = str
if key { kl.key = false
l.value = zKey
c <- l
// Next token is a space, eat it // Next token is a space, eat it
s.tokenText() kl.readByte()
key = false
str = "" l.value = zKey
} else { l.token = str.String()
l.value = zValue return l, true
}
case ';': case ';':
commt = true commt = true
case '\n': case '\n':
@ -274,24 +321,32 @@ func klexer(s *scan, c chan lex) {
// Reset a comment // Reset a comment
commt = false commt = false
} }
kl.key = true
l.value = zValue l.value = zValue
l.token = str l.token = str.String()
c <- l return l, true
str = ""
commt = false
key = true
default: default:
if commt { if commt {
break break
} }
str += string(x)
str.WriteByte(x)
} }
x, err = s.tokenText()
} }
if len(str) > 0 {
if kl.readErr != nil && kl.readErr != io.EOF {
// Don't return any tokens after a read error occurs.
return lex{value: zEOF}, false
}
if str.Len() > 0 {
// Send remainder // Send remainder
l.token = str
l.value = zValue l.value = zValue
c <- l l.token = str.String()
return l, true
} }
return lex{value: zEOF}, false
} }

View file

@ -82,7 +82,7 @@ func (r *DNSKEY) PrivateKeyString(p crypto.PrivateKey) string {
"Public_value(y): " + pub + "\n" "Public_value(y): " + pub + "\n"
case ed25519.PrivateKey: case ed25519.PrivateKey:
private := toBase64(p[:32]) private := toBase64(p.Seed())
return format + return format +
"Algorithm: " + algorithm + "\n" + "Algorithm: " + algorithm + "\n" +
"PrivateKey: " + private + "\n" "PrivateKey: " + private + "\n"

111
vendor/github.com/miekg/dns/doc.go generated vendored
View file

@ -1,20 +1,20 @@
/* /*
Package dns implements a full featured interface to the Domain Name System. Package dns implements a full featured interface to the Domain Name System.
Server- and client-side programming is supported. Both server- and client-side programming is supported. The package allows
The package allows complete control over what is sent out to the DNS. The package complete control over what is sent out to the DNS. The API follows the
API follows the less-is-more principle, by presenting a small, clean interface. less-is-more principle, by presenting a small, clean interface.
The package dns supports (asynchronous) querying/replying, incoming/outgoing zone transfers, It supports (asynchronous) querying/replying, incoming/outgoing zone transfers,
TSIG, EDNS0, dynamic updates, notifies and DNSSEC validation/signing. TSIG, EDNS0, dynamic updates, notifies and DNSSEC validation/signing.
Note that domain names MUST be fully qualified, before sending them, unqualified
Note that domain names MUST be fully qualified before sending them, unqualified
names in a message will result in a packing failure. names in a message will result in a packing failure.
Resource records are native types. They are not stored in wire format. Resource records are native types. They are not stored in wire format. Basic
Basic usage pattern for creating a new resource record: usage pattern for creating a new resource record:
r := new(dns.MX) r := new(dns.MX)
r.Hdr = dns.RR_Header{Name: "miek.nl.", Rrtype: dns.TypeMX, r.Hdr = dns.RR_Header{Name: "miek.nl.", Rrtype: dns.TypeMX, Class: dns.ClassINET, Ttl: 3600}
Class: dns.ClassINET, Ttl: 3600}
r.Preference = 10 r.Preference = 10
r.Mx = "mx.miek.nl." r.Mx = "mx.miek.nl."
@ -30,8 +30,8 @@ Or even:
mx, err := dns.NewRR("$ORIGIN nl.\nmiek 1H IN MX 10 mx.miek") mx, err := dns.NewRR("$ORIGIN nl.\nmiek 1H IN MX 10 mx.miek")
In the DNS messages are exchanged, these messages contain resource In the DNS messages are exchanged, these messages contain resource records
records (sets). Use pattern for creating a message: (sets). Use pattern for creating a message:
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion("miek.nl.", dns.TypeMX) m.SetQuestion("miek.nl.", dns.TypeMX)
@ -40,8 +40,8 @@ Or when not certain if the domain name is fully qualified:
m.SetQuestion(dns.Fqdn("miek.nl"), dns.TypeMX) m.SetQuestion(dns.Fqdn("miek.nl"), dns.TypeMX)
The message m is now a message with the question section set to ask The message m is now a message with the question section set to ask the MX
the MX records for the miek.nl. zone. records for the miek.nl. zone.
The following is slightly more verbose, but more flexible: The following is slightly more verbose, but more flexible:
@ -51,9 +51,8 @@ The following is slightly more verbose, but more flexible:
m1.Question = make([]dns.Question, 1) m1.Question = make([]dns.Question, 1)
m1.Question[0] = dns.Question{"miek.nl.", dns.TypeMX, dns.ClassINET} m1.Question[0] = dns.Question{"miek.nl.", dns.TypeMX, dns.ClassINET}
After creating a message it can be sent. After creating a message it can be sent. Basic use pattern for synchronous
Basic use pattern for synchronous querying the DNS at a querying the DNS at a server configured on 127.0.0.1 and port 53:
server configured on 127.0.0.1 and port 53:
c := new(dns.Client) c := new(dns.Client)
in, rtt, err := c.Exchange(m1, "127.0.0.1:53") in, rtt, err := c.Exchange(m1, "127.0.0.1:53")
@ -73,11 +72,11 @@ and port to use for the connection:
Port: 12345, Port: 12345,
Zone: "", Zone: "",
} }
d := net.Dialer{ c.Dialer := &net.Dialer{
Timeout: 200 * time.Millisecond, Timeout: 200 * time.Millisecond,
LocalAddr: &laddr, LocalAddr: &laddr,
} }
in, rtt, err := c.ExchangeWithDialer(&d, m1, "8.8.8.8:53") in, rtt, err := c.Exchange(m1, "8.8.8.8:53")
If these "advanced" features are not needed, a simple UDP query can be sent, If these "advanced" features are not needed, a simple UDP query can be sent,
with: with:
@ -99,25 +98,24 @@ the Answer section:
Domain Name and TXT Character String Representations Domain Name and TXT Character String Representations
Both domain names and TXT character strings are converted to presentation Both domain names and TXT character strings are converted to presentation form
form both when unpacked and when converted to strings. both when unpacked and when converted to strings.
For TXT character strings, tabs, carriage returns and line feeds will be For TXT character strings, tabs, carriage returns and line feeds will be
converted to \t, \r and \n respectively. Back slashes and quotations marks converted to \t, \r and \n respectively. Back slashes and quotations marks will
will be escaped. Bytes below 32 and above 127 will be converted to \DDD be escaped. Bytes below 32 and above 127 will be converted to \DDD form.
form.
For domain names, in addition to the above rules brackets, periods, For domain names, in addition to the above rules brackets, periods, spaces,
spaces, semicolons and the at symbol are escaped. semicolons and the at symbol are escaped.
DNSSEC DNSSEC
DNSSEC (DNS Security Extension) adds a layer of security to the DNS. It DNSSEC (DNS Security Extension) adds a layer of security to the DNS. It uses
uses public key cryptography to sign resource records. The public key cryptography to sign resource records. The public keys are stored in
public keys are stored in DNSKEY records and the signatures in RRSIG records. DNSKEY records and the signatures in RRSIG records.
Requesting DNSSEC information for a zone is done by adding the DO (DNSSEC OK) bit Requesting DNSSEC information for a zone is done by adding the DO (DNSSEC OK)
to a request. bit to a request.
m := new(dns.Msg) m := new(dns.Msg)
m.SetEdns0(4096, true) m.SetEdns0(4096, true)
@ -126,9 +124,9 @@ Signature generation, signature verification and key generation are all supporte
DYNAMIC UPDATES DYNAMIC UPDATES
Dynamic updates reuses the DNS message format, but renames three of Dynamic updates reuses the DNS message format, but renames three of the
the sections. Question is Zone, Answer is Prerequisite, Authority is sections. Question is Zone, Answer is Prerequisite, Authority is Update, only
Update, only the Additional is not renamed. See RFC 2136 for the gory details. the Additional is not renamed. See RFC 2136 for the gory details.
You can set a rather complex set of rules for the existence of absence of You can set a rather complex set of rules for the existence of absence of
certain resource records or names in a zone to specify if resource records certain resource records or names in a zone to specify if resource records
@ -145,10 +143,9 @@ DNS function shows which functions exist to specify the prerequisites.
NONE rrset empty RRset does not exist dns.RRsetNotUsed NONE rrset empty RRset does not exist dns.RRsetNotUsed
zone rrset rr RRset exists (value dep) dns.Used zone rrset rr RRset exists (value dep) dns.Used
The prerequisite section can also be left empty. The prerequisite section can also be left empty. If you have decided on the
If you have decided on the prerequisites you can tell what RRs should prerequisites you can tell what RRs should be added or deleted. The next table
be added or deleted. The next table shows the options you have and shows the options you have and what functions to call.
what functions to call.
3.4.2.6 - Table Of Metavalues Used In Update Section 3.4.2.6 - Table Of Metavalues Used In Update Section
@ -181,10 +178,10 @@ changes to the RRset after calling SetTsig() the signature will be incorrect.
... ...
// When sending the TSIG RR is calculated and filled in before sending // When sending the TSIG RR is calculated and filled in before sending
When requesting an zone transfer (almost all TSIG usage is when requesting zone transfers), with When requesting an zone transfer (almost all TSIG usage is when requesting zone
TSIG, this is the basic use pattern. In this example we request an AXFR for transfers), with TSIG, this is the basic use pattern. In this example we
miek.nl. with TSIG key named "axfr." and secret "so6ZGir4GPAqINNh9U5c3A==" request an AXFR for miek.nl. with TSIG key named "axfr." and secret
and using the server 176.58.119.54: "so6ZGir4GPAqINNh9U5c3A==" and using the server 176.58.119.54:
t := new(dns.Transfer) t := new(dns.Transfer)
m := new(dns.Msg) m := new(dns.Msg)
@ -194,8 +191,8 @@ and using the server 176.58.119.54:
c, err := t.In(m, "176.58.119.54:53") c, err := t.In(m, "176.58.119.54:53")
for r := range c { ... } for r := range c { ... }
You can now read the records from the transfer as they come in. Each envelope is checked with TSIG. You can now read the records from the transfer as they come in. Each envelope
If something is not correct an error is returned. is checked with TSIG. If something is not correct an error is returned.
Basic use pattern validating and replying to a message that has TSIG set. Basic use pattern validating and replying to a message that has TSIG set.
@ -220,29 +217,30 @@ Basic use pattern validating and replying to a message that has TSIG set.
PRIVATE RRS PRIVATE RRS
RFC 6895 sets aside a range of type codes for private use. This range RFC 6895 sets aside a range of type codes for private use. This range is 65,280
is 65,280 - 65,534 (0xFF00 - 0xFFFE). When experimenting with new Resource Records these - 65,534 (0xFF00 - 0xFFFE). When experimenting with new Resource Records these
can be used, before requesting an official type code from IANA. can be used, before requesting an official type code from IANA.
see http://miek.nl/2014/September/21/idn-and-private-rr-in-go-dns/ for more See https://miek.nl/2014/September/21/idn-and-private-rr-in-go-dns/ for more
information. information.
EDNS0 EDNS0
EDNS0 is an extension mechanism for the DNS defined in RFC 2671 and updated EDNS0 is an extension mechanism for the DNS defined in RFC 2671 and updated by
by RFC 6891. It defines an new RR type, the OPT RR, which is then completely RFC 6891. It defines an new RR type, the OPT RR, which is then completely
abused. abused.
Basic use pattern for creating an (empty) OPT RR: Basic use pattern for creating an (empty) OPT RR:
o := new(dns.OPT) o := new(dns.OPT)
o.Hdr.Name = "." // MUST be the root zone, per definition. o.Hdr.Name = "." // MUST be the root zone, per definition.
o.Hdr.Rrtype = dns.TypeOPT o.Hdr.Rrtype = dns.TypeOPT
The rdata of an OPT RR consists out of a slice of EDNS0 (RFC 6891) The rdata of an OPT RR consists out of a slice of EDNS0 (RFC 6891) interfaces.
interfaces. Currently only a few have been standardized: EDNS0_NSID Currently only a few have been standardized: EDNS0_NSID (RFC 5001) and
(RFC 5001) and EDNS0_SUBNET (draft-vandergaast-edns-client-subnet-02). Note EDNS0_SUBNET (draft-vandergaast-edns-client-subnet-02). Note that these options
that these options may be combined in an OPT RR. may be combined in an OPT RR. Basic use pattern for a server to check if (and
Basic use pattern for a server to check if (and which) options are set: which) options are set:
// o is a dns.OPT // o is a dns.OPT
for _, s := range o.Option { for _, s := range o.Option {
@ -262,10 +260,9 @@ From RFC 2931:
... protection for glue records, DNS requests, protection for message headers ... protection for glue records, DNS requests, protection for message headers
on requests and responses, and protection of the overall integrity of a response. on requests and responses, and protection of the overall integrity of a response.
It works like TSIG, except that SIG(0) uses public key cryptography, instead of the shared It works like TSIG, except that SIG(0) uses public key cryptography, instead of
secret approach in TSIG. the shared secret approach in TSIG. Supported algorithms: DSA, ECDSAP256SHA256,
Supported algorithms: DSA, ECDSAP256SHA256, ECDSAP384SHA384, RSASHA1, RSASHA256 and ECDSAP384SHA384, RSASHA1, RSASHA256 and RSASHA512.
RSASHA512.
Signing subsequent messages in multi-message sessions is not implemented. Signing subsequent messages in multi-message sessions is not implemented.
*/ */

25
vendor/github.com/miekg/dns/duplicate.go generated vendored Normal file
View file

@ -0,0 +1,25 @@
package dns
//go:generate go run duplicate_generate.go
// IsDuplicate checks of r1 and r2 are duplicates of each other, excluding the TTL.
// So this means the header data is equal *and* the RDATA is the same. Return true
// is so, otherwise false.
// It's is a protocol violation to have identical RRs in a message.
func IsDuplicate(r1, r2 RR) bool {
if r1.Header().Class != r2.Header().Class {
return false
}
if r1.Header().Rrtype != r2.Header().Rrtype {
return false
}
if !isDulicateName(r1.Header().Name, r2.Header().Name) {
return false
}
// ignore TTL
return isDuplicateRdata(r1, r2)
}
// isDulicateName checks if the domain names s1 and s2 are equal.
func isDulicateName(s1, s2 string) bool { return equal(s1, s2) }

158
vendor/github.com/miekg/dns/duplicate_generate.go generated vendored Normal file
View file

@ -0,0 +1,158 @@
//+build ignore
// types_generate.go is meant to run with go generate. It will use
// go/{importer,types} to track down all the RR struct types. Then for each type
// it will generate conversion tables (TypeToRR and TypeToString) and banal
// methods (len, Header, copy) based on the struct tags. The generated source is
// written to ztypes.go, and is meant to be checked into git.
package main
import (
"bytes"
"fmt"
"go/format"
"go/importer"
"go/types"
"log"
"os"
)
var packageHdr = `
// Code generated by "go run duplicate_generate.go"; DO NOT EDIT.
package dns
`
func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
st, ok := t.Underlying().(*types.Struct)
if !ok {
return nil, false
}
if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
return st, false
}
if st.Field(0).Anonymous() {
st, _ := getTypeStruct(st.Field(0).Type(), scope)
return st, true
}
return nil, false
}
func main() {
// Import and type-check the package
pkg, err := importer.Default().Import("github.com/miekg/dns")
fatalIfErr(err)
scope := pkg.Scope()
// Collect actual types (*X)
var namedTypes []string
for _, name := range scope.Names() {
o := scope.Lookup(name)
if o == nil || !o.Exported() {
continue
}
if st, _ := getTypeStruct(o.Type(), scope); st == nil {
continue
}
if name == "PrivateRR" || name == "RFC3597" {
continue
}
if name == "OPT" || name == "ANY" || name == "IXFR" || name == "AXFR" {
continue
}
namedTypes = append(namedTypes, o.Name())
}
b := &bytes.Buffer{}
b.WriteString(packageHdr)
// Generate the giant switch that calls the correct function for each type.
fmt.Fprint(b, "// isDuplicateRdata calls the rdata specific functions\n")
fmt.Fprint(b, "func isDuplicateRdata(r1, r2 RR) bool {\n")
fmt.Fprint(b, "switch r1.Header().Rrtype {\n")
for _, name := range namedTypes {
o := scope.Lookup(name)
_, isEmbedded := getTypeStruct(o.Type(), scope)
if isEmbedded {
continue
}
fmt.Fprintf(b, "case Type%s:\nreturn isDuplicate%s(r1.(*%s), r2.(*%s))\n", name, name, name, name)
}
fmt.Fprintf(b, "}\nreturn false\n}\n")
// Generate the duplicate check for each type.
fmt.Fprint(b, "// isDuplicate() functions\n\n")
for _, name := range namedTypes {
o := scope.Lookup(name)
st, isEmbedded := getTypeStruct(o.Type(), scope)
if isEmbedded {
continue
}
fmt.Fprintf(b, "func isDuplicate%s(r1, r2 *%s) bool {\n", name, name)
for i := 1; i < st.NumFields(); i++ {
field := st.Field(i).Name()
o2 := func(s string) { fmt.Fprintf(b, s+"\n", field, field) }
o3 := func(s string) { fmt.Fprintf(b, s+"\n", field, field, field) }
// For some reason, a and aaaa don't pop up as *types.Slice here (mostly like because the are
// *indirectly* defined as a slice in the net package).
if _, ok := st.Field(i).Type().(*types.Slice); ok || st.Tag(i) == `dns:"a"` || st.Tag(i) == `dns:"aaaa"` {
o2("if len(r1.%s) != len(r2.%s) {\nreturn false\n}")
if st.Tag(i) == `dns:"cdomain-name"` || st.Tag(i) == `dns:"domain-name"` {
o3(`for i := 0; i < len(r1.%s); i++ {
if !isDulicateName(r1.%s[i], r2.%s[i]) {
return false
}
}`)
continue
}
o3(`for i := 0; i < len(r1.%s); i++ {
if r1.%s[i] != r2.%s[i] {
return false
}
}`)
continue
}
switch st.Tag(i) {
case `dns:"-"`:
// ignored
case `dns:"cdomain-name"`, `dns:"domain-name"`:
o2("if !isDulicateName(r1.%s, r2.%s) {\nreturn false\n}")
default:
o2("if r1.%s != r2.%s {\nreturn false\n}")
}
}
fmt.Fprintf(b, "return true\n}\n\n")
}
// gofmt
res, err := format.Source(b.Bytes())
if err != nil {
b.WriteTo(os.Stderr)
log.Fatal(err)
}
// write result
f, err := os.Create("zduplicate.go")
fatalIfErr(err)
defer f.Close()
f.Write(res)
}
func fatalIfErr(err error) {
if err != nil {
log.Fatal(err)
}
}

34
vendor/github.com/miekg/dns/edns.go generated vendored
View file

@ -78,8 +78,8 @@ func (rr *OPT) String() string {
return s return s
} }
func (rr *OPT) len() int { func (rr *OPT) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
for i := 0; i < len(rr.Option); i++ { for i := 0; i < len(rr.Option); i++ {
l += 4 // Account for 2-byte option code and 2-byte option length. l += 4 // Account for 2-byte option code and 2-byte option length.
lo, _ := rr.Option[i].pack() lo, _ := rr.Option[i].pack()
@ -92,22 +92,24 @@ func (rr *OPT) len() int {
// Version returns the EDNS version used. Only zero is defined. // Version returns the EDNS version used. Only zero is defined.
func (rr *OPT) Version() uint8 { func (rr *OPT) Version() uint8 {
return uint8((rr.Hdr.Ttl & 0x00FF0000) >> 16) return uint8(rr.Hdr.Ttl & 0x00FF0000 >> 16)
} }
// SetVersion sets the version of EDNS. This is usually zero. // SetVersion sets the version of EDNS. This is usually zero.
func (rr *OPT) SetVersion(v uint8) { func (rr *OPT) SetVersion(v uint8) {
rr.Hdr.Ttl = rr.Hdr.Ttl&0xFF00FFFF | (uint32(v) << 16) rr.Hdr.Ttl = rr.Hdr.Ttl&0xFF00FFFF | uint32(v)<<16
} }
// ExtendedRcode returns the EDNS extended RCODE field (the upper 8 bits of the TTL). // ExtendedRcode returns the EDNS extended RCODE field (the upper 8 bits of the TTL).
func (rr *OPT) ExtendedRcode() int { func (rr *OPT) ExtendedRcode() int {
return int((rr.Hdr.Ttl & 0xFF000000) >> 24) return int(rr.Hdr.Ttl&0xFF000000>>24) << 4
} }
// SetExtendedRcode sets the EDNS extended RCODE field. // SetExtendedRcode sets the EDNS extended RCODE field.
func (rr *OPT) SetExtendedRcode(v uint8) { //
rr.Hdr.Ttl = rr.Hdr.Ttl&0x00FFFFFF | (uint32(v) << 24) // If the RCODE is not an extended RCODE, will reset the extended RCODE field to 0.
func (rr *OPT) SetExtendedRcode(v uint16) {
rr.Hdr.Ttl = rr.Hdr.Ttl&0x00FFFFFF | uint32(v>>4)<<24
} }
// UDPSize returns the UDP buffer size. // UDPSize returns the UDP buffer size.
@ -271,22 +273,16 @@ func (e *EDNS0_SUBNET) unpack(b []byte) error {
if e.SourceNetmask > net.IPv4len*8 || e.SourceScope > net.IPv4len*8 { if e.SourceNetmask > net.IPv4len*8 || e.SourceScope > net.IPv4len*8 {
return errors.New("dns: bad netmask") return errors.New("dns: bad netmask")
} }
addr := make([]byte, net.IPv4len) addr := make(net.IP, net.IPv4len)
for i := 0; i < net.IPv4len && 4+i < len(b); i++ { copy(addr, b[4:])
addr[i] = b[4+i] e.Address = addr.To16()
}
e.Address = net.IPv4(addr[0], addr[1], addr[2], addr[3])
case 2: case 2:
if e.SourceNetmask > net.IPv6len*8 || e.SourceScope > net.IPv6len*8 { if e.SourceNetmask > net.IPv6len*8 || e.SourceScope > net.IPv6len*8 {
return errors.New("dns: bad netmask") return errors.New("dns: bad netmask")
} }
addr := make([]byte, net.IPv6len) addr := make(net.IP, net.IPv6len)
for i := 0; i < net.IPv6len && 4+i < len(b); i++ { copy(addr, b[4:])
addr[i] = b[4+i] e.Address = addr
}
e.Address = net.IP{addr[0], addr[1], addr[2], addr[3], addr[4],
addr[5], addr[6], addr[7], addr[8], addr[9], addr[10],
addr[11], addr[12], addr[13], addr[14], addr[15]}
default: default:
return errors.New("dns: bad address family") return errors.New("dns: bad address family")
} }

View file

@ -2,8 +2,8 @@ package dns
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"io"
"strconv" "strconv"
"strings" "strings"
) )
@ -18,142 +18,225 @@ import (
// * rhs (rdata) // * rhs (rdata)
// But we are lazy here, only the range is parsed *all* occurrences // But we are lazy here, only the range is parsed *all* occurrences
// of $ after that are interpreted. // of $ after that are interpreted.
// Any error are returned as a string value, the empty string signals func (zp *ZoneParser) generate(l lex) (RR, bool) {
// "no error". token := l.token
func generate(l lex, c chan lex, t chan *Token, o string) string {
step := 1 step := 1
if i := strings.IndexAny(l.token, "/"); i != -1 { if i := strings.IndexByte(token, '/'); i >= 0 {
if i+1 == len(l.token) { if i+1 == len(token) {
return "bad step in $GENERATE range" return zp.setParseError("bad step in $GENERATE range", l)
} }
if s, err := strconv.Atoi(l.token[i+1:]); err == nil {
if s < 0 { s, err := strconv.Atoi(token[i+1:])
return "bad step in $GENERATE range" if err != nil || s <= 0 {
return zp.setParseError("bad step in $GENERATE range", l)
} }
step = s step = s
} else { token = token[:i]
return "bad step in $GENERATE range"
} }
l.token = l.token[:i]
} sx := strings.SplitN(token, "-", 2)
sx := strings.SplitN(l.token, "-", 2)
if len(sx) != 2 { if len(sx) != 2 {
return "bad start-stop in $GENERATE range" return zp.setParseError("bad start-stop in $GENERATE range", l)
} }
start, err := strconv.Atoi(sx[0]) start, err := strconv.Atoi(sx[0])
if err != nil { if err != nil {
return "bad start in $GENERATE range" return zp.setParseError("bad start in $GENERATE range", l)
} }
end, err := strconv.Atoi(sx[1]) end, err := strconv.Atoi(sx[1])
if err != nil { if err != nil {
return "bad stop in $GENERATE range" return zp.setParseError("bad stop in $GENERATE range", l)
} }
if end < 0 || start < 0 || end < start { if end < 0 || start < 0 || end < start {
return "bad range in $GENERATE range" return zp.setParseError("bad range in $GENERATE range", l)
} }
<-c // _BLANK zp.c.Next() // _BLANK
// Create a complete new string, which we then parse again. // Create a complete new string, which we then parse again.
s := "" var s string
BuildRR: for l, ok := zp.c.Next(); ok; l, ok = zp.c.Next() {
l = <-c if l.err {
if l.value != zNewline && l.value != zEOF { return zp.setParseError("bad data in $GENERATE directive", l)
s += l.token }
goto BuildRR if l.value == zNewline {
break
} }
for i := start; i <= end; i += step {
var (
escape bool
dom bytes.Buffer
mod string
err error
offset int
)
for j := 0; j < len(s); j++ { // No 'range' because we need to jump around s += l.token
switch s[j] { }
r := &generateReader{
s: s,
cur: start,
start: start,
end: end,
step: step,
file: zp.file,
lex: &l,
}
zp.sub = NewZoneParser(r, zp.origin, zp.file)
zp.sub.includeDepth, zp.sub.includeAllowed = zp.includeDepth, zp.includeAllowed
zp.sub.SetDefaultTTL(defaultTtl)
return zp.subNext()
}
type generateReader struct {
s string
si int
cur int
start int
end int
step int
mod bytes.Buffer
escape bool
eof bool
file string
lex *lex
}
func (r *generateReader) parseError(msg string, end int) *ParseError {
r.eof = true // Make errors sticky.
l := *r.lex
l.token = r.s[r.si-1 : end]
l.column += r.si // l.column starts one zBLANK before r.s
return &ParseError{r.file, msg, l}
}
func (r *generateReader) Read(p []byte) (int, error) {
// NewZLexer, through NewZoneParser, should use ReadByte and
// not end up here.
panic("not implemented")
}
func (r *generateReader) ReadByte() (byte, error) {
if r.eof {
return 0, io.EOF
}
if r.mod.Len() > 0 {
return r.mod.ReadByte()
}
if r.si >= len(r.s) {
r.si = 0
r.cur += r.step
r.eof = r.cur > r.end || r.cur < 0
return '\n', nil
}
si := r.si
r.si++
switch r.s[si] {
case '\\': case '\\':
if escape { if r.escape {
dom.WriteByte('\\') r.escape = false
escape = false return '\\', nil
continue
} }
escape = true
r.escape = true
return r.ReadByte()
case '$': case '$':
mod = "%d" if r.escape {
offset = 0 r.escape = false
if escape { return '$', nil
dom.WriteByte('$')
escape = false
continue
} }
escape = false
if j+1 >= len(s) { // End of the string mod := "%d"
dom.WriteString(fmt.Sprintf(mod, i+offset))
continue if si >= len(r.s)-1 {
} else { // End of the string
if s[j+1] == '$' { fmt.Fprintf(&r.mod, mod, r.cur)
dom.WriteByte('$') return r.mod.ReadByte()
j++
continue
} }
if r.s[si+1] == '$' {
r.si++
return '$', nil
} }
var offset int
// Search for { and } // Search for { and }
if s[j+1] == '{' { // Modifier block if r.s[si+1] == '{' {
sep := strings.Index(s[j+2:], "}") // Modifier block
if sep == -1 { sep := strings.Index(r.s[si+2:], "}")
return "bad modifier in $GENERATE" if sep < 0 {
return 0, r.parseError("bad modifier in $GENERATE", len(r.s))
} }
mod, offset, err = modToPrintf(s[j+2 : j+2+sep])
if err != nil { var errMsg string
return err.Error() mod, offset, errMsg = modToPrintf(r.s[si+2 : si+2+sep])
if errMsg != "" {
return 0, r.parseError(errMsg, si+3+sep)
} }
j += 2 + sep // Jump to it if r.start+offset < 0 || r.end+offset > 1<<31-1 {
return 0, r.parseError("bad offset in $GENERATE", si+3+sep)
} }
dom.WriteString(fmt.Sprintf(mod, i+offset))
r.si += 2 + sep // Jump to it
}
fmt.Fprintf(&r.mod, mod, r.cur+offset)
return r.mod.ReadByte()
default: default:
if escape { // Pretty useless here if r.escape { // Pretty useless here
escape = false r.escape = false
continue return r.ReadByte()
} }
dom.WriteByte(s[j])
return r.s[si], nil
} }
} }
// Re-parse the RR and send it on the current channel t
rx, err := NewRR("$ORIGIN " + o + "\n" + dom.String())
if err != nil {
return err.Error()
}
t <- &Token{RR: rx}
// Its more efficient to first built the rrlist and then parse it in
// one go! But is this a problem?
}
return ""
}
// Convert a $GENERATE modifier 0,0,d to something Printf can deal with. // Convert a $GENERATE modifier 0,0,d to something Printf can deal with.
func modToPrintf(s string) (string, int, error) { func modToPrintf(s string) (string, int, string) {
xs := strings.SplitN(s, ",", 3) // Modifier is { offset [ ,width [ ,base ] ] } - provide default
if len(xs) != 3 { // values for optional width and type, if necessary.
return "", 0, errors.New("bad modifier in $GENERATE") var offStr, widthStr, base string
switch xs := strings.Split(s, ","); len(xs) {
case 1:
offStr, widthStr, base = xs[0], "0", "d"
case 2:
offStr, widthStr, base = xs[0], xs[1], "d"
case 3:
offStr, widthStr, base = xs[0], xs[1], xs[2]
default:
return "", 0, "bad modifier in $GENERATE"
} }
// xs[0] is offset, xs[1] is width, xs[2] is base
if xs[2] != "o" && xs[2] != "d" && xs[2] != "x" && xs[2] != "X" { switch base {
return "", 0, errors.New("bad base in $GENERATE") case "o", "d", "x", "X":
default:
return "", 0, "bad base in $GENERATE"
} }
offset, err := strconv.Atoi(xs[0])
if err != nil || offset > 255 { offset, err := strconv.Atoi(offStr)
return "", 0, errors.New("bad offset in $GENERATE") if err != nil {
return "", 0, "bad offset in $GENERATE"
} }
width, err := strconv.Atoi(xs[1])
if err != nil || width > 255 { width, err := strconv.Atoi(widthStr)
return "", offset, errors.New("bad width in $GENERATE") if err != nil || width < 0 || width > 255 {
return "", 0, "bad width in $GENERATE"
} }
switch {
case width < 0: if width == 0 {
return "", offset, errors.New("bad width in $GENERATE") return "%" + base, offset, ""
case width == 0:
return "%" + xs[1] + xs[2], offset, nil
} }
return "%0" + xs[1] + xs[2], offset, nil
return "%0" + widthStr + base, offset, ""
} }

View file

@ -178,10 +178,10 @@ func equal(a, b string) bool {
ai := a[i] ai := a[i]
bi := b[i] bi := b[i]
if ai >= 'A' && ai <= 'Z' { if ai >= 'A' && ai <= 'Z' {
ai |= ('a' - 'A') ai |= 'a' - 'A'
} }
if bi >= 'A' && bi <= 'Z' { if bi >= 'A' && bi <= 'Z' {
bi |= ('a' - 'A') bi |= 'a' - 'A'
} }
if ai != bi { if ai != bi {
return false return false

44
vendor/github.com/miekg/dns/listen_go111.go generated vendored Normal file
View file

@ -0,0 +1,44 @@
// +build go1.11
// +build aix darwin dragonfly freebsd linux netbsd openbsd
package dns
import (
"context"
"net"
"syscall"
"golang.org/x/sys/unix"
)
const supportsReusePort = true
func reuseportControl(network, address string, c syscall.RawConn) error {
var opErr error
err := c.Control(func(fd uintptr) {
opErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
})
if err != nil {
return err
}
return opErr
}
func listenTCP(network, addr string, reuseport bool) (net.Listener, error) {
var lc net.ListenConfig
if reuseport {
lc.Control = reuseportControl
}
return lc.Listen(context.Background(), network, addr)
}
func listenUDP(network, addr string, reuseport bool) (net.PacketConn, error) {
var lc net.ListenConfig
if reuseport {
lc.Control = reuseportControl
}
return lc.ListenPacket(context.Background(), network, addr)
}

23
vendor/github.com/miekg/dns/listen_go_not111.go generated vendored Normal file
View file

@ -0,0 +1,23 @@
// +build !go1.11 !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd
package dns
import "net"
const supportsReusePort = false
func listenTCP(network, addr string, reuseport bool) (net.Listener, error) {
if reuseport {
// TODO(tmthrgd): return an error?
}
return net.Listen(network, addr)
}
func listenUDP(network, addr string, reuseport bool) (net.PacketConn, error) {
if reuseport {
// TODO(tmthrgd): return an error?
}
return net.ListenPacket(network, addr)
}

613
vendor/github.com/miekg/dns/msg.go generated vendored
View file

@ -9,7 +9,6 @@
package dns package dns
//go:generate go run msg_generate.go //go:generate go run msg_generate.go
//go:generate go run compress_generate.go
import ( import (
crand "crypto/rand" crand "crypto/rand"
@ -18,12 +17,35 @@ import (
"math/big" "math/big"
"math/rand" "math/rand"
"strconv" "strconv"
"strings"
"sync" "sync"
) )
const ( const (
maxCompressionOffset = 2 << 13 // We have 14 bits for the compression pointer maxCompressionOffset = 2 << 13 // We have 14 bits for the compression pointer
maxDomainNameWireOctets = 255 // See RFC 1035 section 2.3.4 maxDomainNameWireOctets = 255 // See RFC 1035 section 2.3.4
// This is the maximum number of compression pointers that should occur in a
// semantically valid message. Each label in a domain name must be at least one
// octet and is separated by a period. The root label won't be represented by a
// compression pointer to a compression pointer, hence the -2 to exclude the
// smallest valid root label.
//
// It is possible to construct a valid message that has more compression pointers
// than this, and still doesn't loop, by pointing to a previous pointer. This is
// not something a well written implementation should ever do, so we leave them
// to trip the maximum compression pointer check.
maxCompressionPointers = (maxDomainNameWireOctets+1)/2 - 2
// This is the maximum length of a domain name in presentation format. The
// maximum wire length of a domain name is 255 octets (see above), with the
// maximum label length being 63. The wire format requires one extra byte over
// the presentation format, reducing the number of octets by 1. Each label in
// the name will be separated by a single period, with each octet in the label
// expanding to at most 4 bytes (\DDD). If all other labels are of the maximum
// length, then the final label can only be 61 octets long to not exceed the
// maximum allowed wire length.
maxDomainNamePresentationLength = 61*4 + 1 + 63*4 + 1 + 63*4 + 1 + 63*4 + 1
) )
// Errors defined in this package. // Errors defined in this package.
@ -49,7 +71,6 @@ var (
ErrSig error = &Error{err: "bad signature"} // ErrSig indicates that a signature can not be cryptographically validated. ErrSig error = &Error{err: "bad signature"} // ErrSig indicates that a signature can not be cryptographically validated.
ErrSoa error = &Error{err: "no SOA"} // ErrSOA indicates that no SOA RR was seen when doing zone transfers. ErrSoa error = &Error{err: "no SOA"} // ErrSOA indicates that no SOA RR was seen when doing zone transfers.
ErrTime error = &Error{err: "bad time"} // ErrTime indicates a timing error in TSIG authentication. ErrTime error = &Error{err: "bad time"} // ErrTime indicates a timing error in TSIG authentication.
ErrTruncated error = &Error{err: "failed to unpack truncated message"} // ErrTruncated indicates that we failed to unpack a truncated message. We unpacked as much as we had so Msg can still be used, if desired.
) )
// Id by default, returns a 16 bits random number to be used as a // Id by default, returns a 16 bits random number to be used as a
@ -151,7 +172,7 @@ var RcodeToString = map[int]string{
RcodeFormatError: "FORMERR", RcodeFormatError: "FORMERR",
RcodeServerFailure: "SERVFAIL", RcodeServerFailure: "SERVFAIL",
RcodeNameError: "NXDOMAIN", RcodeNameError: "NXDOMAIN",
RcodeNotImplemented: "NOTIMPL", RcodeNotImplemented: "NOTIMP",
RcodeRefused: "REFUSED", RcodeRefused: "REFUSED",
RcodeYXDomain: "YXDOMAIN", // See RFC 2136 RcodeYXDomain: "YXDOMAIN", // See RFC 2136
RcodeYXRrset: "YXRRSET", RcodeYXRrset: "YXRRSET",
@ -169,6 +190,39 @@ var RcodeToString = map[int]string{
RcodeBadCookie: "BADCOOKIE", RcodeBadCookie: "BADCOOKIE",
} }
// compressionMap is used to allow a more efficient compression map
// to be used for internal packDomainName calls without changing the
// signature or functionality of public API.
//
// In particular, map[string]uint16 uses 25% less per-entry memory
// than does map[string]int.
type compressionMap struct {
ext map[string]int // external callers
int map[string]uint16 // internal callers
}
func (m compressionMap) valid() bool {
return m.int != nil || m.ext != nil
}
func (m compressionMap) insert(s string, pos int) {
if m.ext != nil {
m.ext[s] = pos
} else {
m.int[s] = uint16(pos)
}
}
func (m compressionMap) find(s string) (int, bool) {
if m.ext != nil {
pos, ok := m.ext[s]
return pos, ok
}
pos, ok := m.int[s]
return int(pos), ok
}
// Domain names are a sequence of counted strings // Domain names are a sequence of counted strings
// split at the dots. They end with a zero-length string. // split at the dots. They end with a zero-length string.
@ -177,143 +231,168 @@ var RcodeToString = map[int]string{
// map needs to hold a mapping between domain names and offsets // map needs to hold a mapping between domain names and offsets
// pointing into msg. // pointing into msg.
func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) { func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
off1, _, err = packDomainName(s, msg, off, compression, compress) off1, _, err = packDomainName(s, msg, off, compressionMap{ext: compression}, compress)
return return
} }
func packDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, labels int, err error) { func packDomainName(s string, msg []byte, off int, compression compressionMap, compress bool) (off1 int, labels int, err error) {
// special case if msg == nil // special case if msg == nil
lenmsg := 256 lenmsg := 256
if msg != nil { if msg != nil {
lenmsg = len(msg) lenmsg = len(msg)
} }
ls := len(s) ls := len(s)
if ls == 0 { // Ok, for instance when dealing with update RR without any rdata. if ls == 0 { // Ok, for instance when dealing with update RR without any rdata.
return off, 0, nil return off, 0, nil
} }
// If not fully qualified, error out, but only if msg == nil #ugly
switch { // If not fully qualified, error out, but only if msg != nil #ugly
case msg == nil:
if s[ls-1] != '.' { if s[ls-1] != '.' {
if msg != nil {
return lenmsg, 0, ErrFqdn
}
s += "." s += "."
ls++ ls++
} }
case msg != nil:
if s[ls-1] != '.' {
return lenmsg, 0, ErrFqdn
}
}
// Each dot ends a segment of the name. // Each dot ends a segment of the name.
// We trade each dot byte for a length byte. // We trade each dot byte for a length byte.
// Except for escaped dots (\.), which are normal dots. // Except for escaped dots (\.), which are normal dots.
// There is also a trailing zero. // There is also a trailing zero.
// Compression // Compression
nameoffset := -1
pointer := -1 pointer := -1
// Emit sequence of counted strings, chopping at dots. // Emit sequence of counted strings, chopping at dots.
begin := 0 var (
bs := []byte(s) begin int
roBs, bsFresh, escapedDot := s, true, false compBegin int
compOff int
bs []byte
wasDot bool
)
loop:
for i := 0; i < ls; i++ { for i := 0; i < ls; i++ {
if bs[i] == '\\' { var c byte
for j := i; j < ls-1; j++ { if bs == nil {
bs[j] = bs[j+1] c = s[i]
} else {
c = bs[i]
} }
ls--
switch c {
case '\\':
if off+1 > lenmsg { if off+1 > lenmsg {
return lenmsg, labels, ErrBuf return lenmsg, labels, ErrBuf
} }
// check for \DDD
if i+2 < ls && isDigit(bs[i]) && isDigit(bs[i+1]) && isDigit(bs[i+2]) { if bs == nil {
bs[i] = dddToByte(bs[i:]) bs = []byte(s)
for j := i + 1; j < ls-2; j++ {
bs[j] = bs[j+2]
}
ls -= 2
}
escapedDot = bs[i] == '.'
bsFresh = false
continue
} }
if bs[i] == '.' { // check for \DDD
if i > 0 && bs[i-1] == '.' && !escapedDot { if i+3 < ls && isDigit(bs[i+1]) && isDigit(bs[i+2]) && isDigit(bs[i+3]) {
bs[i] = dddToByte(bs[i+1:])
copy(bs[i+1:ls-3], bs[i+4:])
ls -= 3
compOff += 3
} else {
copy(bs[i:ls-1], bs[i+1:])
ls--
compOff++
}
wasDot = false
case '.':
if wasDot {
// two dots back to back is not legal // two dots back to back is not legal
return lenmsg, labels, ErrRdata return lenmsg, labels, ErrRdata
} }
if i-begin >= 1<<6 { // top two bits of length must be clear wasDot = true
labelLen := i - begin
if labelLen >= 1<<6 { // top two bits of length must be clear
return lenmsg, labels, ErrRdata return lenmsg, labels, ErrRdata
} }
// off can already (we're in a loop) be bigger than len(msg) // off can already (we're in a loop) be bigger than len(msg)
// this happens when a name isn't fully qualified // this happens when a name isn't fully qualified
if off+1 > lenmsg { if off+1+labelLen > lenmsg {
return lenmsg, labels, ErrBuf return lenmsg, labels, ErrBuf
} }
if msg != nil {
msg[off] = byte(i - begin)
}
offset := off
off++
for j := begin; j < i; j++ {
if off+1 > lenmsg {
return lenmsg, labels, ErrBuf
}
if msg != nil {
msg[off] = bs[j]
}
off++
}
if compress && !bsFresh {
roBs = string(bs)
bsFresh = true
}
// Don't try to compress '.' // Don't try to compress '.'
// We should only compress when compress it true, but we should also still pick // We should only compress when compress is true, but we should also still pick
// up names that can be used for *future* compression(s). // up names that can be used for *future* compression(s).
if compression != nil && roBs[begin:] != "." { if compression.valid() && !isRootLabel(s, bs, begin, ls) {
if p, ok := compression[roBs[begin:]]; !ok { if p, ok := compression.find(s[compBegin:]); ok {
// Only offsets smaller than this can be used.
if offset < maxCompressionOffset {
compression[roBs[begin:]] = offset
}
} else {
// The first hit is the longest matching dname // The first hit is the longest matching dname
// keep the pointer offset we get back and store // keep the pointer offset we get back and store
// the offset of the current name, because that's // the offset of the current name, because that's
// where we need to insert the pointer later // where we need to insert the pointer later
// If compress is true, we're allowed to compress this dname // If compress is true, we're allowed to compress this dname
if pointer == -1 && compress { if compress {
pointer = p // Where to point to pointer = p // Where to point to
nameoffset = offset // Where to point from break loop
break }
} else if off < maxCompressionOffset {
// Only offsets smaller than maxCompressionOffset can be used.
compression.insert(s[compBegin:], off)
} }
} }
// The following is covered by the length check above.
if msg != nil {
msg[off] = byte(labelLen)
if bs == nil {
copy(msg[off+1:], s[begin:i])
} else {
copy(msg[off+1:], bs[begin:i])
} }
}
off += 1 + labelLen
labels++ labels++
begin = i + 1 begin = i + 1
compBegin = begin + compOff
default:
wasDot = false
} }
escapedDot = false
} }
// Root label is special // Root label is special
if len(bs) == 1 && bs[0] == '.' { if isRootLabel(s, bs, 0, ls) {
return off, labels, nil return off, labels, nil
} }
// If we did compression and we find something add the pointer here // If we did compression and we find something add the pointer here
if pointer != -1 { if pointer != -1 {
// We have two bytes (14 bits) to put the pointer in // We have two bytes (14 bits) to put the pointer in
// if msg == nil, we will never do compression // if msg == nil, we will never do compression
binary.BigEndian.PutUint16(msg[nameoffset:], uint16(pointer^0xC000)) binary.BigEndian.PutUint16(msg[off:], uint16(pointer^0xC000))
off = nameoffset + 1 return off + 2, labels, nil
goto End
} }
if msg != nil && off < len(msg) {
if msg != nil && off < lenmsg {
msg[off] = 0 msg[off] = 0
} }
End:
off++ return off + 1, labels, nil
return off, labels, nil }
// isRootLabel returns whether s or bs, from off to end, is the root
// label ".".
//
// If bs is nil, s will be checked, otherwise bs will be checked.
func isRootLabel(s string, bs []byte, off, end int) bool {
if bs == nil {
return s[off:end] == "."
}
return end-off == 1 && bs[off] == '.'
} }
// Unpack a domain name. // Unpack a domain name.
@ -330,12 +409,16 @@ End:
// In theory, the pointers are only allowed to jump backward. // In theory, the pointers are only allowed to jump backward.
// We let them jump anywhere and stop jumping after a while. // We let them jump anywhere and stop jumping after a while.
// UnpackDomainName unpacks a domain name into a string. // UnpackDomainName unpacks a domain name into a string. It returns
// the name, the new offset into msg and any error that occurred.
//
// When an error is encountered, the unpacked name will be discarded
// and len(msg) will be returned as the offset.
func UnpackDomainName(msg []byte, off int) (string, int, error) { func UnpackDomainName(msg []byte, off int) (string, int, error) {
s := make([]byte, 0, 64) s := make([]byte, 0, maxDomainNamePresentationLength)
off1 := 0 off1 := 0
lenmsg := len(msg) lenmsg := len(msg)
maxLen := maxDomainNameWireOctets budget := maxDomainNameWireOctets
ptr := 0 // number of pointers followed ptr := 0 // number of pointers followed
Loop: Loop:
for { for {
@ -354,27 +437,19 @@ Loop:
if off+c > lenmsg { if off+c > lenmsg {
return "", lenmsg, ErrBuf return "", lenmsg, ErrBuf
} }
budget -= c + 1 // +1 for the label separator
if budget <= 0 {
return "", lenmsg, ErrLongDomain
}
for j := off; j < off+c; j++ { for j := off; j < off+c; j++ {
switch b := msg[j]; b { switch b := msg[j]; b {
case '.', '(', ')', ';', ' ', '@': case '.', '(', ')', ';', ' ', '@':
fallthrough fallthrough
case '"', '\\': case '"', '\\':
s = append(s, '\\', b) s = append(s, '\\', b)
// presentation-format \X escapes add an extra byte
maxLen++
default: default:
if b < 32 || b >= 127 { // unprintable, use \DDD if b < ' ' || b > '~' { // unprintable, use \DDD
var buf [3]byte s = append(s, escapeByte(b)...)
bufs := strconv.AppendInt(buf[:0], int64(b), 10)
s = append(s, '\\')
for i := 0; i < 3-len(bufs); i++ {
s = append(s, '0')
}
for _, r := range bufs {
s = append(s, r)
}
// presentation-format \DDD escapes add 3 extra bytes
maxLen += 3
} else { } else {
s = append(s, b) s = append(s, b)
} }
@ -396,7 +471,7 @@ Loop:
if ptr == 0 { if ptr == 0 {
off1 = off off1 = off
} }
if ptr++; ptr > 10 { if ptr++; ptr > maxCompressionPointers {
return "", lenmsg, &Error{err: "too many compression pointers"} return "", lenmsg, &Error{err: "too many compression pointers"}
} }
// pointer should guarantee that it advances and points forwards at least // pointer should guarantee that it advances and points forwards at least
@ -412,10 +487,7 @@ Loop:
off1 = off off1 = off
} }
if len(s) == 0 { if len(s) == 0 {
s = []byte(".") return ".", off1, nil
} else if len(s) >= maxLen {
// error if the name is too long, but don't throw it away
return string(s), lenmsg, ErrLongDomain
} }
return string(s), off1, nil return string(s), off1, nil
} }
@ -512,7 +584,7 @@ func unpackTxt(msg []byte, off0 int) (ss []string, off int, err error) {
off = off0 off = off0
var s string var s string
for off < len(msg) && err == nil { for off < len(msg) && err == nil {
s, off, err = unpackTxtString(msg, off) s, off, err = unpackString(msg, off)
if err == nil { if err == nil {
ss = append(ss, s) ss = append(ss, s)
} }
@ -520,43 +592,16 @@ func unpackTxt(msg []byte, off0 int) (ss []string, off int, err error) {
return return
} }
func unpackTxtString(msg []byte, offset int) (string, int, error) {
if offset+1 > len(msg) {
return "", offset, &Error{err: "overflow unpacking txt"}
}
l := int(msg[offset])
if offset+l+1 > len(msg) {
return "", offset, &Error{err: "overflow unpacking txt"}
}
s := make([]byte, 0, l)
for _, b := range msg[offset+1 : offset+1+l] {
switch b {
case '"', '\\':
s = append(s, '\\', b)
default:
if b < 32 || b > 127 { // unprintable
var buf [3]byte
bufs := strconv.AppendInt(buf[:0], int64(b), 10)
s = append(s, '\\')
for i := 0; i < 3-len(bufs); i++ {
s = append(s, '0')
}
for _, r := range bufs {
s = append(s, r)
}
} else {
s = append(s, b)
}
}
}
offset += 1 + l
return string(s), offset, nil
}
// Helpers for dealing with escaped bytes // Helpers for dealing with escaped bytes
func isDigit(b byte) bool { return b >= '0' && b <= '9' } func isDigit(b byte) bool { return b >= '0' && b <= '9' }
func dddToByte(s []byte) byte { func dddToByte(s []byte) byte {
_ = s[2] // bounds check hint to compiler; see golang.org/issue/14808
return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0'))
}
func dddStringToByte(s string) byte {
_ = s[2] // bounds check hint to compiler; see golang.org/issue/14808
return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0')) return byte((s[0]-'0')*100 + (s[1]-'0')*10 + (s[2] - '0'))
} }
@ -574,19 +619,33 @@ func intToBytes(i *big.Int, length int) []byte {
// PackRR packs a resource record rr into msg[off:]. // PackRR packs a resource record rr into msg[off:].
// See PackDomainName for documentation about the compression. // See PackDomainName for documentation about the compression.
func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) { func PackRR(rr RR, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
if rr == nil { headerEnd, off1, err := packRR(rr, msg, off, compressionMap{ext: compression}, compress)
return len(msg), &Error{err: "nil rr"} if err == nil {
// packRR no longer sets the Rdlength field on the rr, but
// callers might be expecting it so we set it here.
rr.Header().Rdlength = uint16(off1 - headerEnd)
}
return off1, err
} }
off1, err = rr.pack(msg, off, compression, compress) func packRR(rr RR, msg []byte, off int, compression compressionMap, compress bool) (headerEnd int, off1 int, err error) {
if rr == nil {
return len(msg), len(msg), &Error{err: "nil rr"}
}
headerEnd, off1, err = rr.pack(msg, off, compression, compress)
if err != nil { if err != nil {
return len(msg), err return headerEnd, len(msg), err
} }
// TODO(miek): Not sure if this is needed? If removed we can remove rawmsg.go as well.
if rawSetRdlength(msg, off, off1) { rdlength := off1 - headerEnd
return off1, nil if int(uint16(rdlength)) != rdlength { // overflow
return headerEnd, len(msg), ErrRdata
} }
return off, ErrRdata
// The RDLENGTH field is the last field in the header and we set it here.
binary.BigEndian.PutUint16(msg[headerEnd-2:], uint16(rdlength))
return headerEnd, off1, nil
} }
// UnpackRR unpacks msg[off:] into an RR. // UnpackRR unpacks msg[off:] into an RR.
@ -595,6 +654,13 @@ func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) {
if err != nil { if err != nil {
return nil, len(msg), err return nil, len(msg), err
} }
return UnpackRRWithHeader(h, msg, off)
}
// UnpackRRWithHeader unpacks the record type specific payload given an existing
// RR_Header.
func UnpackRRWithHeader(h RR_Header, msg []byte, off int) (rr RR, off1 int, err error) {
end := off + int(h.Rdlength) end := off + int(h.Rdlength)
if fn, known := typeToUnpack[h.Rrtype]; !known { if fn, known := typeToUnpack[h.Rrtype]; !known {
@ -684,35 +750,37 @@ func (dns *Msg) Pack() (msg []byte, err error) {
return dns.PackBuffer(nil) return dns.PackBuffer(nil)
} }
// PackBuffer packs a Msg, using the given buffer buf. If buf is too small // PackBuffer packs a Msg, using the given buffer buf. If buf is too small a new buffer is allocated.
// a new buffer is allocated.
func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) { func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
// We use a similar function in tsig.go's stripTsig. // If this message can't be compressed, avoid filling the
var ( // compression map and creating garbage.
dh Header if dns.Compress && dns.isCompressible() {
compression map[string]int compression := make(map[string]uint16) // Compression pointer mappings.
) return dns.packBufferWithCompressionMap(buf, compressionMap{int: compression}, true)
if dns.Compress {
compression = make(map[string]int) // Compression pointer mappings
} }
return dns.packBufferWithCompressionMap(buf, compressionMap{}, false)
}
// packBufferWithCompressionMap packs a Msg, using the given buffer buf.
func (dns *Msg) packBufferWithCompressionMap(buf []byte, compression compressionMap, compress bool) (msg []byte, err error) {
if dns.Rcode < 0 || dns.Rcode > 0xFFF { if dns.Rcode < 0 || dns.Rcode > 0xFFF {
return nil, ErrRcode return nil, ErrRcode
} }
if dns.Rcode > 0xF {
// Regular RCODE field is 4 bits // Set extended rcode unconditionally if we have an opt, this will allow
opt := dns.IsEdns0() // reseting the extended rcode bits if they need to.
if opt == nil { if opt := dns.IsEdns0(); opt != nil {
opt.SetExtendedRcode(uint16(dns.Rcode))
} else if dns.Rcode > 0xF {
// If Rcode is an extended one and opt is nil, error out.
return nil, ErrExtendedRcode return nil, ErrExtendedRcode
} }
opt.SetExtendedRcode(uint8(dns.Rcode >> 4))
dns.Rcode &= 0xF
}
// Convert convenient Msg into wire-like Header. // Convert convenient Msg into wire-like Header.
var dh Header
dh.Id = dns.Id dh.Id = dns.Id
dh.Bits = uint16(dns.Opcode)<<11 | uint16(dns.Rcode) dh.Bits = uint16(dns.Opcode)<<11 | uint16(dns.Rcode&0xF)
if dns.Response { if dns.Response {
dh.Bits |= _QR dh.Bits |= _QR
} }
@ -738,50 +806,44 @@ func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
dh.Bits |= _CD dh.Bits |= _CD
} }
// Prepare variable sized arrays. dh.Qdcount = uint16(len(dns.Question))
question := dns.Question dh.Ancount = uint16(len(dns.Answer))
answer := dns.Answer dh.Nscount = uint16(len(dns.Ns))
ns := dns.Ns dh.Arcount = uint16(len(dns.Extra))
extra := dns.Extra
dh.Qdcount = uint16(len(question))
dh.Ancount = uint16(len(answer))
dh.Nscount = uint16(len(ns))
dh.Arcount = uint16(len(extra))
// We need the uncompressed length here, because we first pack it and then compress it. // We need the uncompressed length here, because we first pack it and then compress it.
msg = buf msg = buf
uncompressedLen := compressedLen(dns, false) uncompressedLen := msgLenWithCompressionMap(dns, nil)
if packLen := uncompressedLen + 1; len(msg) < packLen { if packLen := uncompressedLen + 1; len(msg) < packLen {
msg = make([]byte, packLen) msg = make([]byte, packLen)
} }
// Pack it in: header and then the pieces. // Pack it in: header and then the pieces.
off := 0 off := 0
off, err = dh.pack(msg, off, compression, dns.Compress) off, err = dh.pack(msg, off, compression, compress)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for i := 0; i < len(question); i++ { for _, r := range dns.Question {
off, err = question[i].pack(msg, off, compression, dns.Compress) off, err = r.pack(msg, off, compression, compress)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
for i := 0; i < len(answer); i++ { for _, r := range dns.Answer {
off, err = PackRR(answer[i], msg, off, compression, dns.Compress) _, off, err = packRR(r, msg, off, compression, compress)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
for i := 0; i < len(ns); i++ { for _, r := range dns.Ns {
off, err = PackRR(ns[i], msg, off, compression, dns.Compress) _, off, err = packRR(r, msg, off, compression, compress)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
for i := 0; i < len(extra); i++ { for _, r := range dns.Extra {
off, err = PackRR(extra[i], msg, off, compression, dns.Compress) _, off, err = packRR(r, msg, off, compression, compress)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -789,28 +851,7 @@ func (dns *Msg) PackBuffer(buf []byte) (msg []byte, err error) {
return msg[:off], nil return msg[:off], nil
} }
// Unpack unpacks a binary message to a Msg structure. func (dns *Msg) unpack(dh Header, msg []byte, off int) (err error) {
func (dns *Msg) Unpack(msg []byte) (err error) {
var (
dh Header
off int
)
if dh, off, err = unpackMsgHdr(msg, off); err != nil {
return err
}
dns.Id = dh.Id
dns.Response = (dh.Bits & _QR) != 0
dns.Opcode = int(dh.Bits>>11) & 0xF
dns.Authoritative = (dh.Bits & _AA) != 0
dns.Truncated = (dh.Bits & _TC) != 0
dns.RecursionDesired = (dh.Bits & _RD) != 0
dns.RecursionAvailable = (dh.Bits & _RA) != 0
dns.Zero = (dh.Bits & _Z) != 0
dns.AuthenticatedData = (dh.Bits & _AD) != 0
dns.CheckingDisabled = (dh.Bits & _CD) != 0
dns.Rcode = int(dh.Bits & 0xF)
// If we are at the end of the message we should return *just* the // If we are at the end of the message we should return *just* the
// header. This can still be useful to the caller. 9.9.9.9 sends these // header. This can still be useful to the caller. 9.9.9.9 sends these
// when responding with REFUSED for instance. // when responding with REFUSED for instance.
@ -829,8 +870,6 @@ func (dns *Msg) Unpack(msg []byte) (err error) {
var q Question var q Question
q, off, err = unpackQuestion(msg, off) q, off, err = unpackQuestion(msg, off)
if err != nil { if err != nil {
// Even if Truncated is set, we only will set ErrTruncated if we
// actually got the questions
return err return err
} }
if off1 == off { // Offset does not increase anymore, dh.Qdcount is a lie! if off1 == off { // Offset does not increase anymore, dh.Qdcount is a lie!
@ -854,16 +893,29 @@ func (dns *Msg) Unpack(msg []byte) (err error) {
// The header counts might have been wrong so we need to update it // The header counts might have been wrong so we need to update it
dh.Arcount = uint16(len(dns.Extra)) dh.Arcount = uint16(len(dns.Extra))
// Set extended Rcode
if opt := dns.IsEdns0(); opt != nil {
dns.Rcode |= opt.ExtendedRcode()
}
if off != len(msg) { if off != len(msg) {
// TODO(miek) make this an error? // TODO(miek) make this an error?
// use PackOpt to let people tell how detailed the error reporting should be? // use PackOpt to let people tell how detailed the error reporting should be?
// println("dns: extra bytes in dns packet", off, "<", len(msg)) // println("dns: extra bytes in dns packet", off, "<", len(msg))
} else if dns.Truncated {
// Whether we ran into a an error or not, we want to return that it
// was truncated
err = ErrTruncated
} }
return err return err
}
// Unpack unpacks a binary message to a Msg structure.
func (dns *Msg) Unpack(msg []byte) (err error) {
dh, off, err := unpackMsgHdr(msg, 0)
if err != nil {
return err
}
dns.setHdr(dh)
return dns.unpack(dh, msg, off)
} }
// Convert a complete message to a string with dig-like output. // Convert a complete message to a string with dig-like output.
@ -909,99 +961,109 @@ func (dns *Msg) String() string {
return s return s
} }
// isCompressible returns whether the msg may be compressible.
func (dns *Msg) isCompressible() bool {
// If we only have one question, there is nothing we can ever compress.
return len(dns.Question) > 1 || len(dns.Answer) > 0 ||
len(dns.Ns) > 0 || len(dns.Extra) > 0
}
// Len returns the message length when in (un)compressed wire format. // Len returns the message length when in (un)compressed wire format.
// If dns.Compress is true compression it is taken into account. Len() // If dns.Compress is true compression it is taken into account. Len()
// is provided to be a faster way to get the size of the resulting packet, // is provided to be a faster way to get the size of the resulting packet,
// than packing it, measuring the size and discarding the buffer. // than packing it, measuring the size and discarding the buffer.
func (dns *Msg) Len() int { return compressedLen(dns, dns.Compress) } func (dns *Msg) Len() int {
// If this message can't be compressed, avoid filling the
// compressedLen returns the message length when in compressed wire format // compression map and creating garbage.
// when compress is true, otherwise the uncompressed length is returned. if dns.Compress && dns.isCompressible() {
func compressedLen(dns *Msg, compress bool) int { compression := make(map[string]struct{})
// We always return one more than needed. return msgLenWithCompressionMap(dns, compression)
l := 12 // Message header is always 12 bytes
if compress {
compression := map[string]int{}
for _, r := range dns.Question {
l += r.len()
compressionLenHelper(compression, r.Name)
} }
l += compressionLenSlice(compression, dns.Answer)
l += compressionLenSlice(compression, dns.Ns) return msgLenWithCompressionMap(dns, nil)
l += compressionLenSlice(compression, dns.Extra) }
} else {
func msgLenWithCompressionMap(dns *Msg, compression map[string]struct{}) int {
l := 12 // Message header is always 12 bytes
for _, r := range dns.Question { for _, r := range dns.Question {
l += r.len() l += r.len(l, compression)
} }
for _, r := range dns.Answer { for _, r := range dns.Answer {
if r != nil { if r != nil {
l += r.len() l += r.len(l, compression)
} }
} }
for _, r := range dns.Ns { for _, r := range dns.Ns {
if r != nil { if r != nil {
l += r.len() l += r.len(l, compression)
} }
} }
for _, r := range dns.Extra { for _, r := range dns.Extra {
if r != nil { if r != nil {
l += r.len() l += r.len(l, compression)
}
} }
} }
return l return l
} }
func compressionLenSlice(c map[string]int, rs []RR) int { func domainNameLen(s string, off int, compression map[string]struct{}, compress bool) int {
var l int if s == "" || s == "." {
for _, r := range rs { return 1
if r == nil { }
escaped := strings.Contains(s, "\\")
if compression != nil && (compress || off < maxCompressionOffset) {
// compressionLenSearch will insert the entry into the compression
// map if it doesn't contain it.
if l, ok := compressionLenSearch(compression, s, off); ok && compress {
if escaped {
return escapedNameLen(s[:l]) + 2
}
return l + 2
}
}
if escaped {
return escapedNameLen(s) + 1
}
return len(s) + 1
}
func escapedNameLen(s string) int {
nameLen := len(s)
for i := 0; i < len(s); i++ {
if s[i] != '\\' {
continue continue
} }
l += r.len()
k, ok := compressionLenSearch(c, r.Header().Name)
if ok {
l += 1 - k
}
compressionLenHelper(c, r.Header().Name)
k, ok = compressionLenSearchType(c, r)
if ok {
l += 1 - k
}
compressionLenHelperType(c, r)
}
return l
}
// Put the parts of the name in the compression map. if i+3 < len(s) && isDigit(s[i+1]) && isDigit(s[i+2]) && isDigit(s[i+3]) {
func compressionLenHelper(c map[string]int, s string) { nameLen -= 3
pref := "" i += 3
lbs := Split(s) } else {
for j := len(lbs) - 1; j >= 0; j-- { nameLen--
pref = s[lbs[j]:] i++
if _, ok := c[pref]; !ok {
c[pref] = len(pref)
}
} }
} }
// Look for each part in the compression map and returns its length, return nameLen
// keep on searching so we get the longest match.
func compressionLenSearch(c map[string]int, s string) (int, bool) {
off := 0
end := false
if s == "" { // don't bork on bogus data
return 0, false
} }
for {
func compressionLenSearch(c map[string]struct{}, s string, msgOff int) (int, bool) {
for off, end := 0, false; !end; off, end = NextLabel(s, off) {
if _, ok := c[s[off:]]; ok { if _, ok := c[s[off:]]; ok {
return len(s[off:]), true return off, true
} }
if end {
break if msgOff+off < maxCompressionOffset {
c[s[off:]] = struct{}{}
} }
off, end = NextLabel(s, off)
} }
return 0, false return 0, false
} }
@ -1009,7 +1071,7 @@ func compressionLenSearch(c map[string]int, s string) (int, bool) {
func Copy(r RR) RR { r1 := r.copy(); return r1 } func Copy(r RR) RR { r1 := r.copy(); return r1 }
// Len returns the length (in octets) of the uncompressed RR in wire format. // Len returns the length (in octets) of the uncompressed RR in wire format.
func Len(r RR) int { return r.len() } func Len(r RR) int { return r.len(0, nil) }
// Copy returns a new *Msg which is a deep-copy of dns. // Copy returns a new *Msg which is a deep-copy of dns.
func (dns *Msg) Copy() *Msg { return dns.CopyTo(new(Msg)) } func (dns *Msg) Copy() *Msg { return dns.CopyTo(new(Msg)) }
@ -1057,8 +1119,8 @@ func (dns *Msg) CopyTo(r1 *Msg) *Msg {
return r1 return r1
} }
func (q *Question) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { func (q *Question) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
off, err := PackDomainName(q.Name, msg, off, compression, compress) off, _, err := packDomainName(q.Name, msg, off, compression, compress)
if err != nil { if err != nil {
return off, err return off, err
} }
@ -1099,7 +1161,7 @@ func unpackQuestion(msg []byte, off int) (Question, int, error) {
return q, off, err return q, off, err
} }
func (dh *Header) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) { func (dh *Header) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
off, err := packUint16(dh.Id, msg, off) off, err := packUint16(dh.Id, msg, off)
if err != nil { if err != nil {
return off, err return off, err
@ -1152,3 +1214,18 @@ func unpackMsgHdr(msg []byte, off int) (Header, int, error) {
dh.Arcount, off, err = unpackUint16(msg, off) dh.Arcount, off, err = unpackUint16(msg, off)
return dh, off, err return dh, off, err
} }
// setHdr set the header in the dns using the binary data in dh.
func (dns *Msg) setHdr(dh Header) {
dns.Id = dh.Id
dns.Response = dh.Bits&_QR != 0
dns.Opcode = int(dh.Bits>>11) & 0xF
dns.Authoritative = dh.Bits&_AA != 0
dns.Truncated = dh.Bits&_TC != 0
dns.RecursionDesired = dh.Bits&_RD != 0
dns.RecursionAvailable = dh.Bits&_RA != 0
dns.Zero = dh.Bits&_Z != 0 // _Z covers the zero bit, which should be zero; not sure why we set it to the opposite.
dns.AuthenticatedData = dh.Bits&_AD != 0
dns.CheckingDisabled = dh.Bits&_CD != 0
dns.Rcode = int(dh.Bits & 0xF)
}

View file

@ -80,18 +80,17 @@ func main() {
o := scope.Lookup(name) o := scope.Lookup(name)
st, _ := getTypeStruct(o.Type(), scope) st, _ := getTypeStruct(o.Type(), scope)
fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {\n", name) fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression compressionMap, compress bool) (int, int, error) {\n", name)
fmt.Fprint(b, `off, err := rr.Hdr.pack(msg, off, compression, compress) fmt.Fprint(b, `headerEnd, off, err := rr.Hdr.pack(msg, off, compression, compress)
if err != nil { if err != nil {
return off, err return headerEnd, off, err
} }
headerEnd := off
`) `)
for i := 1; i < st.NumFields(); i++ { for i := 1; i < st.NumFields(); i++ {
o := func(s string) { o := func(s string) {
fmt.Fprintf(b, s, st.Field(i).Name()) fmt.Fprintf(b, s, st.Field(i).Name())
fmt.Fprint(b, `if err != nil { fmt.Fprint(b, `if err != nil {
return off, err return headerEnd, off, err
} }
`) `)
} }
@ -106,7 +105,7 @@ return off, err
case `dns:"nsec"`: case `dns:"nsec"`:
o("off, err = packDataNsec(rr.%s, msg, off)\n") o("off, err = packDataNsec(rr.%s, msg, off)\n")
case `dns:"domain-name"`: case `dns:"domain-name"`:
o("off, err = packDataDomainNames(rr.%s, msg, off, compression, compress)\n") o("off, err = packDataDomainNames(rr.%s, msg, off, compression, false)\n")
default: default:
log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
} }
@ -116,9 +115,9 @@ return off, err
switch { switch {
case st.Tag(i) == `dns:"-"`: // ignored case st.Tag(i) == `dns:"-"`: // ignored
case st.Tag(i) == `dns:"cdomain-name"`: case st.Tag(i) == `dns:"cdomain-name"`:
o("off, err = PackDomainName(rr.%s, msg, off, compression, compress)\n") o("off, _, err = packDomainName(rr.%s, msg, off, compression, compress)\n")
case st.Tag(i) == `dns:"domain-name"`: case st.Tag(i) == `dns:"domain-name"`:
o("off, err = PackDomainName(rr.%s, msg, off, compression, false)\n") o("off, _, err = packDomainName(rr.%s, msg, off, compression, false)\n")
case st.Tag(i) == `dns:"a"`: case st.Tag(i) == `dns:"a"`:
o("off, err = packDataA(rr.%s, msg, off)\n") o("off, err = packDataA(rr.%s, msg, off)\n")
case st.Tag(i) == `dns:"aaaa"`: case st.Tag(i) == `dns:"aaaa"`:
@ -145,7 +144,7 @@ return off, err
if rr.%s != "-" { if rr.%s != "-" {
off, err = packStringHex(rr.%s, msg, off) off, err = packStringHex(rr.%s, msg, off)
if err != nil { if err != nil {
return off, err return headerEnd, off, err
} }
} }
`, field, field) `, field, field)
@ -176,9 +175,7 @@ if rr.%s != "-" {
log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
} }
} }
// We have packed everything, only now we know the rdlength of this RR fmt.Fprintln(b, "return headerEnd, off, nil }\n")
fmt.Fprintln(b, "rr.Header().Rdlength = uint16(off-headerEnd)")
fmt.Fprintln(b, "return off, nil }\n")
} }
fmt.Fprint(b, "// unpack*() functions\n\n") fmt.Fprint(b, "// unpack*() functions\n\n")

View file

@ -6,7 +6,7 @@ import (
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"net" "net"
"strconv" "strings"
) )
// helper functions called from the generated zmsg.go // helper functions called from the generated zmsg.go
@ -101,32 +101,32 @@ func unpackHeader(msg []byte, off int) (rr RR_Header, off1 int, truncmsg []byte,
// pack packs an RR header, returning the offset to the end of the header. // pack packs an RR header, returning the offset to the end of the header.
// See PackDomainName for documentation about the compression. // See PackDomainName for documentation about the compression.
func (hdr RR_Header) pack(msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) { func (hdr RR_Header) pack(msg []byte, off int, compression compressionMap, compress bool) (int, int, error) {
if off == len(msg) { if off == len(msg) {
return off, nil return off, off, nil
} }
off, err = PackDomainName(hdr.Name, msg, off, compression, compress) off, _, err := packDomainName(hdr.Name, msg, off, compression, compress)
if err != nil { if err != nil {
return len(msg), err return off, len(msg), err
} }
off, err = packUint16(hdr.Rrtype, msg, off) off, err = packUint16(hdr.Rrtype, msg, off)
if err != nil { if err != nil {
return len(msg), err return off, len(msg), err
} }
off, err = packUint16(hdr.Class, msg, off) off, err = packUint16(hdr.Class, msg, off)
if err != nil { if err != nil {
return len(msg), err return off, len(msg), err
} }
off, err = packUint32(hdr.Ttl, msg, off) off, err = packUint32(hdr.Ttl, msg, off)
if err != nil { if err != nil {
return len(msg), err return off, len(msg), err
} }
off, err = packUint16(hdr.Rdlength, msg, off) off, err = packUint16(0, msg, off) // The RDLENGTH field will be set later in packRR.
if err != nil { if err != nil {
return len(msg), err return off, len(msg), err
} }
return off, nil return off, off, nil
} }
// helper helper functions. // helper helper functions.
@ -141,20 +141,24 @@ func truncateMsgFromRdlength(msg []byte, off int, rdlength uint16) (truncmsg []b
return msg[:lenrd], nil return msg[:lenrd], nil
} }
var base32HexNoPadEncoding = base32.HexEncoding.WithPadding(base32.NoPadding)
func fromBase32(s []byte) (buf []byte, err error) { func fromBase32(s []byte) (buf []byte, err error) {
for i, b := range s { for i, b := range s {
if b >= 'a' && b <= 'z' { if b >= 'a' && b <= 'z' {
s[i] = b - 32 s[i] = b - 32
} }
} }
buflen := base32.HexEncoding.DecodedLen(len(s)) buflen := base32HexNoPadEncoding.DecodedLen(len(s))
buf = make([]byte, buflen) buf = make([]byte, buflen)
n, err := base32.HexEncoding.Decode(buf, s) n, err := base32HexNoPadEncoding.Decode(buf, s)
buf = buf[:n] buf = buf[:n]
return return
} }
func toBase32(b []byte) string { return base32.HexEncoding.EncodeToString(b) } func toBase32(b []byte) string {
return base32HexNoPadEncoding.EncodeToString(b)
}
func fromBase64(s []byte) (buf []byte, err error) { func fromBase64(s []byte) (buf []byte, err error) {
buflen := base64.StdEncoding.DecodedLen(len(s)) buflen := base64.StdEncoding.DecodedLen(len(s))
@ -219,8 +223,8 @@ func unpackUint48(msg []byte, off int) (i uint64, off1 int, err error) {
return 0, len(msg), &Error{err: "overflow unpacking uint64 as uint48"} return 0, len(msg), &Error{err: "overflow unpacking uint64 as uint48"}
} }
// Used in TSIG where the last 48 bits are occupied, so for now, assume a uint48 (6 bytes) // Used in TSIG where the last 48 bits are occupied, so for now, assume a uint48 (6 bytes)
i = (uint64(uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 | i = uint64(msg[off])<<40 | uint64(msg[off+1])<<32 | uint64(msg[off+2])<<24 | uint64(msg[off+3])<<16 |
uint64(msg[off+4])<<8 | uint64(msg[off+5]))) uint64(msg[off+4])<<8 | uint64(msg[off+5])
off += 6 off += 6
return i, off, nil return i, off, nil
} }
@ -263,29 +267,21 @@ func unpackString(msg []byte, off int) (string, int, error) {
if off+l+1 > len(msg) { if off+l+1 > len(msg) {
return "", off, &Error{err: "overflow unpacking txt"} return "", off, &Error{err: "overflow unpacking txt"}
} }
s := make([]byte, 0, l) var s strings.Builder
s.Grow(l)
for _, b := range msg[off+1 : off+1+l] { for _, b := range msg[off+1 : off+1+l] {
switch b { switch {
case '"', '\\': case b == '"' || b == '\\':
s = append(s, '\\', b) s.WriteByte('\\')
s.WriteByte(b)
case b < ' ' || b > '~': // unprintable
s.WriteString(escapeByte(b))
default: default:
if b < 32 || b > 127 { // unprintable s.WriteByte(b)
var buf [3]byte
bufs := strconv.AppendInt(buf[:0], int64(b), 10)
s = append(s, '\\')
for i := 0; i < 3-len(bufs); i++ {
s = append(s, '0')
}
for _, r := range bufs {
s = append(s, r)
}
} else {
s = append(s, b)
}
} }
} }
off += 1 + l off += 1 + l
return string(s), off, nil return s.String(), off, nil
} }
func packString(s string, msg []byte, off int) (int, error) { func packString(s string, msg []byte, off int) (int, error) {
@ -359,7 +355,7 @@ func packStringHex(s string, msg []byte, off int) (int, error) {
if err != nil { if err != nil {
return len(msg), err return len(msg), err
} }
if off+(len(h)) > len(msg) { if off+len(h) > len(msg) {
return len(msg), &Error{err: "overflow packing hex"} return len(msg), &Error{err: "overflow packing hex"}
} }
copy(msg[off:off+len(h)], h) copy(msg[off:off+len(h)], h)
@ -599,7 +595,7 @@ func packDataNsec(bitmap []uint16, msg []byte, off int) (int, error) {
// Setting the octets length // Setting the octets length
msg[off+1] = byte(length) msg[off+1] = byte(length)
// Setting the bit value for the type in the right octet // Setting the bit value for the type in the right octet
msg[off+1+int(length)] |= byte(1 << (7 - (t % 8))) msg[off+1+int(length)] |= byte(1 << (7 - t%8))
lastwindow, lastlength = window, length lastwindow, lastlength = window, length
} }
off += int(lastlength) + 2 off += int(lastlength) + 2
@ -625,10 +621,10 @@ func unpackDataDomainNames(msg []byte, off, end int) ([]string, int, error) {
return servers, off, nil return servers, off, nil
} }
func packDataDomainNames(names []string, msg []byte, off int, compression map[string]int, compress bool) (int, error) { func packDataDomainNames(names []string, msg []byte, off int, compression compressionMap, compress bool) (int, error) {
var err error var err error
for j := 0; j < len(names); j++ { for j := 0; j < len(names); j++ {
off, err = PackDomainName(names[j], msg, off, compression, false && compress) off, _, err = packDomainName(names[j], msg, off, compression, compress)
if err != nil { if err != nil {
return len(msg), err return len(msg), err
} }

47
vendor/github.com/miekg/dns/nsecx.go generated vendored
View file

@ -2,49 +2,44 @@ package dns
import ( import (
"crypto/sha1" "crypto/sha1"
"hash" "encoding/hex"
"strings" "strings"
) )
type saltWireFmt struct {
Salt string `dns:"size-hex"`
}
// HashName hashes a string (label) according to RFC 5155. It returns the hashed string in uppercase. // HashName hashes a string (label) according to RFC 5155. It returns the hashed string in uppercase.
func HashName(label string, ha uint8, iter uint16, salt string) string { func HashName(label string, ha uint8, iter uint16, salt string) string {
saltwire := new(saltWireFmt) if ha != SHA1 {
saltwire.Salt = salt return ""
wire := make([]byte, DefaultMsgSize) }
n, err := packSaltWire(saltwire, wire)
wireSalt := make([]byte, hex.DecodedLen(len(salt)))
n, err := packStringHex(salt, wireSalt, 0)
if err != nil { if err != nil {
return "" return ""
} }
wire = wire[:n] wireSalt = wireSalt[:n]
name := make([]byte, 255) name := make([]byte, 255)
off, err := PackDomainName(strings.ToLower(label), name, 0, nil, false) off, err := PackDomainName(strings.ToLower(label), name, 0, nil, false)
if err != nil { if err != nil {
return "" return ""
} }
name = name[:off] name = name[:off]
var s hash.Hash
switch ha {
case SHA1:
s = sha1.New()
default:
return ""
}
s := sha1.New()
// k = 0 // k = 0
s.Write(name) s.Write(name)
s.Write(wire) s.Write(wireSalt)
nsec3 := s.Sum(nil) nsec3 := s.Sum(nil)
// k > 0 // k > 0
for k := uint16(0); k < iter; k++ { for k := uint16(0); k < iter; k++ {
s.Reset() s.Reset()
s.Write(nsec3) s.Write(nsec3)
s.Write(wire) s.Write(wireSalt)
nsec3 = s.Sum(nsec3[:0]) nsec3 = s.Sum(nsec3[:0])
} }
return toBase32(nsec3) return toBase32(nsec3)
} }
@ -63,8 +58,10 @@ func (rr *NSEC3) Cover(name string) bool {
} }
nextHash := rr.NextDomain nextHash := rr.NextDomain
if ownerHash == nextHash { // empty interval
return false // if empty interval found, try cover wildcard hashes so nameHash shouldn't match with ownerHash
if ownerHash == nextHash && nameHash != ownerHash { // empty interval
return true
} }
if ownerHash > nextHash { // end of zone if ownerHash > nextHash { // end of zone
if nameHash > ownerHash { // covered since there is nothing after ownerHash if nameHash > ownerHash { // covered since there is nothing after ownerHash
@ -96,11 +93,3 @@ func (rr *NSEC3) Match(name string) bool {
} }
return false return false
} }
func packSaltWire(sw *saltWireFmt, msg []byte) (int, error) {
off, err := packStringHex(sw.Salt, msg, 0)
if err != nil {
return off, err
}
return off, nil
}

View file

@ -52,12 +52,16 @@ func (r *PrivateRR) Header() *RR_Header { return &r.Hdr }
func (r *PrivateRR) String() string { return r.Hdr.String() + r.Data.String() } func (r *PrivateRR) String() string { return r.Hdr.String() + r.Data.String() }
// Private len and copy parts to satisfy RR interface. // Private len and copy parts to satisfy RR interface.
func (r *PrivateRR) len() int { return r.Hdr.len() + r.Data.Len() } func (r *PrivateRR) len(off int, compression map[string]struct{}) int {
l := r.Hdr.len(off, compression)
l += r.Data.Len()
return l
}
func (r *PrivateRR) copy() RR { func (r *PrivateRR) copy() RR {
// make new RR like this: // make new RR like this:
rr := mkPrivateRR(r.Hdr.Rrtype) rr := mkPrivateRR(r.Hdr.Rrtype)
newh := r.Hdr.copyHeader() rr.Hdr = r.Hdr
rr.Hdr = *newh
err := r.Data.Copy(rr.Data) err := r.Data.Copy(rr.Data)
if err != nil { if err != nil {
@ -65,19 +69,18 @@ func (r *PrivateRR) copy() RR {
} }
return rr return rr
} }
func (r *PrivateRR) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {
off, err := r.Hdr.pack(msg, off, compression, compress) func (r *PrivateRR) pack(msg []byte, off int, compression compressionMap, compress bool) (int, int, error) {
headerEnd, off, err := r.Hdr.pack(msg, off, compression, compress)
if err != nil { if err != nil {
return off, err return off, off, err
} }
headerEnd := off
n, err := r.Data.Pack(msg[off:]) n, err := r.Data.Pack(msg[off:])
if err != nil { if err != nil {
return len(msg), err return headerEnd, len(msg), err
} }
off += n off += n
r.Header().Rdlength = uint16(off - headerEnd) return headerEnd, off, nil
return off, nil
} }
// PrivateHandle registers a private resource record type. It requires // PrivateHandle registers a private resource record type. It requires
@ -106,7 +109,7 @@ func PrivateHandle(rtypestr string, rtype uint16, generator func() PrivateRdata)
return rr, off, err return rr, off, err
} }
setPrivateRR := func(h RR_Header, c chan lex, o, f string) (RR, *ParseError, string) { setPrivateRR := func(h RR_Header, c *zlexer, o, f string) (RR, *ParseError, string) {
rr := mkPrivateRR(h.Rrtype) rr := mkPrivateRR(h.Rrtype)
rr.Hdr = h rr.Hdr = h
@ -116,7 +119,7 @@ func PrivateHandle(rtypestr string, rtype uint16, generator func() PrivateRdata)
for { for {
// TODO(miek): we could also be returning _QUOTE, this might or might not // TODO(miek): we could also be returning _QUOTE, this might or might not
// be an issue (basically parsing TXT becomes hard) // be an issue (basically parsing TXT becomes hard)
switch l = <-c; l.value { switch l, _ = c.Next(); l.value {
case zNewline, zEOF: case zNewline, zEOF:
break Fetch break Fetch
case zString: case zString:
@ -135,7 +138,7 @@ func PrivateHandle(rtypestr string, rtype uint16, generator func() PrivateRdata)
typeToparserFunc[rtype] = parserFunc{setPrivateRR, true} typeToparserFunc[rtype] = parserFunc{setPrivateRR, true}
} }
// PrivateHandleRemove removes defenitions required to support private RR type. // PrivateHandleRemove removes definitions required to support private RR type.
func PrivateHandleRemove(rtype uint16) { func PrivateHandleRemove(rtype uint16) {
rtypestr, ok := TypeToString[rtype] rtypestr, ok := TypeToString[rtype]
if ok { if ok {
@ -145,5 +148,4 @@ func PrivateHandleRemove(rtype uint16) {
delete(StringToType, rtypestr) delete(StringToType, rtypestr)
delete(typeToUnpack, rtype) delete(typeToUnpack, rtype)
} }
return
} }

View file

@ -1,49 +0,0 @@
package dns
import "encoding/binary"
// rawSetRdlength sets the rdlength in the header of
// the RR. The offset 'off' must be positioned at the
// start of the header of the RR, 'end' must be the
// end of the RR.
func rawSetRdlength(msg []byte, off, end int) bool {
l := len(msg)
Loop:
for {
if off+1 > l {
return false
}
c := int(msg[off])
off++
switch c & 0xC0 {
case 0x00:
if c == 0x00 {
// End of the domainname
break Loop
}
if off+c > l {
return false
}
off += c
case 0xC0:
// pointer, next byte included, ends domainname
off++
break Loop
}
}
// The domainname has been seen, we at the start of the fixed part in the header.
// Type is 2 bytes, class is 2 bytes, ttl 4 and then 2 bytes for the length.
off += 2 + 2 + 4
if off+2 > l {
return false
}
//off+1 is the end of the header, 'end' is the end of the rr
//so 'end' - 'off+2' is the length of the rdata
rdatalen := end - (off + 2)
if rdatalen > 0xFFFF {
return false
}
binary.BigEndian.PutUint16(msg[off:], uint16(rdatalen))
return true
}

View file

@ -12,6 +12,11 @@ var StringToOpcode = reverseInt(OpcodeToString)
// StringToRcode is a map of rcodes to strings. // StringToRcode is a map of rcodes to strings.
var StringToRcode = reverseInt(RcodeToString) var StringToRcode = reverseInt(RcodeToString)
func init() {
// Preserve previous NOTIMP typo, see github.com/miekg/dns/issues/733.
StringToRcode["NOTIMPL"] = RcodeNotImplemented
}
// Reverse a map // Reverse a map
func reverseInt8(m map[uint8]string) map[string]uint8 { func reverseInt8(m map[uint8]string) map[string]uint8 {
n := make(map[string]uint8, len(m)) n := make(map[string]uint8, len(m))

View file

@ -5,6 +5,7 @@ package dns
// rrs. // rrs.
// m is used to store the RRs temporary. If it is nil a new map will be allocated. // m is used to store the RRs temporary. If it is nil a new map will be allocated.
func Dedup(rrs []RR, m map[string]RR) []RR { func Dedup(rrs []RR, m map[string]RR) []RR {
if m == nil { if m == nil {
m = make(map[string]RR) m = make(map[string]RR)
} }

990
vendor/github.com/miekg/dns/scan.go generated vendored

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1,56 +0,0 @@
package dns
// Implement a simple scanner, return a byte stream from an io reader.
import (
"bufio"
"context"
"io"
"text/scanner"
)
type scan struct {
src *bufio.Reader
position scanner.Position
eof bool // Have we just seen a eof
ctx context.Context
}
func scanInit(r io.Reader) (*scan, context.CancelFunc) {
s := new(scan)
s.src = bufio.NewReader(r)
s.position.Line = 1
ctx, cancel := context.WithCancel(context.Background())
s.ctx = ctx
return s, cancel
}
// tokenText returns the next byte from the input
func (s *scan) tokenText() (byte, error) {
c, err := s.src.ReadByte()
if err != nil {
return c, err
}
select {
case <-s.ctx.Done():
return c, context.Canceled
default:
break
}
// delay the newline handling until the next token is delivered,
// fixes off-by-one errors when reporting a parse error.
if s.eof == true {
s.position.Line++
s.position.Column = 0
s.eof = false
}
if c == '\n' {
s.eof = true
return c, nil
}
s.position.Column++
return c, nil
}

147
vendor/github.com/miekg/dns/serve_mux.go generated vendored Normal file
View file

@ -0,0 +1,147 @@
package dns
import (
"strings"
"sync"
)
// ServeMux is an DNS request multiplexer. It matches the zone name of
// each incoming request against a list of registered patterns add calls
// the handler for the pattern that most closely matches the zone name.
//
// ServeMux is DNSSEC aware, meaning that queries for the DS record are
// redirected to the parent zone (if that is also registered), otherwise
// the child gets the query.
//
// ServeMux is also safe for concurrent access from multiple goroutines.
//
// The zero ServeMux is empty and ready for use.
type ServeMux struct {
z map[string]Handler
m sync.RWMutex
}
// NewServeMux allocates and returns a new ServeMux.
func NewServeMux() *ServeMux {
return new(ServeMux)
}
// DefaultServeMux is the default ServeMux used by Serve.
var DefaultServeMux = NewServeMux()
func (mux *ServeMux) match(q string, t uint16) Handler {
mux.m.RLock()
defer mux.m.RUnlock()
if mux.z == nil {
return nil
}
var handler Handler
// TODO(tmthrgd): Once https://go-review.googlesource.com/c/go/+/137575
// lands in a go release, replace the following with strings.ToLower.
var sb strings.Builder
for i := 0; i < len(q); i++ {
c := q[i]
if !(c >= 'A' && c <= 'Z') {
continue
}
sb.Grow(len(q))
sb.WriteString(q[:i])
for ; i < len(q); i++ {
c := q[i]
if c >= 'A' && c <= 'Z' {
c += 'a' - 'A'
}
sb.WriteByte(c)
}
q = sb.String()
break
}
for off, end := 0, false; !end; off, end = NextLabel(q, off) {
if h, ok := mux.z[q[off:]]; ok {
if t != TypeDS {
return h
}
// Continue for DS to see if we have a parent too, if so delegate to the parent
handler = h
}
}
// Wildcard match, if we have found nothing try the root zone as a last resort.
if h, ok := mux.z["."]; ok {
return h
}
return handler
}
// Handle adds a handler to the ServeMux for pattern.
func (mux *ServeMux) Handle(pattern string, handler Handler) {
if pattern == "" {
panic("dns: invalid pattern " + pattern)
}
mux.m.Lock()
if mux.z == nil {
mux.z = make(map[string]Handler)
}
mux.z[Fqdn(pattern)] = handler
mux.m.Unlock()
}
// HandleFunc adds a handler function to the ServeMux for pattern.
func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
mux.Handle(pattern, HandlerFunc(handler))
}
// HandleRemove deregisters the handler specific for pattern from the ServeMux.
func (mux *ServeMux) HandleRemove(pattern string) {
if pattern == "" {
panic("dns: invalid pattern " + pattern)
}
mux.m.Lock()
delete(mux.z, Fqdn(pattern))
mux.m.Unlock()
}
// ServeDNS dispatches the request to the handler whose pattern most
// closely matches the request message.
//
// ServeDNS is DNSSEC aware, meaning that queries for the DS record
// are redirected to the parent zone (if that is also registered),
// otherwise the child gets the query.
//
// If no handler is found, or there is no question, a standard SERVFAIL
// message is returned
func (mux *ServeMux) ServeDNS(w ResponseWriter, req *Msg) {
var h Handler
if len(req.Question) >= 1 { // allow more than one question
h = mux.match(req.Question[0].Name, req.Question[0].Qtype)
}
if h != nil {
h.ServeDNS(w, req)
} else {
HandleFailed(w, req)
}
}
// Handle registers the handler with the given pattern
// in the DefaultServeMux. The documentation for
// ServeMux explains how patterns are matched.
func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) }
// HandleRemove deregisters the handle with the given pattern
// in the DefaultServeMux.
func HandleRemove(pattern string) { DefaultServeMux.HandleRemove(pattern) }
// HandleFunc registers the handler function with the given pattern
// in the DefaultServeMux.
func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
DefaultServeMux.HandleFunc(pattern, handler)
}

695
vendor/github.com/miekg/dns/server.go generated vendored
View file

@ -4,22 +4,54 @@ package dns
import ( import (
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"encoding/binary" "encoding/binary"
"errors"
"io" "io"
"net" "net"
"strings"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
// Maximum number of TCP queries before we close the socket. // Default maximum number of TCP queries before we close the socket.
const maxTCPQueries = 128 const maxTCPQueries = 128
// The maximum number of idle workers.
//
// This controls the maximum number of workers that are allowed to stay
// idle waiting for incoming requests before being torn down.
//
// If this limit is reached, the server will just keep spawning new
// workers (goroutines) for each incoming request. In this case, each
// worker will only be used for a single request.
const maxIdleWorkersCount = 10000
// The maximum length of time a worker may idle for before being destroyed.
const idleWorkerTimeout = 10 * time.Second
// aLongTimeAgo is a non-zero time, far in the past, used for
// immediate cancelation of network operations.
var aLongTimeAgo = time.Unix(1, 0)
// Handler is implemented by any value that implements ServeDNS. // Handler is implemented by any value that implements ServeDNS.
type Handler interface { type Handler interface {
ServeDNS(w ResponseWriter, r *Msg) ServeDNS(w ResponseWriter, r *Msg)
} }
// The HandlerFunc type is an adapter to allow the use of
// ordinary functions as DNS handlers. If f is a function
// with the appropriate signature, HandlerFunc(f) is a
// Handler object that calls f.
type HandlerFunc func(ResponseWriter, *Msg)
// ServeDNS calls f(w, r).
func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
f(w, r)
}
// A ResponseWriter interface is used by an DNS handler to // A ResponseWriter interface is used by an DNS handler to
// construct an DNS response. // construct an DNS response.
type ResponseWriter interface { type ResponseWriter interface {
@ -42,46 +74,25 @@ type ResponseWriter interface {
Hijack() Hijack()
} }
// A ConnectionStater interface is used by a DNS Handler to access TLS connection state
// when available.
type ConnectionStater interface {
ConnectionState() *tls.ConnectionState
}
type response struct { type response struct {
msg []byte
closed bool // connection has been closed
hijacked bool // connection has been hijacked by handler hijacked bool // connection has been hijacked by handler
tsigStatus error
tsigTimersOnly bool tsigTimersOnly bool
tsigStatus error
tsigRequestMAC string tsigRequestMAC string
tsigSecret map[string]string // the tsig secrets tsigSecret map[string]string // the tsig secrets
udp *net.UDPConn // i/o connection if UDP was used udp *net.UDPConn // i/o connection if UDP was used
tcp net.Conn // i/o connection if TCP was used tcp net.Conn // i/o connection if TCP was used
udpSession *SessionUDP // oob data to get egress interface right udpSession *SessionUDP // oob data to get egress interface right
remoteAddr net.Addr // address of the client
writer Writer // writer to output the raw DNS bits writer Writer // writer to output the raw DNS bits
} wg *sync.WaitGroup // for gracefull shutdown
// ServeMux is an DNS request multiplexer. It matches the
// zone name of each incoming request against a list of
// registered patterns add calls the handler for the pattern
// that most closely matches the zone name. ServeMux is DNSSEC aware, meaning
// that queries for the DS record are redirected to the parent zone (if that
// is also registered), otherwise the child gets the query.
// ServeMux is also safe for concurrent access from multiple goroutines.
type ServeMux struct {
z map[string]Handler
m *sync.RWMutex
}
// NewServeMux allocates and returns a new ServeMux.
func NewServeMux() *ServeMux { return &ServeMux{z: make(map[string]Handler), m: new(sync.RWMutex)} }
// DefaultServeMux is the default ServeMux used by Serve.
var DefaultServeMux = NewServeMux()
// The HandlerFunc type is an adapter to allow the use of
// ordinary functions as DNS handlers. If f is a function
// with the appropriate signature, HandlerFunc(f) is a
// Handler object that calls f.
type HandlerFunc func(ResponseWriter, *Msg)
// ServeDNS calls f(w, r).
func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
f(w, r)
} }
// HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets. // HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets.
@ -92,8 +103,6 @@ func HandleFailed(w ResponseWriter, r *Msg) {
w.WriteMsg(m) w.WriteMsg(m)
} }
func failedHandler() Handler { return HandlerFunc(HandleFailed) }
// ListenAndServe Starts a server on address and network specified Invoke handler // ListenAndServe Starts a server on address and network specified Invoke handler
// for incoming queries. // for incoming queries.
func ListenAndServe(addr string, network string, handler Handler) error { func ListenAndServe(addr string, network string, handler Handler) error {
@ -132,99 +141,6 @@ func ActivateAndServe(l net.Listener, p net.PacketConn, handler Handler) error {
return server.ActivateAndServe() return server.ActivateAndServe()
} }
func (mux *ServeMux) match(q string, t uint16) Handler {
mux.m.RLock()
defer mux.m.RUnlock()
var handler Handler
b := make([]byte, len(q)) // worst case, one label of length q
off := 0
end := false
for {
l := len(q[off:])
for i := 0; i < l; i++ {
b[i] = q[off+i]
if b[i] >= 'A' && b[i] <= 'Z' {
b[i] |= ('a' - 'A')
}
}
if h, ok := mux.z[string(b[:l])]; ok { // causes garbage, might want to change the map key
if t != TypeDS {
return h
}
// Continue for DS to see if we have a parent too, if so delegeate to the parent
handler = h
}
off, end = NextLabel(q, off)
if end {
break
}
}
// Wildcard match, if we have found nothing try the root zone as a last resort.
if h, ok := mux.z["."]; ok {
return h
}
return handler
}
// Handle adds a handler to the ServeMux for pattern.
func (mux *ServeMux) Handle(pattern string, handler Handler) {
if pattern == "" {
panic("dns: invalid pattern " + pattern)
}
mux.m.Lock()
mux.z[Fqdn(pattern)] = handler
mux.m.Unlock()
}
// HandleFunc adds a handler function to the ServeMux for pattern.
func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
mux.Handle(pattern, HandlerFunc(handler))
}
// HandleRemove deregistrars the handler specific for pattern from the ServeMux.
func (mux *ServeMux) HandleRemove(pattern string) {
if pattern == "" {
panic("dns: invalid pattern " + pattern)
}
mux.m.Lock()
delete(mux.z, Fqdn(pattern))
mux.m.Unlock()
}
// ServeDNS dispatches the request to the handler whose
// pattern most closely matches the request message. If DefaultServeMux
// is used the correct thing for DS queries is done: a possible parent
// is sought.
// If no handler is found a standard SERVFAIL message is returned
// If the request message does not have exactly one question in the
// question section a SERVFAIL is returned, unlesss Unsafe is true.
func (mux *ServeMux) ServeDNS(w ResponseWriter, request *Msg) {
var h Handler
if len(request.Question) < 1 { // allow more than one question
h = failedHandler()
} else {
if h = mux.match(request.Question[0].Name, request.Question[0].Qtype); h == nil {
h = failedHandler()
}
}
h.ServeDNS(w, request)
}
// Handle registers the handler with the given pattern
// in the DefaultServeMux. The documentation for
// ServeMux explains how patterns are matched.
func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) }
// HandleRemove deregisters the handle with the given pattern
// in the DefaultServeMux.
func HandleRemove(pattern string) { DefaultServeMux.HandleRemove(pattern) }
// HandleFunc registers the handler function with the given pattern
// in the DefaultServeMux.
func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
DefaultServeMux.HandleFunc(pattern, handler)
}
// Writer writes raw DNS messages; each call to Write should send an entire message. // Writer writes raw DNS messages; each call to Write should send an entire message.
type Writer interface { type Writer interface {
io.Writer io.Writer
@ -287,87 +203,170 @@ type Server struct {
IdleTimeout func() time.Duration IdleTimeout func() time.Duration
// Secret(s) for Tsig map[<zonename>]<base64 secret>. The zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2). // Secret(s) for Tsig map[<zonename>]<base64 secret>. The zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2).
TsigSecret map[string]string TsigSecret map[string]string
// Unsafe instructs the server to disregard any sanity checks and directly hand the message to
// the handler. It will specifically not check if the query has the QR bit not set.
Unsafe bool
// If NotifyStartedFunc is set it is called once the server has started listening. // If NotifyStartedFunc is set it is called once the server has started listening.
NotifyStartedFunc func() NotifyStartedFunc func()
// DecorateReader is optional, allows customization of the process that reads raw DNS messages. // DecorateReader is optional, allows customization of the process that reads raw DNS messages.
DecorateReader DecorateReader DecorateReader DecorateReader
// DecorateWriter is optional, allows customization of the process that writes raw DNS messages. // DecorateWriter is optional, allows customization of the process that writes raw DNS messages.
DecorateWriter DecorateWriter DecorateWriter DecorateWriter
// Maximum number of TCP queries before we close the socket. Default is maxTCPQueries (unlimited if -1).
MaxTCPQueries int
// Whether to set the SO_REUSEPORT socket option, allowing multiple listeners to be bound to a single address.
// It is only supported on go1.11+ and when using ListenAndServe.
ReusePort bool
// AcceptMsgFunc will check the incoming message and will reject it early in the process.
// By default DefaultMsgAcceptFunc will be used.
MsgAcceptFunc MsgAcceptFunc
// UDP packet or TCP connection queue
queue chan *response
// Workers count
workersCount int32
// Shutdown handling // Shutdown handling
lock sync.RWMutex lock sync.RWMutex
started bool started bool
shutdown chan struct{}
conns map[net.Conn]struct{}
// A pool for UDP message buffers.
udpPool sync.Pool
}
func (srv *Server) isStarted() bool {
srv.lock.RLock()
started := srv.started
srv.lock.RUnlock()
return started
}
func (srv *Server) worker(w *response) {
srv.serve(w)
for {
count := atomic.LoadInt32(&srv.workersCount)
if count > maxIdleWorkersCount {
return
}
if atomic.CompareAndSwapInt32(&srv.workersCount, count, count+1) {
break
}
}
defer atomic.AddInt32(&srv.workersCount, -1)
inUse := false
timeout := time.NewTimer(idleWorkerTimeout)
defer timeout.Stop()
LOOP:
for {
select {
case w, ok := <-srv.queue:
if !ok {
break LOOP
}
inUse = true
srv.serve(w)
case <-timeout.C:
if !inUse {
break LOOP
}
inUse = false
timeout.Reset(idleWorkerTimeout)
}
}
}
func (srv *Server) spawnWorker(w *response) {
select {
case srv.queue <- w:
default:
go srv.worker(w)
}
}
func makeUDPBuffer(size int) func() interface{} {
return func() interface{} {
return make([]byte, size)
}
}
func (srv *Server) init() {
srv.queue = make(chan *response)
srv.shutdown = make(chan struct{})
srv.conns = make(map[net.Conn]struct{})
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
}
if srv.MsgAcceptFunc == nil {
srv.MsgAcceptFunc = defaultMsgAcceptFunc
}
srv.udpPool.New = makeUDPBuffer(srv.UDPSize)
}
func unlockOnce(l sync.Locker) func() {
var once sync.Once
return func() { once.Do(l.Unlock) }
} }
// ListenAndServe starts a nameserver on the configured address in *Server. // ListenAndServe starts a nameserver on the configured address in *Server.
func (srv *Server) ListenAndServe() error { func (srv *Server) ListenAndServe() error {
unlock := unlockOnce(&srv.lock)
srv.lock.Lock() srv.lock.Lock()
defer srv.lock.Unlock() defer unlock()
if srv.started { if srv.started {
return &Error{err: "server already started"} return &Error{err: "server already started"}
} }
addr := srv.Addr addr := srv.Addr
if addr == "" { if addr == "" {
addr = ":domain" addr = ":domain"
} }
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize srv.init()
} defer close(srv.queue)
switch srv.Net { switch srv.Net {
case "tcp", "tcp4", "tcp6": case "tcp", "tcp4", "tcp6":
a, err := net.ResolveTCPAddr(srv.Net, addr) l, err := listenTCP(srv.Net, addr, srv.ReusePort)
if err != nil {
return err
}
l, err := net.ListenTCP(srv.Net, a)
if err != nil { if err != nil {
return err return err
} }
srv.Listener = l srv.Listener = l
srv.started = true srv.started = true
srv.lock.Unlock() unlock()
err = srv.serveTCP(l) return srv.serveTCP(l)
srv.lock.Lock() // to satisfy the defer at the top
return err
case "tcp-tls", "tcp4-tls", "tcp6-tls": case "tcp-tls", "tcp4-tls", "tcp6-tls":
network := "tcp" if srv.TLSConfig == nil || (len(srv.TLSConfig.Certificates) == 0 && srv.TLSConfig.GetCertificate == nil) {
if srv.Net == "tcp4-tls" { return errors.New("dns: neither Certificates nor GetCertificate set in Config")
network = "tcp4"
} else if srv.Net == "tcp6-tls" {
network = "tcp6"
} }
network := strings.TrimSuffix(srv.Net, "-tls")
l, err := tls.Listen(network, addr, srv.TLSConfig) l, err := listenTCP(network, addr, srv.ReusePort)
if err != nil { if err != nil {
return err return err
} }
l = tls.NewListener(l, srv.TLSConfig)
srv.Listener = l srv.Listener = l
srv.started = true srv.started = true
srv.lock.Unlock() unlock()
err = srv.serveTCP(l) return srv.serveTCP(l)
srv.lock.Lock() // to satisfy the defer at the top
return err
case "udp", "udp4", "udp6": case "udp", "udp4", "udp6":
a, err := net.ResolveUDPAddr(srv.Net, addr) l, err := listenUDP(srv.Net, addr, srv.ReusePort)
if err != nil { if err != nil {
return err return err
} }
l, err := net.ListenUDP(srv.Net, a) u := l.(*net.UDPConn)
if err != nil { if e := setUDPSocketOptions(u); e != nil {
return err
}
if e := setUDPSocketOptions(l); e != nil {
return e return e
} }
srv.PacketConn = l srv.PacketConn = l
srv.started = true srv.started = true
srv.lock.Unlock() unlock()
err = srv.serveUDP(l) return srv.serveUDP(u)
srv.lock.Lock() // to satisfy the defer at the top
return err
} }
return &Error{err: "bad network"} return &Error{err: "bad network"}
} }
@ -375,17 +374,20 @@ func (srv *Server) ListenAndServe() error {
// ActivateAndServe starts a nameserver with the PacketConn or Listener // ActivateAndServe starts a nameserver with the PacketConn or Listener
// configured in *Server. Its main use is to start a server from systemd. // configured in *Server. Its main use is to start a server from systemd.
func (srv *Server) ActivateAndServe() error { func (srv *Server) ActivateAndServe() error {
unlock := unlockOnce(&srv.lock)
srv.lock.Lock() srv.lock.Lock()
defer srv.lock.Unlock() defer unlock()
if srv.started { if srv.started {
return &Error{err: "server already started"} return &Error{err: "server already started"}
} }
srv.init()
defer close(srv.queue)
pConn := srv.PacketConn pConn := srv.PacketConn
l := srv.Listener l := srv.Listener
if pConn != nil { if pConn != nil {
if srv.UDPSize == 0 {
srv.UDPSize = MinMsgSize
}
// Check PacketConn interface's type is valid and value // Check PacketConn interface's type is valid and value
// is not nil // is not nil
if t, ok := pConn.(*net.UDPConn); ok && t != nil { if t, ok := pConn.(*net.UDPConn); ok && t != nil {
@ -393,18 +395,14 @@ func (srv *Server) ActivateAndServe() error {
return e return e
} }
srv.started = true srv.started = true
srv.lock.Unlock() unlock()
e := srv.serveUDP(t) return srv.serveUDP(t)
srv.lock.Lock() // to satisfy the defer at the top
return e
} }
} }
if l != nil { if l != nil {
srv.started = true srv.started = true
srv.lock.Unlock() unlock()
e := srv.serveTCP(l) return srv.serveTCP(l)
srv.lock.Lock() // to satisfy the defer at the top
return e
} }
return &Error{err: "bad listeners"} return &Error{err: "bad listeners"}
} }
@ -412,23 +410,57 @@ func (srv *Server) ActivateAndServe() error {
// Shutdown shuts down a server. After a call to Shutdown, ListenAndServe and // Shutdown shuts down a server. After a call to Shutdown, ListenAndServe and
// ActivateAndServe will return. // ActivateAndServe will return.
func (srv *Server) Shutdown() error { func (srv *Server) Shutdown() error {
return srv.ShutdownContext(context.Background())
}
// ShutdownContext shuts down a server. After a call to ShutdownContext,
// ListenAndServe and ActivateAndServe will return.
//
// A context.Context may be passed to limit how long to wait for connections
// to terminate.
func (srv *Server) ShutdownContext(ctx context.Context) error {
srv.lock.Lock() srv.lock.Lock()
if !srv.started { if !srv.started {
srv.lock.Unlock() srv.lock.Unlock()
return &Error{err: "server not started"} return &Error{err: "server not started"}
} }
srv.started = false srv.started = false
if srv.PacketConn != nil {
srv.PacketConn.SetReadDeadline(aLongTimeAgo) // Unblock reads
}
if srv.Listener != nil {
srv.Listener.Close()
}
for rw := range srv.conns {
rw.SetReadDeadline(aLongTimeAgo) // Unblock reads
}
srv.lock.Unlock() srv.lock.Unlock()
if testShutdownNotify != nil {
testShutdownNotify.Broadcast()
}
var ctxErr error
select {
case <-srv.shutdown:
case <-ctx.Done():
ctxErr = ctx.Err()
}
if srv.PacketConn != nil { if srv.PacketConn != nil {
srv.PacketConn.Close() srv.PacketConn.Close()
} }
if srv.Listener != nil {
srv.Listener.Close() return ctxErr
}
return nil
} }
var testShutdownNotify *sync.Cond
// getReadTimeout is a helper func to use system timeout if server did not intend to change it. // getReadTimeout is a helper func to use system timeout if server did not intend to change it.
func (srv *Server) getReadTimeout() time.Duration { func (srv *Server) getReadTimeout() time.Duration {
rtimeout := dnsTimeout rtimeout := dnsTimeout
@ -439,7 +471,6 @@ func (srv *Server) getReadTimeout() time.Duration {
} }
// serveTCP starts a TCP listener for the server. // serveTCP starts a TCP listener for the server.
// Each request is handled in a separate goroutine.
func (srv *Server) serveTCP(l net.Listener) error { func (srv *Server) serveTCP(l net.Listener) error {
defer l.Close() defer l.Close()
@ -447,44 +478,39 @@ func (srv *Server) serveTCP(l net.Listener) error {
srv.NotifyStartedFunc() srv.NotifyStartedFunc()
} }
reader := Reader(&defaultReader{srv}) var wg sync.WaitGroup
if srv.DecorateReader != nil { defer func() {
reader = srv.DecorateReader(reader) wg.Wait()
} close(srv.shutdown)
}()
handler := srv.Handler for srv.isStarted() {
if handler == nil {
handler = DefaultServeMux
}
rtimeout := srv.getReadTimeout()
// deadline is not used here
for {
rw, err := l.Accept() rw, err := l.Accept()
srv.lock.RLock() if err != nil {
if !srv.started { if !srv.isStarted() {
srv.lock.RUnlock()
return nil return nil
} }
srv.lock.RUnlock()
if err != nil {
if neterr, ok := err.(net.Error); ok && neterr.Temporary() { if neterr, ok := err.(net.Error); ok && neterr.Temporary() {
continue continue
} }
return err return err
} }
go func() { srv.lock.Lock()
m, err := reader.ReadTCP(rw, rtimeout) // Track the connection to allow unblocking reads on shutdown.
if err != nil { srv.conns[rw] = struct{}{}
rw.Close() srv.lock.Unlock()
return wg.Add(1)
} srv.spawnWorker(&response{
srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw) tsigSecret: srv.TsigSecret,
}() tcp: rw,
wg: &wg,
})
} }
return nil
} }
// serveUDP starts a UDP listener for the server. // serveUDP starts a UDP listener for the server.
// Each request is handled in a separate goroutine.
func (srv *Server) serveUDP(l *net.UDPConn) error { func (srv *Server) serveUDP(l *net.UDPConn) error {
defer l.Close() defer l.Close()
@ -497,107 +523,182 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
reader = srv.DecorateReader(reader) reader = srv.DecorateReader(reader)
} }
handler := srv.Handler var wg sync.WaitGroup
if handler == nil { defer func() {
handler = DefaultServeMux wg.Wait()
} close(srv.shutdown)
}()
rtimeout := srv.getReadTimeout() rtimeout := srv.getReadTimeout()
// deadline is not used here // deadline is not used here
for { for srv.isStarted() {
m, s, err := reader.ReadUDP(l, rtimeout) m, s, err := reader.ReadUDP(l, rtimeout)
srv.lock.RLock() if err != nil {
if !srv.started { if !srv.isStarted() {
srv.lock.RUnlock()
return nil return nil
} }
srv.lock.RUnlock()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Temporary() { if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
continue continue
} }
return err return err
} }
if len(m) < headerSize { if len(m) < headerSize {
if cap(m) == srv.UDPSize {
srv.udpPool.Put(m[:srv.UDPSize])
}
continue continue
} }
go srv.serve(s.RemoteAddr(), handler, m, l, s, nil) wg.Add(1)
} srv.spawnWorker(&response{
msg: m,
tsigSecret: srv.TsigSecret,
udp: l,
udpSession: s,
wg: &wg,
})
} }
// Serve a new connection. return nil
func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t net.Conn) { }
w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s}
func (srv *Server) serve(w *response) {
if srv.DecorateWriter != nil { if srv.DecorateWriter != nil {
w.writer = srv.DecorateWriter(w) w.writer = srv.DecorateWriter(w)
} else { } else {
w.writer = w w.writer = w
} }
q := 0 // counter for the amount of TCP queries we get if w.udp != nil {
// serve UDP
srv.serveDNS(w)
w.wg.Done()
return
}
defer func() {
if !w.hijacked {
w.Close()
}
srv.lock.Lock()
delete(srv.conns, w.tcp)
srv.lock.Unlock()
w.wg.Done()
}()
reader := Reader(&defaultReader{srv}) reader := Reader(&defaultReader{srv})
if srv.DecorateReader != nil { if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader) reader = srv.DecorateReader(reader)
} }
Redo:
req := new(Msg) idleTimeout := tcpIdleTimeout
err := req.Unpack(m) if srv.IdleTimeout != nil {
if err != nil { // Send a FormatError back idleTimeout = srv.IdleTimeout()
x := new(Msg)
x.SetRcodeFormatError(req)
w.WriteMsg(x)
goto Exit
} }
if !srv.Unsafe && req.Response {
goto Exit timeout := srv.getReadTimeout()
limit := srv.MaxTCPQueries
if limit == 0 {
limit = maxTCPQueries
}
for q := 0; (q < limit || limit == -1) && srv.isStarted(); q++ {
var err error
w.msg, err = reader.ReadTCP(w.tcp, timeout)
if err != nil {
// TODO(tmthrgd): handle error
break
}
srv.serveDNS(w)
if w.tcp == nil {
break // Close() was called
}
if w.hijacked {
break // client will call Close() themselves
}
// The first read uses the read timeout, the rest use the
// idle timeout.
timeout = idleTimeout
}
}
func (srv *Server) disposeBuffer(w *response) {
if w.udp != nil && cap(w.msg) == srv.UDPSize {
srv.udpPool.Put(w.msg[:srv.UDPSize])
}
w.msg = nil
}
func (srv *Server) serveDNS(w *response) {
dh, off, err := unpackMsgHdr(w.msg, 0)
if err != nil {
// Let client hang, they are sending crap; any reply can be used to amplify.
return
}
req := new(Msg)
req.setHdr(dh)
switch srv.MsgAcceptFunc(dh) {
case MsgAccept:
case MsgIgnore:
return
case MsgReject:
req.SetRcodeFormatError(req)
// Are we allowed to delete any OPT records here?
req.Ns, req.Answer, req.Extra = nil, nil, nil
w.WriteMsg(req)
srv.disposeBuffer(w)
return
}
if err := req.unpack(dh, w.msg, off); err != nil {
req.SetRcodeFormatError(req)
req.Ns, req.Answer, req.Extra = nil, nil, nil
w.WriteMsg(req)
srv.disposeBuffer(w)
return
} }
w.tsigStatus = nil w.tsigStatus = nil
if w.tsigSecret != nil { if w.tsigSecret != nil {
if t := req.IsTsig(); t != nil { if t := req.IsTsig(); t != nil {
secret := t.Hdr.Name if secret, ok := w.tsigSecret[t.Hdr.Name]; ok {
if _, ok := w.tsigSecret[secret]; !ok { w.tsigStatus = TsigVerify(w.msg, secret, "", false)
w.tsigStatus = ErrKeyAlg } else {
w.tsigStatus = ErrSecret
} }
w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false)
w.tsigTimersOnly = false w.tsigTimersOnly = false
w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
} }
} }
h.ServeDNS(w, req) // Writes back to the client
Exit: srv.disposeBuffer(w)
if w.tcp == nil {
return handler := srv.Handler
} if handler == nil {
// TODO(miek): make this number configurable? handler = DefaultServeMux
if q > maxTCPQueries { // close socket after this many queries
w.Close()
return
} }
if w.hijacked { handler.ServeDNS(w, req) // Writes back to the client
return // client calls Close()
}
if u != nil { // UDP, "close" and return
w.Close()
return
}
idleTimeout := tcpIdleTimeout
if srv.IdleTimeout != nil {
idleTimeout = srv.IdleTimeout()
}
m, err = reader.ReadTCP(w.tcp, idleTimeout)
if err == nil {
q++
goto Redo
}
w.Close()
return
} }
func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) { func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
// If we race with ShutdownContext, the read deadline may
// have been set in the distant past to unblock the read
// below. We must not override it, otherwise we may block
// ShutdownContext.
srv.lock.RLock()
if srv.started {
conn.SetReadDeadline(time.Now().Add(timeout)) conn.SetReadDeadline(time.Now().Add(timeout))
}
srv.lock.RUnlock()
l := make([]byte, 2) l := make([]byte, 2)
n, err := conn.Read(l) n, err := conn.Read(l)
if err != nil || n != 2 { if err != nil || n != 2 {
@ -632,10 +733,17 @@ func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error)
} }
func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) { func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
srv.lock.RLock()
if srv.started {
// See the comment in readTCP above.
conn.SetReadDeadline(time.Now().Add(timeout)) conn.SetReadDeadline(time.Now().Add(timeout))
m := make([]byte, srv.UDPSize) }
srv.lock.RUnlock()
m := srv.udpPool.Get().([]byte)
n, s, err := ReadFromSessionUDP(conn, m) n, s, err := ReadFromSessionUDP(conn, m)
if err != nil { if err != nil {
srv.udpPool.Put(m)
return nil, nil, err return nil, nil, err
} }
m = m[:n] m = m[:n]
@ -644,6 +752,10 @@ func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *S
// WriteMsg implements the ResponseWriter.WriteMsg method. // WriteMsg implements the ResponseWriter.WriteMsg method.
func (w *response) WriteMsg(m *Msg) (err error) { func (w *response) WriteMsg(m *Msg) (err error) {
if w.closed {
return &Error{err: "WriteMsg called after Close"}
}
var data []byte var data []byte
if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check) if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check)
if t := m.IsTsig(); t != nil { if t := m.IsTsig(); t != nil {
@ -665,6 +777,10 @@ func (w *response) WriteMsg(m *Msg) (err error) {
// Write implements the ResponseWriter.Write method. // Write implements the ResponseWriter.Write method.
func (w *response) Write(m []byte) (int, error) { func (w *response) Write(m []byte) (int, error) {
if w.closed {
return 0, &Error{err: "Write called after Close"}
}
switch { switch {
case w.udp != nil: case w.udp != nil:
n, err := WriteToSessionUDP(w.udp, m, w.udpSession) n, err := WriteToSessionUDP(w.udp, m, w.udpSession)
@ -683,20 +799,34 @@ func (w *response) Write(m []byte) (int, error) {
n, err := io.Copy(w.tcp, bytes.NewReader(m)) n, err := io.Copy(w.tcp, bytes.NewReader(m))
return int(n), err return int(n), err
default:
panic("dns: internal error: udp and tcp both nil")
} }
panic("not reached")
} }
// LocalAddr implements the ResponseWriter.LocalAddr method. // LocalAddr implements the ResponseWriter.LocalAddr method.
func (w *response) LocalAddr() net.Addr { func (w *response) LocalAddr() net.Addr {
if w.tcp != nil { switch {
return w.tcp.LocalAddr() case w.udp != nil:
}
return w.udp.LocalAddr() return w.udp.LocalAddr()
case w.tcp != nil:
return w.tcp.LocalAddr()
default:
panic("dns: internal error: udp and tcp both nil")
}
} }
// RemoteAddr implements the ResponseWriter.RemoteAddr method. // RemoteAddr implements the ResponseWriter.RemoteAddr method.
func (w *response) RemoteAddr() net.Addr { return w.remoteAddr } func (w *response) RemoteAddr() net.Addr {
switch {
case w.udpSession != nil:
return w.udpSession.RemoteAddr()
case w.tcp != nil:
return w.tcp.RemoteAddr()
default:
panic("dns: internal error: udpSession and tcp both nil")
}
}
// TsigStatus implements the ResponseWriter.TsigStatus method. // TsigStatus implements the ResponseWriter.TsigStatus method.
func (w *response) TsigStatus() error { return w.tsigStatus } func (w *response) TsigStatus() error { return w.tsigStatus }
@ -709,11 +839,30 @@ func (w *response) Hijack() { w.hijacked = true }
// Close implements the ResponseWriter.Close method // Close implements the ResponseWriter.Close method
func (w *response) Close() error { func (w *response) Close() error {
if w.closed {
return &Error{err: "connection already closed"}
}
w.closed = true
switch {
case w.udp != nil:
// Can't close the udp conn, as that is actually the listener. // Can't close the udp conn, as that is actually the listener.
if w.tcp != nil { return nil
e := w.tcp.Close() case w.tcp != nil:
w.tcp = nil return w.tcp.Close()
return e default:
panic("dns: internal error: udp and tcp both nil")
}
}
// ConnectionState() implements the ConnectionStater.ConnectionState() interface.
func (w *response) ConnectionState() *tls.ConnectionState {
type tlsConnectionStater interface {
ConnectionState() tls.ConnectionState
}
if v, ok := w.tcp.(tlsConnectionStater); ok {
t := v.ConnectionState()
return &t
} }
return nil return nil
} }

View file

@ -29,7 +29,7 @@ func (rr *SIG) Sign(k crypto.Signer, m *Msg) ([]byte, error) {
rr.TypeCovered = 0 rr.TypeCovered = 0
rr.Labels = 0 rr.Labels = 0
buf := make([]byte, m.Len()+rr.len()) buf := make([]byte, m.Len()+Len(rr))
mbuf, err := m.PackBuffer(buf) mbuf, err := m.PackBuffer(buf)
if err != nil { if err != nil {
return nil, err return nil, err
@ -127,8 +127,7 @@ func (rr *SIG) Verify(k *KEY, buf []byte) error {
if offset+1 >= buflen { if offset+1 >= buflen {
continue continue
} }
var rdlen uint16 rdlen := binary.BigEndian.Uint16(buf[offset:])
rdlen = binary.BigEndian.Uint16(buf[offset:])
offset += 2 offset += 2
offset += int(rdlen) offset += int(rdlen)
} }
@ -168,7 +167,7 @@ func (rr *SIG) Verify(k *KEY, buf []byte) error {
} }
// If key has come from the DNS name compression might // If key has come from the DNS name compression might
// have mangled the case of the name // have mangled the case of the name
if strings.ToLower(signername) != strings.ToLower(k.Header().Name) { if !strings.EqualFold(signername, k.Header().Name) {
return &Error{err: "signer name doesn't match key name"} return &Error{err: "signer name doesn't match key name"}
} }
sigend := offset sigend := offset

View file

@ -133,7 +133,7 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s
t.Algorithm = rr.Algorithm t.Algorithm = rr.Algorithm
t.OrigId = m.Id t.OrigId = m.Id
tbuf := make([]byte, t.len()) tbuf := make([]byte, Len(t))
if off, err := PackRR(t, tbuf, 0, nil, false); err == nil { if off, err := PackRR(t, tbuf, 0, nil, false); err == nil {
tbuf = tbuf[:off] // reset to actual size used tbuf = tbuf[:off] // reset to actual size used
} else { } else {

215
vendor/github.com/miekg/dns/types.go generated vendored
View file

@ -218,8 +218,10 @@ type Question struct {
Qclass uint16 Qclass uint16
} }
func (q *Question) len() int { func (q *Question) len(off int, compression map[string]struct{}) int {
return len(q.Name) + 1 + 2 + 2 l := domainNameLen(q.Name, off, compression, true)
l += 2 + 2
return l
} }
func (q *Question) String() (s string) { func (q *Question) String() (s string) {
@ -330,7 +332,7 @@ func (rr *MX) String() string {
type AFSDB struct { type AFSDB struct {
Hdr RR_Header Hdr RR_Header
Subtype uint16 Subtype uint16
Hostname string `dns:"cdomain-name"` Hostname string `dns:"domain-name"`
} }
func (rr *AFSDB) String() string { func (rr *AFSDB) String() string {
@ -351,7 +353,7 @@ func (rr *X25) String() string {
type RT struct { type RT struct {
Hdr RR_Header Hdr RR_Header
Preference uint16 Preference uint16
Host string `dns:"cdomain-name"` Host string `dns:"domain-name"` // RFC 3597 prohibits compressing records not defined in RFC 1035.
} }
func (rr *RT) String() string { func (rr *RT) String() string {
@ -419,128 +421,154 @@ type TXT struct {
func (rr *TXT) String() string { return rr.Hdr.String() + sprintTxt(rr.Txt) } func (rr *TXT) String() string { return rr.Hdr.String() + sprintTxt(rr.Txt) }
func sprintName(s string) string { func sprintName(s string) string {
src := []byte(s) var dst strings.Builder
dst := make([]byte, 0, len(src)) dst.Grow(len(s))
for i := 0; i < len(src); { for i := 0; i < len(s); {
if i+1 < len(src) && src[i] == '\\' && src[i+1] == '.' { if i+1 < len(s) && s[i] == '\\' && s[i+1] == '.' {
dst = append(dst, src[i:i+2]...) dst.WriteString(s[i : i+2])
i += 2 i += 2
} else { continue
b, n := nextByte(src, i) }
if n == 0 {
b, n := nextByte(s, i)
switch {
case n == 0:
i++ // dangling back slash i++ // dangling back slash
} else if b == '.' { case b == '.':
dst = append(dst, b) dst.WriteByte('.')
} else { default:
dst = appendDomainNameByte(dst, b) writeDomainNameByte(&dst, b)
} }
i += n i += n
} }
} return dst.String()
return string(dst)
} }
func sprintTxtOctet(s string) string { func sprintTxtOctet(s string) string {
src := []byte(s) var dst strings.Builder
dst := make([]byte, 0, len(src)) dst.Grow(2 + len(s))
dst = append(dst, '"') dst.WriteByte('"')
for i := 0; i < len(src); { for i := 0; i < len(s); {
if i+1 < len(src) && src[i] == '\\' && src[i+1] == '.' { if i+1 < len(s) && s[i] == '\\' && s[i+1] == '.' {
dst = append(dst, src[i:i+2]...) dst.WriteString(s[i : i+2])
i += 2 i += 2
} else { continue
b, n := nextByte(src, i)
if n == 0 {
i++ // dangling back slash
} else if b == '.' {
dst = append(dst, b)
} else {
if b < ' ' || b > '~' {
dst = appendByte(dst, b)
} else {
dst = append(dst, b)
} }
b, n := nextByte(s, i)
switch {
case n == 0:
i++ // dangling back slash
case b == '.':
dst.WriteByte('.')
case b < ' ' || b > '~':
dst.WriteString(escapeByte(b))
default:
dst.WriteByte(b)
} }
i += n i += n
} }
} dst.WriteByte('"')
dst = append(dst, '"') return dst.String()
return string(dst)
} }
func sprintTxt(txt []string) string { func sprintTxt(txt []string) string {
var out []byte var out strings.Builder
for i, s := range txt { for i, s := range txt {
out.Grow(3 + len(s))
if i > 0 { if i > 0 {
out = append(out, ` "`...) out.WriteString(` "`)
} else { } else {
out = append(out, '"') out.WriteByte('"')
} }
bs := []byte(s) for j := 0; j < len(s); {
for j := 0; j < len(bs); { b, n := nextByte(s, j)
b, n := nextByte(bs, j)
if n == 0 { if n == 0 {
break break
} }
out = appendTXTStringByte(out, b) writeTXTStringByte(&out, b)
j += n j += n
} }
out = append(out, '"') out.WriteByte('"')
} }
return string(out) return out.String()
} }
func appendDomainNameByte(s []byte, b byte) []byte { func writeDomainNameByte(s *strings.Builder, b byte) {
switch b { switch b {
case '.', ' ', '\'', '@', ';', '(', ')': // additional chars to escape case '.', ' ', '\'', '@', ';', '(', ')': // additional chars to escape
return append(s, '\\', b) s.WriteByte('\\')
s.WriteByte(b)
default:
writeTXTStringByte(s, b)
} }
return appendTXTStringByte(s, b)
} }
func appendTXTStringByte(s []byte, b byte) []byte { func writeTXTStringByte(s *strings.Builder, b byte) {
switch b { switch {
case '"', '\\': case b == '"' || b == '\\':
return append(s, '\\', b) s.WriteByte('\\')
s.WriteByte(b)
case b < ' ' || b > '~':
s.WriteString(escapeByte(b))
default:
s.WriteByte(b)
} }
if b < ' ' || b > '~' {
return appendByte(s, b)
}
return append(s, b)
} }
func appendByte(s []byte, b byte) []byte { const (
var buf [3]byte escapedByteSmall = "" +
bufs := strconv.AppendInt(buf[:0], int64(b), 10) `\000\001\002\003\004\005\006\007\008\009` +
s = append(s, '\\') `\010\011\012\013\014\015\016\017\018\019` +
for i := 0; i < 3-len(bufs); i++ { `\020\021\022\023\024\025\026\027\028\029` +
s = append(s, '0') `\030\031`
} escapedByteLarge = `\127\128\129` +
for _, r := range bufs { `\130\131\132\133\134\135\136\137\138\139` +
s = append(s, r) `\140\141\142\143\144\145\146\147\148\149` +
} `\150\151\152\153\154\155\156\157\158\159` +
return s `\160\161\162\163\164\165\166\167\168\169` +
`\170\171\172\173\174\175\176\177\178\179` +
`\180\181\182\183\184\185\186\187\188\189` +
`\190\191\192\193\194\195\196\197\198\199` +
`\200\201\202\203\204\205\206\207\208\209` +
`\210\211\212\213\214\215\216\217\218\219` +
`\220\221\222\223\224\225\226\227\228\229` +
`\230\231\232\233\234\235\236\237\238\239` +
`\240\241\242\243\244\245\246\247\248\249` +
`\250\251\252\253\254\255`
)
// escapeByte returns the \DDD escaping of b which must
// satisfy b < ' ' || b > '~'.
func escapeByte(b byte) string {
if b < ' ' {
return escapedByteSmall[b*4 : b*4+4]
} }
func nextByte(b []byte, offset int) (byte, int) { b -= '~' + 1
if offset >= len(b) { // The cast here is needed as b*4 may overflow byte.
return escapedByteLarge[int(b)*4 : int(b)*4+4]
}
func nextByte(s string, offset int) (byte, int) {
if offset >= len(s) {
return 0, 0 return 0, 0
} }
if b[offset] != '\\' { if s[offset] != '\\' {
// not an escape sequence // not an escape sequence
return b[offset], 1 return s[offset], 1
} }
switch len(b) - offset { switch len(s) - offset {
case 1: // dangling escape case 1: // dangling escape
return 0, 0 return 0, 0
case 2, 3: // too short to be \ddd case 2, 3: // too short to be \ddd
default: // maybe \ddd default: // maybe \ddd
if isDigit(b[offset+1]) && isDigit(b[offset+2]) && isDigit(b[offset+3]) { if isDigit(s[offset+1]) && isDigit(s[offset+2]) && isDigit(s[offset+3]) {
return dddToByte(b[offset+1:]), 4 return dddStringToByte(s[offset+1:]), 4
} }
} }
// not \ddd, just an RFC 1035 "quoted" character // not \ddd, just an RFC 1035 "quoted" character
return b[offset+1], 2 return s[offset+1], 2
} }
// SPF RR. See RFC 4408, Section 3.1.1. // SPF RR. See RFC 4408, Section 3.1.1.
@ -728,7 +756,7 @@ func (rr *LOC) String() string {
lat = lat % LOC_DEGREES lat = lat % LOC_DEGREES
m := lat / LOC_HOURS m := lat / LOC_HOURS
lat = lat % LOC_HOURS lat = lat % LOC_HOURS
s += fmt.Sprintf("%02d %02d %0.3f %s ", h, m, (float64(lat) / 1000), ns) s += fmt.Sprintf("%02d %02d %0.3f %s ", h, m, float64(lat)/1000, ns)
lon := rr.Longitude lon := rr.Longitude
ew := "E" ew := "E"
@ -742,7 +770,7 @@ func (rr *LOC) String() string {
lon = lon % LOC_DEGREES lon = lon % LOC_DEGREES
m = lon / LOC_HOURS m = lon / LOC_HOURS
lon = lon % LOC_HOURS lon = lon % LOC_HOURS
s += fmt.Sprintf("%02d %02d %0.3f %s ", h, m, (float64(lon) / 1000), ew) s += fmt.Sprintf("%02d %02d %0.3f %s ", h, m, float64(lon)/1000, ew)
var alt = float64(rr.Altitude) / 100 var alt = float64(rr.Altitude) / 100
alt -= LOC_ALTITUDEBASE alt -= LOC_ALTITUDEBASE
@ -752,9 +780,9 @@ func (rr *LOC) String() string {
s += fmt.Sprintf("%.0fm ", alt) s += fmt.Sprintf("%.0fm ", alt)
} }
s += cmToM((rr.Size&0xf0)>>4, rr.Size&0x0f) + "m " s += cmToM(rr.Size&0xf0>>4, rr.Size&0x0f) + "m "
s += cmToM((rr.HorizPre&0xf0)>>4, rr.HorizPre&0x0f) + "m " s += cmToM(rr.HorizPre&0xf0>>4, rr.HorizPre&0x0f) + "m "
s += cmToM((rr.VertPre&0xf0)>>4, rr.VertPre&0x0f) + "m" s += cmToM(rr.VertPre&0xf0>>4, rr.VertPre&0x0f) + "m"
return s return s
} }
@ -807,8 +835,9 @@ func (rr *NSEC) String() string {
return s return s
} }
func (rr *NSEC) len() int { func (rr *NSEC) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() + len(rr.NextDomain) + 1 l := rr.Hdr.len(off, compression)
l += domainNameLen(rr.NextDomain, off+l, compression, false)
lastwindow := uint32(2 ^ 32 + 1) lastwindow := uint32(2 ^ 32 + 1)
for _, t := range rr.TypeBitMap { for _, t := range rr.TypeBitMap {
window := t / 256 window := t / 256
@ -972,8 +1001,9 @@ func (rr *NSEC3) String() string {
return s return s
} }
func (rr *NSEC3) len() int { func (rr *NSEC3) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() + 6 + len(rr.Salt)/2 + 1 + len(rr.NextDomain) + 1 l := rr.Hdr.len(off, compression)
l += 6 + len(rr.Salt)/2 + 1 + len(rr.NextDomain) + 1
lastwindow := uint32(2 ^ 32 + 1) lastwindow := uint32(2 ^ 32 + 1)
for _, t := range rr.TypeBitMap { for _, t := range rr.TypeBitMap {
window := t / 256 window := t / 256
@ -1289,8 +1319,9 @@ func (rr *CSYNC) String() string {
return s return s
} }
func (rr *CSYNC) len() int { func (rr *CSYNC) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() + 4 + 2 l := rr.Hdr.len(off, compression)
l += 4 + 2
lastwindow := uint32(2 ^ 32 + 1) lastwindow := uint32(2 ^ 32 + 1)
for _, t := range rr.TypeBitMap { for _, t := range rr.TypeBitMap {
window := t / 256 window := t / 256
@ -1306,11 +1337,11 @@ func (rr *CSYNC) len() int {
// string representation used when printing the record. // string representation used when printing the record.
// It takes serial arithmetic (RFC 1982) into account. // It takes serial arithmetic (RFC 1982) into account.
func TimeToString(t uint32) string { func TimeToString(t uint32) string {
mod := ((int64(t) - time.Now().Unix()) / year68) - 1 mod := (int64(t)-time.Now().Unix())/year68 - 1
if mod < 0 { if mod < 0 {
mod = 0 mod = 0
} }
ti := time.Unix(int64(t)-(mod*year68), 0).UTC() ti := time.Unix(int64(t)-mod*year68, 0).UTC()
return ti.Format("20060102150405") return ti.Format("20060102150405")
} }
@ -1322,11 +1353,11 @@ func StringToTime(s string) (uint32, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
mod := (t.Unix() / year68) - 1 mod := t.Unix()/year68 - 1
if mod < 0 { if mod < 0 {
mod = 0 mod = 0
} }
return uint32(t.Unix() - (mod * year68)), nil return uint32(t.Unix() - mod*year68), nil
} }
// saltToString converts a NSECX salt to uppercase and returns "-" when it is empty. // saltToString converts a NSECX salt to uppercase and returns "-" when it is empty.

View file

@ -153,8 +153,8 @@ func main() {
if isEmbedded { if isEmbedded {
continue continue
} }
fmt.Fprintf(b, "func (rr *%s) len() int {\n", name) fmt.Fprintf(b, "func (rr *%s) len(off int, compression map[string]struct{}) int {\n", name)
fmt.Fprintf(b, "l := rr.Hdr.len()\n") fmt.Fprintf(b, "l := rr.Hdr.len(off, compression)\n")
for i := 1; i < st.NumFields(); i++ { for i := 1; i < st.NumFields(); i++ {
o := func(s string) { fmt.Fprintf(b, s, st.Field(i).Name()) } o := func(s string) { fmt.Fprintf(b, s, st.Field(i).Name()) }
@ -162,7 +162,11 @@ func main() {
switch st.Tag(i) { switch st.Tag(i) {
case `dns:"-"`: case `dns:"-"`:
// ignored // ignored
case `dns:"cdomain-name"`, `dns:"domain-name"`, `dns:"txt"`: case `dns:"cdomain-name"`:
o("for _, x := range rr.%s { l += domainNameLen(x, off+l, compression, true) }\n")
case `dns:"domain-name"`:
o("for _, x := range rr.%s { l += domainNameLen(x, off+l, compression, false) }\n")
case `dns:"txt"`:
o("for _, x := range rr.%s { l += len(x) + 1 }\n") o("for _, x := range rr.%s { l += len(x) + 1 }\n")
default: default:
log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
@ -173,8 +177,10 @@ func main() {
switch { switch {
case st.Tag(i) == `dns:"-"`: case st.Tag(i) == `dns:"-"`:
// ignored // ignored
case st.Tag(i) == `dns:"cdomain-name"`, st.Tag(i) == `dns:"domain-name"`: case st.Tag(i) == `dns:"cdomain-name"`:
o("l += len(rr.%s) + 1\n") o("l += domainNameLen(rr.%s, off+l, compression, true)\n")
case st.Tag(i) == `dns:"domain-name"`:
o("l += domainNameLen(rr.%s, off+l, compression, false)\n")
case st.Tag(i) == `dns:"octet"`: case st.Tag(i) == `dns:"octet"`:
o("l += len(rr.%s)\n") o("l += len(rr.%s)\n")
case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): case strings.HasPrefix(st.Tag(i), `dns:"size-base64`):
@ -226,7 +232,7 @@ func main() {
continue continue
} }
fmt.Fprintf(b, "func (rr *%s) copy() RR {\n", name) fmt.Fprintf(b, "func (rr *%s) copy() RR {\n", name)
fields := []string{"*rr.Hdr.copyHeader()"} fields := []string{"rr.Hdr"}
for i := 1; i < st.NumFields(); i++ { for i := 1; i < st.NumFields(); i++ {
f := st.Field(i).Name() f := st.Field(i).Name()
if sl, ok := st.Field(i).Type().(*types.Slice); ok { if sl, ok := st.Field(i).Type().(*types.Slice); ok {

31
vendor/github.com/miekg/dns/udp.go generated vendored
View file

@ -9,6 +9,22 @@ import (
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
) )
// This is the required size of the OOB buffer to pass to ReadMsgUDP.
var udpOOBSize = func() int {
// We can't know whether we'll get an IPv4 control message or an
// IPv6 control message ahead of time. To get around this, we size
// the buffer equal to the largest of the two.
oob4 := ipv4.NewControlMessage(ipv4.FlagDst | ipv4.FlagInterface)
oob6 := ipv6.NewControlMessage(ipv6.FlagDst | ipv6.FlagInterface)
if len(oob4) > len(oob6) {
return len(oob4)
}
return len(oob6)
}()
// SessionUDP holds the remote address and the associated // SessionUDP holds the remote address and the associated
// out-of-band data. // out-of-band data.
type SessionUDP struct { type SessionUDP struct {
@ -22,7 +38,7 @@ func (s *SessionUDP) RemoteAddr() net.Addr { return s.raddr }
// ReadFromSessionUDP acts just like net.UDPConn.ReadFrom(), but returns a session object instead of a // ReadFromSessionUDP acts just like net.UDPConn.ReadFrom(), but returns a session object instead of a
// net.UDPAddr. // net.UDPAddr.
func ReadFromSessionUDP(conn *net.UDPConn, b []byte) (int, *SessionUDP, error) { func ReadFromSessionUDP(conn *net.UDPConn, b []byte) (int, *SessionUDP, error) {
oob := make([]byte, 40) oob := make([]byte, udpOOBSize)
n, oobn, _, raddr, err := conn.ReadMsgUDP(b, oob) n, oobn, _, raddr, err := conn.ReadMsgUDP(b, oob)
if err != nil { if err != nil {
return n, nil, err return n, nil, err
@ -53,18 +69,15 @@ func parseDstFromOOB(oob []byte) net.IP {
// Start with IPv6 and then fallback to IPv4 // Start with IPv6 and then fallback to IPv4
// TODO(fastest963): Figure out a way to prefer one or the other. Looking at // TODO(fastest963): Figure out a way to prefer one or the other. Looking at
// the lvl of the header for a 0 or 41 isn't cross-platform. // the lvl of the header for a 0 or 41 isn't cross-platform.
var dst net.IP
cm6 := new(ipv6.ControlMessage) cm6 := new(ipv6.ControlMessage)
if cm6.Parse(oob) == nil { if cm6.Parse(oob) == nil && cm6.Dst != nil {
dst = cm6.Dst return cm6.Dst
} }
if dst == nil {
cm4 := new(ipv4.ControlMessage) cm4 := new(ipv4.ControlMessage)
if cm4.Parse(oob) == nil { if cm4.Parse(oob) == nil && cm4.Dst != nil {
dst = cm4.Dst return cm4.Dst
} }
} return nil
return dst
} }
// correctSource takes oob data and returns new oob data with the Src equal to the Dst // correctSource takes oob data and returns new oob data with the Src equal to the Dst

View file

@ -3,7 +3,7 @@ package dns
import "fmt" import "fmt"
// Version is current version of this library. // Version is current version of this library.
var Version = V{1, 0, 4} var Version = V{1, 1, 1}
// V holds the version of this library. // V holds the version of this library.
type V struct { type V struct {

View file

@ -1,118 +0,0 @@
// Code generated by "go run compress_generate.go"; DO NOT EDIT.
package dns
func compressionLenHelperType(c map[string]int, r RR) {
switch x := r.(type) {
case *AFSDB:
compressionLenHelper(c, x.Hostname)
case *CNAME:
compressionLenHelper(c, x.Target)
case *DNAME:
compressionLenHelper(c, x.Target)
case *HIP:
for i := range x.RendezvousServers {
compressionLenHelper(c, x.RendezvousServers[i])
}
case *KX:
compressionLenHelper(c, x.Exchanger)
case *LP:
compressionLenHelper(c, x.Fqdn)
case *MB:
compressionLenHelper(c, x.Mb)
case *MD:
compressionLenHelper(c, x.Md)
case *MF:
compressionLenHelper(c, x.Mf)
case *MG:
compressionLenHelper(c, x.Mg)
case *MINFO:
compressionLenHelper(c, x.Rmail)
compressionLenHelper(c, x.Email)
case *MR:
compressionLenHelper(c, x.Mr)
case *MX:
compressionLenHelper(c, x.Mx)
case *NAPTR:
compressionLenHelper(c, x.Replacement)
case *NS:
compressionLenHelper(c, x.Ns)
case *NSAPPTR:
compressionLenHelper(c, x.Ptr)
case *NSEC:
compressionLenHelper(c, x.NextDomain)
case *PTR:
compressionLenHelper(c, x.Ptr)
case *PX:
compressionLenHelper(c, x.Map822)
compressionLenHelper(c, x.Mapx400)
case *RP:
compressionLenHelper(c, x.Mbox)
compressionLenHelper(c, x.Txt)
case *RRSIG:
compressionLenHelper(c, x.SignerName)
case *RT:
compressionLenHelper(c, x.Host)
case *SIG:
compressionLenHelper(c, x.SignerName)
case *SOA:
compressionLenHelper(c, x.Ns)
compressionLenHelper(c, x.Mbox)
case *SRV:
compressionLenHelper(c, x.Target)
case *TALINK:
compressionLenHelper(c, x.PreviousName)
compressionLenHelper(c, x.NextName)
case *TKEY:
compressionLenHelper(c, x.Algorithm)
case *TSIG:
compressionLenHelper(c, x.Algorithm)
}
}
func compressionLenSearchType(c map[string]int, r RR) (int, bool) {
switch x := r.(type) {
case *AFSDB:
k1, ok1 := compressionLenSearch(c, x.Hostname)
return k1, ok1
case *CNAME:
k1, ok1 := compressionLenSearch(c, x.Target)
return k1, ok1
case *MB:
k1, ok1 := compressionLenSearch(c, x.Mb)
return k1, ok1
case *MD:
k1, ok1 := compressionLenSearch(c, x.Md)
return k1, ok1
case *MF:
k1, ok1 := compressionLenSearch(c, x.Mf)
return k1, ok1
case *MG:
k1, ok1 := compressionLenSearch(c, x.Mg)
return k1, ok1
case *MINFO:
k1, ok1 := compressionLenSearch(c, x.Rmail)
k2, ok2 := compressionLenSearch(c, x.Email)
return k1 + k2, ok1 && ok2
case *MR:
k1, ok1 := compressionLenSearch(c, x.Mr)
return k1, ok1
case *MX:
k1, ok1 := compressionLenSearch(c, x.Mx)
return k1, ok1
case *NS:
k1, ok1 := compressionLenSearch(c, x.Ns)
return k1, ok1
case *PTR:
k1, ok1 := compressionLenSearch(c, x.Ptr)
return k1, ok1
case *RT:
k1, ok1 := compressionLenSearch(c, x.Host)
return k1, ok1
case *SOA:
k1, ok1 := compressionLenSearch(c, x.Ns)
k2, ok2 := compressionLenSearch(c, x.Mbox)
return k1 + k2, ok1 && ok2
}
return 0, false
}

943
vendor/github.com/miekg/dns/zduplicate.go generated vendored Normal file
View file

@ -0,0 +1,943 @@
// Code generated by "go run duplicate_generate.go"; DO NOT EDIT.
package dns
// isDuplicateRdata calls the rdata specific functions
func isDuplicateRdata(r1, r2 RR) bool {
switch r1.Header().Rrtype {
case TypeA:
return isDuplicateA(r1.(*A), r2.(*A))
case TypeAAAA:
return isDuplicateAAAA(r1.(*AAAA), r2.(*AAAA))
case TypeAFSDB:
return isDuplicateAFSDB(r1.(*AFSDB), r2.(*AFSDB))
case TypeAVC:
return isDuplicateAVC(r1.(*AVC), r2.(*AVC))
case TypeCAA:
return isDuplicateCAA(r1.(*CAA), r2.(*CAA))
case TypeCERT:
return isDuplicateCERT(r1.(*CERT), r2.(*CERT))
case TypeCNAME:
return isDuplicateCNAME(r1.(*CNAME), r2.(*CNAME))
case TypeCSYNC:
return isDuplicateCSYNC(r1.(*CSYNC), r2.(*CSYNC))
case TypeDHCID:
return isDuplicateDHCID(r1.(*DHCID), r2.(*DHCID))
case TypeDNAME:
return isDuplicateDNAME(r1.(*DNAME), r2.(*DNAME))
case TypeDNSKEY:
return isDuplicateDNSKEY(r1.(*DNSKEY), r2.(*DNSKEY))
case TypeDS:
return isDuplicateDS(r1.(*DS), r2.(*DS))
case TypeEID:
return isDuplicateEID(r1.(*EID), r2.(*EID))
case TypeEUI48:
return isDuplicateEUI48(r1.(*EUI48), r2.(*EUI48))
case TypeEUI64:
return isDuplicateEUI64(r1.(*EUI64), r2.(*EUI64))
case TypeGID:
return isDuplicateGID(r1.(*GID), r2.(*GID))
case TypeGPOS:
return isDuplicateGPOS(r1.(*GPOS), r2.(*GPOS))
case TypeHINFO:
return isDuplicateHINFO(r1.(*HINFO), r2.(*HINFO))
case TypeHIP:
return isDuplicateHIP(r1.(*HIP), r2.(*HIP))
case TypeKX:
return isDuplicateKX(r1.(*KX), r2.(*KX))
case TypeL32:
return isDuplicateL32(r1.(*L32), r2.(*L32))
case TypeL64:
return isDuplicateL64(r1.(*L64), r2.(*L64))
case TypeLOC:
return isDuplicateLOC(r1.(*LOC), r2.(*LOC))
case TypeLP:
return isDuplicateLP(r1.(*LP), r2.(*LP))
case TypeMB:
return isDuplicateMB(r1.(*MB), r2.(*MB))
case TypeMD:
return isDuplicateMD(r1.(*MD), r2.(*MD))
case TypeMF:
return isDuplicateMF(r1.(*MF), r2.(*MF))
case TypeMG:
return isDuplicateMG(r1.(*MG), r2.(*MG))
case TypeMINFO:
return isDuplicateMINFO(r1.(*MINFO), r2.(*MINFO))
case TypeMR:
return isDuplicateMR(r1.(*MR), r2.(*MR))
case TypeMX:
return isDuplicateMX(r1.(*MX), r2.(*MX))
case TypeNAPTR:
return isDuplicateNAPTR(r1.(*NAPTR), r2.(*NAPTR))
case TypeNID:
return isDuplicateNID(r1.(*NID), r2.(*NID))
case TypeNIMLOC:
return isDuplicateNIMLOC(r1.(*NIMLOC), r2.(*NIMLOC))
case TypeNINFO:
return isDuplicateNINFO(r1.(*NINFO), r2.(*NINFO))
case TypeNS:
return isDuplicateNS(r1.(*NS), r2.(*NS))
case TypeNSAPPTR:
return isDuplicateNSAPPTR(r1.(*NSAPPTR), r2.(*NSAPPTR))
case TypeNSEC:
return isDuplicateNSEC(r1.(*NSEC), r2.(*NSEC))
case TypeNSEC3:
return isDuplicateNSEC3(r1.(*NSEC3), r2.(*NSEC3))
case TypeNSEC3PARAM:
return isDuplicateNSEC3PARAM(r1.(*NSEC3PARAM), r2.(*NSEC3PARAM))
case TypeOPENPGPKEY:
return isDuplicateOPENPGPKEY(r1.(*OPENPGPKEY), r2.(*OPENPGPKEY))
case TypePTR:
return isDuplicatePTR(r1.(*PTR), r2.(*PTR))
case TypePX:
return isDuplicatePX(r1.(*PX), r2.(*PX))
case TypeRKEY:
return isDuplicateRKEY(r1.(*RKEY), r2.(*RKEY))
case TypeRP:
return isDuplicateRP(r1.(*RP), r2.(*RP))
case TypeRRSIG:
return isDuplicateRRSIG(r1.(*RRSIG), r2.(*RRSIG))
case TypeRT:
return isDuplicateRT(r1.(*RT), r2.(*RT))
case TypeSMIMEA:
return isDuplicateSMIMEA(r1.(*SMIMEA), r2.(*SMIMEA))
case TypeSOA:
return isDuplicateSOA(r1.(*SOA), r2.(*SOA))
case TypeSPF:
return isDuplicateSPF(r1.(*SPF), r2.(*SPF))
case TypeSRV:
return isDuplicateSRV(r1.(*SRV), r2.(*SRV))
case TypeSSHFP:
return isDuplicateSSHFP(r1.(*SSHFP), r2.(*SSHFP))
case TypeTA:
return isDuplicateTA(r1.(*TA), r2.(*TA))
case TypeTALINK:
return isDuplicateTALINK(r1.(*TALINK), r2.(*TALINK))
case TypeTKEY:
return isDuplicateTKEY(r1.(*TKEY), r2.(*TKEY))
case TypeTLSA:
return isDuplicateTLSA(r1.(*TLSA), r2.(*TLSA))
case TypeTSIG:
return isDuplicateTSIG(r1.(*TSIG), r2.(*TSIG))
case TypeTXT:
return isDuplicateTXT(r1.(*TXT), r2.(*TXT))
case TypeUID:
return isDuplicateUID(r1.(*UID), r2.(*UID))
case TypeUINFO:
return isDuplicateUINFO(r1.(*UINFO), r2.(*UINFO))
case TypeURI:
return isDuplicateURI(r1.(*URI), r2.(*URI))
case TypeX25:
return isDuplicateX25(r1.(*X25), r2.(*X25))
}
return false
}
// isDuplicate() functions
func isDuplicateA(r1, r2 *A) bool {
if len(r1.A) != len(r2.A) {
return false
}
for i := 0; i < len(r1.A); i++ {
if r1.A[i] != r2.A[i] {
return false
}
}
return true
}
func isDuplicateAAAA(r1, r2 *AAAA) bool {
if len(r1.AAAA) != len(r2.AAAA) {
return false
}
for i := 0; i < len(r1.AAAA); i++ {
if r1.AAAA[i] != r2.AAAA[i] {
return false
}
}
return true
}
func isDuplicateAFSDB(r1, r2 *AFSDB) bool {
if r1.Subtype != r2.Subtype {
return false
}
if !isDulicateName(r1.Hostname, r2.Hostname) {
return false
}
return true
}
func isDuplicateAVC(r1, r2 *AVC) bool {
if len(r1.Txt) != len(r2.Txt) {
return false
}
for i := 0; i < len(r1.Txt); i++ {
if r1.Txt[i] != r2.Txt[i] {
return false
}
}
return true
}
func isDuplicateCAA(r1, r2 *CAA) bool {
if r1.Flag != r2.Flag {
return false
}
if r1.Tag != r2.Tag {
return false
}
if r1.Value != r2.Value {
return false
}
return true
}
func isDuplicateCERT(r1, r2 *CERT) bool {
if r1.Type != r2.Type {
return false
}
if r1.KeyTag != r2.KeyTag {
return false
}
if r1.Algorithm != r2.Algorithm {
return false
}
if r1.Certificate != r2.Certificate {
return false
}
return true
}
func isDuplicateCNAME(r1, r2 *CNAME) bool {
if !isDulicateName(r1.Target, r2.Target) {
return false
}
return true
}
func isDuplicateCSYNC(r1, r2 *CSYNC) bool {
if r1.Serial != r2.Serial {
return false
}
if r1.Flags != r2.Flags {
return false
}
if len(r1.TypeBitMap) != len(r2.TypeBitMap) {
return false
}
for i := 0; i < len(r1.TypeBitMap); i++ {
if r1.TypeBitMap[i] != r2.TypeBitMap[i] {
return false
}
}
return true
}
func isDuplicateDHCID(r1, r2 *DHCID) bool {
if r1.Digest != r2.Digest {
return false
}
return true
}
func isDuplicateDNAME(r1, r2 *DNAME) bool {
if !isDulicateName(r1.Target, r2.Target) {
return false
}
return true
}
func isDuplicateDNSKEY(r1, r2 *DNSKEY) bool {
if r1.Flags != r2.Flags {
return false
}
if r1.Protocol != r2.Protocol {
return false
}
if r1.Algorithm != r2.Algorithm {
return false
}
if r1.PublicKey != r2.PublicKey {
return false
}
return true
}
func isDuplicateDS(r1, r2 *DS) bool {
if r1.KeyTag != r2.KeyTag {
return false
}
if r1.Algorithm != r2.Algorithm {
return false
}
if r1.DigestType != r2.DigestType {
return false
}
if r1.Digest != r2.Digest {
return false
}
return true
}
func isDuplicateEID(r1, r2 *EID) bool {
if r1.Endpoint != r2.Endpoint {
return false
}
return true
}
func isDuplicateEUI48(r1, r2 *EUI48) bool {
if r1.Address != r2.Address {
return false
}
return true
}
func isDuplicateEUI64(r1, r2 *EUI64) bool {
if r1.Address != r2.Address {
return false
}
return true
}
func isDuplicateGID(r1, r2 *GID) bool {
if r1.Gid != r2.Gid {
return false
}
return true
}
func isDuplicateGPOS(r1, r2 *GPOS) bool {
if r1.Longitude != r2.Longitude {
return false
}
if r1.Latitude != r2.Latitude {
return false
}
if r1.Altitude != r2.Altitude {
return false
}
return true
}
func isDuplicateHINFO(r1, r2 *HINFO) bool {
if r1.Cpu != r2.Cpu {
return false
}
if r1.Os != r2.Os {
return false
}
return true
}
func isDuplicateHIP(r1, r2 *HIP) bool {
if r1.HitLength != r2.HitLength {
return false
}
if r1.PublicKeyAlgorithm != r2.PublicKeyAlgorithm {
return false
}
if r1.PublicKeyLength != r2.PublicKeyLength {
return false
}
if r1.Hit != r2.Hit {
return false
}
if r1.PublicKey != r2.PublicKey {
return false
}
if len(r1.RendezvousServers) != len(r2.RendezvousServers) {
return false
}
for i := 0; i < len(r1.RendezvousServers); i++ {
if !isDulicateName(r1.RendezvousServers[i], r2.RendezvousServers[i]) {
return false
}
}
return true
}
func isDuplicateKX(r1, r2 *KX) bool {
if r1.Preference != r2.Preference {
return false
}
if !isDulicateName(r1.Exchanger, r2.Exchanger) {
return false
}
return true
}
func isDuplicateL32(r1, r2 *L32) bool {
if r1.Preference != r2.Preference {
return false
}
if len(r1.Locator32) != len(r2.Locator32) {
return false
}
for i := 0; i < len(r1.Locator32); i++ {
if r1.Locator32[i] != r2.Locator32[i] {
return false
}
}
return true
}
func isDuplicateL64(r1, r2 *L64) bool {
if r1.Preference != r2.Preference {
return false
}
if r1.Locator64 != r2.Locator64 {
return false
}
return true
}
func isDuplicateLOC(r1, r2 *LOC) bool {
if r1.Version != r2.Version {
return false
}
if r1.Size != r2.Size {
return false
}
if r1.HorizPre != r2.HorizPre {
return false
}
if r1.VertPre != r2.VertPre {
return false
}
if r1.Latitude != r2.Latitude {
return false
}
if r1.Longitude != r2.Longitude {
return false
}
if r1.Altitude != r2.Altitude {
return false
}
return true
}
func isDuplicateLP(r1, r2 *LP) bool {
if r1.Preference != r2.Preference {
return false
}
if !isDulicateName(r1.Fqdn, r2.Fqdn) {
return false
}
return true
}
func isDuplicateMB(r1, r2 *MB) bool {
if !isDulicateName(r1.Mb, r2.Mb) {
return false
}
return true
}
func isDuplicateMD(r1, r2 *MD) bool {
if !isDulicateName(r1.Md, r2.Md) {
return false
}
return true
}
func isDuplicateMF(r1, r2 *MF) bool {
if !isDulicateName(r1.Mf, r2.Mf) {
return false
}
return true
}
func isDuplicateMG(r1, r2 *MG) bool {
if !isDulicateName(r1.Mg, r2.Mg) {
return false
}
return true
}
func isDuplicateMINFO(r1, r2 *MINFO) bool {
if !isDulicateName(r1.Rmail, r2.Rmail) {
return false
}
if !isDulicateName(r1.Email, r2.Email) {
return false
}
return true
}
func isDuplicateMR(r1, r2 *MR) bool {
if !isDulicateName(r1.Mr, r2.Mr) {
return false
}
return true
}
func isDuplicateMX(r1, r2 *MX) bool {
if r1.Preference != r2.Preference {
return false
}
if !isDulicateName(r1.Mx, r2.Mx) {
return false
}
return true
}
func isDuplicateNAPTR(r1, r2 *NAPTR) bool {
if r1.Order != r2.Order {
return false
}
if r1.Preference != r2.Preference {
return false
}
if r1.Flags != r2.Flags {
return false
}
if r1.Service != r2.Service {
return false
}
if r1.Regexp != r2.Regexp {
return false
}
if !isDulicateName(r1.Replacement, r2.Replacement) {
return false
}
return true
}
func isDuplicateNID(r1, r2 *NID) bool {
if r1.Preference != r2.Preference {
return false
}
if r1.NodeID != r2.NodeID {
return false
}
return true
}
func isDuplicateNIMLOC(r1, r2 *NIMLOC) bool {
if r1.Locator != r2.Locator {
return false
}
return true
}
func isDuplicateNINFO(r1, r2 *NINFO) bool {
if len(r1.ZSData) != len(r2.ZSData) {
return false
}
for i := 0; i < len(r1.ZSData); i++ {
if r1.ZSData[i] != r2.ZSData[i] {
return false
}
}
return true
}
func isDuplicateNS(r1, r2 *NS) bool {
if !isDulicateName(r1.Ns, r2.Ns) {
return false
}
return true
}
func isDuplicateNSAPPTR(r1, r2 *NSAPPTR) bool {
if !isDulicateName(r1.Ptr, r2.Ptr) {
return false
}
return true
}
func isDuplicateNSEC(r1, r2 *NSEC) bool {
if !isDulicateName(r1.NextDomain, r2.NextDomain) {
return false
}
if len(r1.TypeBitMap) != len(r2.TypeBitMap) {
return false
}
for i := 0; i < len(r1.TypeBitMap); i++ {
if r1.TypeBitMap[i] != r2.TypeBitMap[i] {
return false
}
}
return true
}
func isDuplicateNSEC3(r1, r2 *NSEC3) bool {
if r1.Hash != r2.Hash {
return false
}
if r1.Flags != r2.Flags {
return false
}
if r1.Iterations != r2.Iterations {
return false
}
if r1.SaltLength != r2.SaltLength {
return false
}
if r1.Salt != r2.Salt {
return false
}
if r1.HashLength != r2.HashLength {
return false
}
if r1.NextDomain != r2.NextDomain {
return false
}
if len(r1.TypeBitMap) != len(r2.TypeBitMap) {
return false
}
for i := 0; i < len(r1.TypeBitMap); i++ {
if r1.TypeBitMap[i] != r2.TypeBitMap[i] {
return false
}
}
return true
}
func isDuplicateNSEC3PARAM(r1, r2 *NSEC3PARAM) bool {
if r1.Hash != r2.Hash {
return false
}
if r1.Flags != r2.Flags {
return false
}
if r1.Iterations != r2.Iterations {
return false
}
if r1.SaltLength != r2.SaltLength {
return false
}
if r1.Salt != r2.Salt {
return false
}
return true
}
func isDuplicateOPENPGPKEY(r1, r2 *OPENPGPKEY) bool {
if r1.PublicKey != r2.PublicKey {
return false
}
return true
}
func isDuplicatePTR(r1, r2 *PTR) bool {
if !isDulicateName(r1.Ptr, r2.Ptr) {
return false
}
return true
}
func isDuplicatePX(r1, r2 *PX) bool {
if r1.Preference != r2.Preference {
return false
}
if !isDulicateName(r1.Map822, r2.Map822) {
return false
}
if !isDulicateName(r1.Mapx400, r2.Mapx400) {
return false
}
return true
}
func isDuplicateRKEY(r1, r2 *RKEY) bool {
if r1.Flags != r2.Flags {
return false
}
if r1.Protocol != r2.Protocol {
return false
}
if r1.Algorithm != r2.Algorithm {
return false
}
if r1.PublicKey != r2.PublicKey {
return false
}
return true
}
func isDuplicateRP(r1, r2 *RP) bool {
if !isDulicateName(r1.Mbox, r2.Mbox) {
return false
}
if !isDulicateName(r1.Txt, r2.Txt) {
return false
}
return true
}
func isDuplicateRRSIG(r1, r2 *RRSIG) bool {
if r1.TypeCovered != r2.TypeCovered {
return false
}
if r1.Algorithm != r2.Algorithm {
return false
}
if r1.Labels != r2.Labels {
return false
}
if r1.OrigTtl != r2.OrigTtl {
return false
}
if r1.Expiration != r2.Expiration {
return false
}
if r1.Inception != r2.Inception {
return false
}
if r1.KeyTag != r2.KeyTag {
return false
}
if !isDulicateName(r1.SignerName, r2.SignerName) {
return false
}
if r1.Signature != r2.Signature {
return false
}
return true
}
func isDuplicateRT(r1, r2 *RT) bool {
if r1.Preference != r2.Preference {
return false
}
if !isDulicateName(r1.Host, r2.Host) {
return false
}
return true
}
func isDuplicateSMIMEA(r1, r2 *SMIMEA) bool {
if r1.Usage != r2.Usage {
return false
}
if r1.Selector != r2.Selector {
return false
}
if r1.MatchingType != r2.MatchingType {
return false
}
if r1.Certificate != r2.Certificate {
return false
}
return true
}
func isDuplicateSOA(r1, r2 *SOA) bool {
if !isDulicateName(r1.Ns, r2.Ns) {
return false
}
if !isDulicateName(r1.Mbox, r2.Mbox) {
return false
}
if r1.Serial != r2.Serial {
return false
}
if r1.Refresh != r2.Refresh {
return false
}
if r1.Retry != r2.Retry {
return false
}
if r1.Expire != r2.Expire {
return false
}
if r1.Minttl != r2.Minttl {
return false
}
return true
}
func isDuplicateSPF(r1, r2 *SPF) bool {
if len(r1.Txt) != len(r2.Txt) {
return false
}
for i := 0; i < len(r1.Txt); i++ {
if r1.Txt[i] != r2.Txt[i] {
return false
}
}
return true
}
func isDuplicateSRV(r1, r2 *SRV) bool {
if r1.Priority != r2.Priority {
return false
}
if r1.Weight != r2.Weight {
return false
}
if r1.Port != r2.Port {
return false
}
if !isDulicateName(r1.Target, r2.Target) {
return false
}
return true
}
func isDuplicateSSHFP(r1, r2 *SSHFP) bool {
if r1.Algorithm != r2.Algorithm {
return false
}
if r1.Type != r2.Type {
return false
}
if r1.FingerPrint != r2.FingerPrint {
return false
}
return true
}
func isDuplicateTA(r1, r2 *TA) bool {
if r1.KeyTag != r2.KeyTag {
return false
}
if r1.Algorithm != r2.Algorithm {
return false
}
if r1.DigestType != r2.DigestType {
return false
}
if r1.Digest != r2.Digest {
return false
}
return true
}
func isDuplicateTALINK(r1, r2 *TALINK) bool {
if !isDulicateName(r1.PreviousName, r2.PreviousName) {
return false
}
if !isDulicateName(r1.NextName, r2.NextName) {
return false
}
return true
}
func isDuplicateTKEY(r1, r2 *TKEY) bool {
if !isDulicateName(r1.Algorithm, r2.Algorithm) {
return false
}
if r1.Inception != r2.Inception {
return false
}
if r1.Expiration != r2.Expiration {
return false
}
if r1.Mode != r2.Mode {
return false
}
if r1.Error != r2.Error {
return false
}
if r1.KeySize != r2.KeySize {
return false
}
if r1.Key != r2.Key {
return false
}
if r1.OtherLen != r2.OtherLen {
return false
}
if r1.OtherData != r2.OtherData {
return false
}
return true
}
func isDuplicateTLSA(r1, r2 *TLSA) bool {
if r1.Usage != r2.Usage {
return false
}
if r1.Selector != r2.Selector {
return false
}
if r1.MatchingType != r2.MatchingType {
return false
}
if r1.Certificate != r2.Certificate {
return false
}
return true
}
func isDuplicateTSIG(r1, r2 *TSIG) bool {
if !isDulicateName(r1.Algorithm, r2.Algorithm) {
return false
}
if r1.TimeSigned != r2.TimeSigned {
return false
}
if r1.Fudge != r2.Fudge {
return false
}
if r1.MACSize != r2.MACSize {
return false
}
if r1.MAC != r2.MAC {
return false
}
if r1.OrigId != r2.OrigId {
return false
}
if r1.Error != r2.Error {
return false
}
if r1.OtherLen != r2.OtherLen {
return false
}
if r1.OtherData != r2.OtherData {
return false
}
return true
}
func isDuplicateTXT(r1, r2 *TXT) bool {
if len(r1.Txt) != len(r2.Txt) {
return false
}
for i := 0; i < len(r1.Txt); i++ {
if r1.Txt[i] != r2.Txt[i] {
return false
}
}
return true
}
func isDuplicateUID(r1, r2 *UID) bool {
if r1.Uid != r2.Uid {
return false
}
return true
}
func isDuplicateUINFO(r1, r2 *UINFO) bool {
if r1.Uinfo != r2.Uinfo {
return false
}
return true
}
func isDuplicateURI(r1, r2 *URI) bool {
if r1.Priority != r2.Priority {
return false
}
if r1.Weight != r2.Weight {
return false
}
if r1.Target != r2.Target {
return false
}
return true
}
func isDuplicateX25(r1, r2 *X25) bool {
if r1.PSDNAddress != r2.PSDNAddress {
return false
}
return true
}

1156
vendor/github.com/miekg/dns/zmsg.go generated vendored

File diff suppressed because it is too large Load diff

436
vendor/github.com/miekg/dns/ztypes.go generated vendored
View file

@ -236,144 +236,144 @@ func (rr *URI) Header() *RR_Header { return &rr.Hdr }
func (rr *X25) Header() *RR_Header { return &rr.Hdr } func (rr *X25) Header() *RR_Header { return &rr.Hdr }
// len() functions // len() functions
func (rr *A) len() int { func (rr *A) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += net.IPv4len // A l += net.IPv4len // A
return l return l
} }
func (rr *AAAA) len() int { func (rr *AAAA) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += net.IPv6len // AAAA l += net.IPv6len // AAAA
return l return l
} }
func (rr *AFSDB) len() int { func (rr *AFSDB) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += 2 // Subtype l += 2 // Subtype
l += len(rr.Hostname) + 1 l += domainNameLen(rr.Hostname, off+l, compression, false)
return l return l
} }
func (rr *ANY) len() int { func (rr *ANY) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
return l return l
} }
func (rr *AVC) len() int { func (rr *AVC) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
for _, x := range rr.Txt { for _, x := range rr.Txt {
l += len(x) + 1 l += len(x) + 1
} }
return l return l
} }
func (rr *CAA) len() int { func (rr *CAA) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l++ // Flag l++ // Flag
l += len(rr.Tag) + 1 l += len(rr.Tag) + 1
l += len(rr.Value) l += len(rr.Value)
return l return l
} }
func (rr *CERT) len() int { func (rr *CERT) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += 2 // Type l += 2 // Type
l += 2 // KeyTag l += 2 // KeyTag
l++ // Algorithm l++ // Algorithm
l += base64.StdEncoding.DecodedLen(len(rr.Certificate)) l += base64.StdEncoding.DecodedLen(len(rr.Certificate))
return l return l
} }
func (rr *CNAME) len() int { func (rr *CNAME) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.Target) + 1 l += domainNameLen(rr.Target, off+l, compression, true)
return l return l
} }
func (rr *DHCID) len() int { func (rr *DHCID) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += base64.StdEncoding.DecodedLen(len(rr.Digest)) l += base64.StdEncoding.DecodedLen(len(rr.Digest))
return l return l
} }
func (rr *DNAME) len() int { func (rr *DNAME) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.Target) + 1 l += domainNameLen(rr.Target, off+l, compression, false)
return l return l
} }
func (rr *DNSKEY) len() int { func (rr *DNSKEY) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += 2 // Flags l += 2 // Flags
l++ // Protocol l++ // Protocol
l++ // Algorithm l++ // Algorithm
l += base64.StdEncoding.DecodedLen(len(rr.PublicKey)) l += base64.StdEncoding.DecodedLen(len(rr.PublicKey))
return l return l
} }
func (rr *DS) len() int { func (rr *DS) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += 2 // KeyTag l += 2 // KeyTag
l++ // Algorithm l++ // Algorithm
l++ // DigestType l++ // DigestType
l += len(rr.Digest)/2 + 1 l += len(rr.Digest)/2 + 1
return l return l
} }
func (rr *EID) len() int { func (rr *EID) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.Endpoint)/2 + 1 l += len(rr.Endpoint)/2 + 1
return l return l
} }
func (rr *EUI48) len() int { func (rr *EUI48) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += 6 // Address l += 6 // Address
return l return l
} }
func (rr *EUI64) len() int { func (rr *EUI64) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += 8 // Address l += 8 // Address
return l return l
} }
func (rr *GID) len() int { func (rr *GID) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += 4 // Gid l += 4 // Gid
return l return l
} }
func (rr *GPOS) len() int { func (rr *GPOS) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.Longitude) + 1 l += len(rr.Longitude) + 1
l += len(rr.Latitude) + 1 l += len(rr.Latitude) + 1
l += len(rr.Altitude) + 1 l += len(rr.Altitude) + 1
return l return l
} }
func (rr *HINFO) len() int { func (rr *HINFO) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.Cpu) + 1 l += len(rr.Cpu) + 1
l += len(rr.Os) + 1 l += len(rr.Os) + 1
return l return l
} }
func (rr *HIP) len() int { func (rr *HIP) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l++ // HitLength l++ // HitLength
l++ // PublicKeyAlgorithm l++ // PublicKeyAlgorithm
l += 2 // PublicKeyLength l += 2 // PublicKeyLength
l += len(rr.Hit) / 2 l += len(rr.Hit) / 2
l += base64.StdEncoding.DecodedLen(len(rr.PublicKey)) l += base64.StdEncoding.DecodedLen(len(rr.PublicKey))
for _, x := range rr.RendezvousServers { for _, x := range rr.RendezvousServers {
l += len(x) + 1 l += domainNameLen(x, off+l, compression, false)
} }
return l return l
} }
func (rr *KX) len() int { func (rr *KX) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += 2 // Preference l += 2 // Preference
l += len(rr.Exchanger) + 1 l += domainNameLen(rr.Exchanger, off+l, compression, false)
return l return l
} }
func (rr *L32) len() int { func (rr *L32) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += 2 // Preference l += 2 // Preference
l += net.IPv4len // Locator32 l += net.IPv4len // Locator32
return l return l
} }
func (rr *L64) len() int { func (rr *L64) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += 2 // Preference l += 2 // Preference
l += 8 // Locator64 l += 8 // Locator64
return l return l
} }
func (rr *LOC) len() int { func (rr *LOC) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l++ // Version l++ // Version
l++ // Size l++ // Size
l++ // HorizPre l++ // HorizPre
@ -383,89 +383,89 @@ func (rr *LOC) len() int {
l += 4 // Altitude l += 4 // Altitude
return l return l
} }
func (rr *LP) len() int { func (rr *LP) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += 2 // Preference l += 2 // Preference
l += len(rr.Fqdn) + 1 l += domainNameLen(rr.Fqdn, off+l, compression, false)
return l return l
} }
func (rr *MB) len() int { func (rr *MB) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.Mb) + 1 l += domainNameLen(rr.Mb, off+l, compression, true)
return l return l
} }
func (rr *MD) len() int { func (rr *MD) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.Md) + 1 l += domainNameLen(rr.Md, off+l, compression, true)
return l return l
} }
func (rr *MF) len() int { func (rr *MF) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.Mf) + 1 l += domainNameLen(rr.Mf, off+l, compression, true)
return l return l
} }
func (rr *MG) len() int { func (rr *MG) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.Mg) + 1 l += domainNameLen(rr.Mg, off+l, compression, true)
return l return l
} }
func (rr *MINFO) len() int { func (rr *MINFO) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.Rmail) + 1 l += domainNameLen(rr.Rmail, off+l, compression, true)
l += len(rr.Email) + 1 l += domainNameLen(rr.Email, off+l, compression, true)
return l return l
} }
func (rr *MR) len() int { func (rr *MR) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.Mr) + 1 l += domainNameLen(rr.Mr, off+l, compression, true)
return l return l
} }
func (rr *MX) len() int { func (rr *MX) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += 2 // Preference l += 2 // Preference
l += len(rr.Mx) + 1 l += domainNameLen(rr.Mx, off+l, compression, true)
return l return l
} }
func (rr *NAPTR) len() int { func (rr *NAPTR) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += 2 // Order l += 2 // Order
l += 2 // Preference l += 2 // Preference
l += len(rr.Flags) + 1 l += len(rr.Flags) + 1
l += len(rr.Service) + 1 l += len(rr.Service) + 1
l += len(rr.Regexp) + 1 l += len(rr.Regexp) + 1
l += len(rr.Replacement) + 1 l += domainNameLen(rr.Replacement, off+l, compression, false)
return l return l
} }
func (rr *NID) len() int { func (rr *NID) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += 2 // Preference l += 2 // Preference
l += 8 // NodeID l += 8 // NodeID
return l return l
} }
func (rr *NIMLOC) len() int { func (rr *NIMLOC) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.Locator)/2 + 1 l += len(rr.Locator)/2 + 1
return l return l
} }
func (rr *NINFO) len() int { func (rr *NINFO) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
for _, x := range rr.ZSData { for _, x := range rr.ZSData {
l += len(x) + 1 l += len(x) + 1
} }
return l return l
} }
func (rr *NS) len() int { func (rr *NS) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.Ns) + 1 l += domainNameLen(rr.Ns, off+l, compression, true)
return l return l
} }
func (rr *NSAPPTR) len() int { func (rr *NSAPPTR) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.Ptr) + 1 l += domainNameLen(rr.Ptr, off+l, compression, false)
return l return l
} }
func (rr *NSEC3PARAM) len() int { func (rr *NSEC3PARAM) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l++ // Hash l++ // Hash
l++ // Flags l++ // Flags
l += 2 // Iterations l += 2 // Iterations
@ -473,44 +473,44 @@ func (rr *NSEC3PARAM) len() int {
l += len(rr.Salt) / 2 l += len(rr.Salt) / 2
return l return l
} }
func (rr *OPENPGPKEY) len() int { func (rr *OPENPGPKEY) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += base64.StdEncoding.DecodedLen(len(rr.PublicKey)) l += base64.StdEncoding.DecodedLen(len(rr.PublicKey))
return l return l
} }
func (rr *PTR) len() int { func (rr *PTR) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.Ptr) + 1 l += domainNameLen(rr.Ptr, off+l, compression, true)
return l return l
} }
func (rr *PX) len() int { func (rr *PX) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += 2 // Preference l += 2 // Preference
l += len(rr.Map822) + 1 l += domainNameLen(rr.Map822, off+l, compression, false)
l += len(rr.Mapx400) + 1 l += domainNameLen(rr.Mapx400, off+l, compression, false)
return l return l
} }
func (rr *RFC3597) len() int { func (rr *RFC3597) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.Rdata)/2 + 1 l += len(rr.Rdata)/2 + 1
return l return l
} }
func (rr *RKEY) len() int { func (rr *RKEY) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += 2 // Flags l += 2 // Flags
l++ // Protocol l++ // Protocol
l++ // Algorithm l++ // Algorithm
l += base64.StdEncoding.DecodedLen(len(rr.PublicKey)) l += base64.StdEncoding.DecodedLen(len(rr.PublicKey))
return l return l
} }
func (rr *RP) len() int { func (rr *RP) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.Mbox) + 1 l += domainNameLen(rr.Mbox, off+l, compression, false)
l += len(rr.Txt) + 1 l += domainNameLen(rr.Txt, off+l, compression, false)
return l return l
} }
func (rr *RRSIG) len() int { func (rr *RRSIG) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += 2 // TypeCovered l += 2 // TypeCovered
l++ // Algorithm l++ // Algorithm
l++ // Labels l++ // Labels
@ -518,28 +518,28 @@ func (rr *RRSIG) len() int {
l += 4 // Expiration l += 4 // Expiration
l += 4 // Inception l += 4 // Inception
l += 2 // KeyTag l += 2 // KeyTag
l += len(rr.SignerName) + 1 l += domainNameLen(rr.SignerName, off+l, compression, false)
l += base64.StdEncoding.DecodedLen(len(rr.Signature)) l += base64.StdEncoding.DecodedLen(len(rr.Signature))
return l return l
} }
func (rr *RT) len() int { func (rr *RT) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += 2 // Preference l += 2 // Preference
l += len(rr.Host) + 1 l += domainNameLen(rr.Host, off+l, compression, false)
return l return l
} }
func (rr *SMIMEA) len() int { func (rr *SMIMEA) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l++ // Usage l++ // Usage
l++ // Selector l++ // Selector
l++ // MatchingType l++ // MatchingType
l += len(rr.Certificate)/2 + 1 l += len(rr.Certificate)/2 + 1
return l return l
} }
func (rr *SOA) len() int { func (rr *SOA) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.Ns) + 1 l += domainNameLen(rr.Ns, off+l, compression, true)
l += len(rr.Mbox) + 1 l += domainNameLen(rr.Mbox, off+l, compression, true)
l += 4 // Serial l += 4 // Serial
l += 4 // Refresh l += 4 // Refresh
l += 4 // Retry l += 4 // Retry
@ -547,45 +547,45 @@ func (rr *SOA) len() int {
l += 4 // Minttl l += 4 // Minttl
return l return l
} }
func (rr *SPF) len() int { func (rr *SPF) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
for _, x := range rr.Txt { for _, x := range rr.Txt {
l += len(x) + 1 l += len(x) + 1
} }
return l return l
} }
func (rr *SRV) len() int { func (rr *SRV) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += 2 // Priority l += 2 // Priority
l += 2 // Weight l += 2 // Weight
l += 2 // Port l += 2 // Port
l += len(rr.Target) + 1 l += domainNameLen(rr.Target, off+l, compression, false)
return l return l
} }
func (rr *SSHFP) len() int { func (rr *SSHFP) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l++ // Algorithm l++ // Algorithm
l++ // Type l++ // Type
l += len(rr.FingerPrint)/2 + 1 l += len(rr.FingerPrint)/2 + 1
return l return l
} }
func (rr *TA) len() int { func (rr *TA) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += 2 // KeyTag l += 2 // KeyTag
l++ // Algorithm l++ // Algorithm
l++ // DigestType l++ // DigestType
l += len(rr.Digest)/2 + 1 l += len(rr.Digest)/2 + 1
return l return l
} }
func (rr *TALINK) len() int { func (rr *TALINK) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.PreviousName) + 1 l += domainNameLen(rr.PreviousName, off+l, compression, false)
l += len(rr.NextName) + 1 l += domainNameLen(rr.NextName, off+l, compression, false)
return l return l
} }
func (rr *TKEY) len() int { func (rr *TKEY) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.Algorithm) + 1 l += domainNameLen(rr.Algorithm, off+l, compression, false)
l += 4 // Inception l += 4 // Inception
l += 4 // Expiration l += 4 // Expiration
l += 2 // Mode l += 2 // Mode
@ -596,17 +596,17 @@ func (rr *TKEY) len() int {
l += len(rr.OtherData) / 2 l += len(rr.OtherData) / 2
return l return l
} }
func (rr *TLSA) len() int { func (rr *TLSA) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l++ // Usage l++ // Usage
l++ // Selector l++ // Selector
l++ // MatchingType l++ // MatchingType
l += len(rr.Certificate)/2 + 1 l += len(rr.Certificate)/2 + 1
return l return l
} }
func (rr *TSIG) len() int { func (rr *TSIG) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.Algorithm) + 1 l += domainNameLen(rr.Algorithm, off+l, compression, false)
l += 6 // TimeSigned l += 6 // TimeSigned
l += 2 // Fudge l += 2 // Fudge
l += 2 // MACSize l += 2 // MACSize
@ -617,247 +617,247 @@ func (rr *TSIG) len() int {
l += len(rr.OtherData) / 2 l += len(rr.OtherData) / 2
return l return l
} }
func (rr *TXT) len() int { func (rr *TXT) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
for _, x := range rr.Txt { for _, x := range rr.Txt {
l += len(x) + 1 l += len(x) + 1
} }
return l return l
} }
func (rr *UID) len() int { func (rr *UID) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += 4 // Uid l += 4 // Uid
return l return l
} }
func (rr *UINFO) len() int { func (rr *UINFO) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.Uinfo) + 1 l += len(rr.Uinfo) + 1
return l return l
} }
func (rr *URI) len() int { func (rr *URI) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += 2 // Priority l += 2 // Priority
l += 2 // Weight l += 2 // Weight
l += len(rr.Target) l += len(rr.Target)
return l return l
} }
func (rr *X25) len() int { func (rr *X25) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len() l := rr.Hdr.len(off, compression)
l += len(rr.PSDNAddress) + 1 l += len(rr.PSDNAddress) + 1
return l return l
} }
// copy() functions // copy() functions
func (rr *A) copy() RR { func (rr *A) copy() RR {
return &A{*rr.Hdr.copyHeader(), copyIP(rr.A)} return &A{rr.Hdr, copyIP(rr.A)}
} }
func (rr *AAAA) copy() RR { func (rr *AAAA) copy() RR {
return &AAAA{*rr.Hdr.copyHeader(), copyIP(rr.AAAA)} return &AAAA{rr.Hdr, copyIP(rr.AAAA)}
} }
func (rr *AFSDB) copy() RR { func (rr *AFSDB) copy() RR {
return &AFSDB{*rr.Hdr.copyHeader(), rr.Subtype, rr.Hostname} return &AFSDB{rr.Hdr, rr.Subtype, rr.Hostname}
} }
func (rr *ANY) copy() RR { func (rr *ANY) copy() RR {
return &ANY{*rr.Hdr.copyHeader()} return &ANY{rr.Hdr}
} }
func (rr *AVC) copy() RR { func (rr *AVC) copy() RR {
Txt := make([]string, len(rr.Txt)) Txt := make([]string, len(rr.Txt))
copy(Txt, rr.Txt) copy(Txt, rr.Txt)
return &AVC{*rr.Hdr.copyHeader(), Txt} return &AVC{rr.Hdr, Txt}
} }
func (rr *CAA) copy() RR { func (rr *CAA) copy() RR {
return &CAA{*rr.Hdr.copyHeader(), rr.Flag, rr.Tag, rr.Value} return &CAA{rr.Hdr, rr.Flag, rr.Tag, rr.Value}
} }
func (rr *CERT) copy() RR { func (rr *CERT) copy() RR {
return &CERT{*rr.Hdr.copyHeader(), rr.Type, rr.KeyTag, rr.Algorithm, rr.Certificate} return &CERT{rr.Hdr, rr.Type, rr.KeyTag, rr.Algorithm, rr.Certificate}
} }
func (rr *CNAME) copy() RR { func (rr *CNAME) copy() RR {
return &CNAME{*rr.Hdr.copyHeader(), rr.Target} return &CNAME{rr.Hdr, rr.Target}
} }
func (rr *CSYNC) copy() RR { func (rr *CSYNC) copy() RR {
TypeBitMap := make([]uint16, len(rr.TypeBitMap)) TypeBitMap := make([]uint16, len(rr.TypeBitMap))
copy(TypeBitMap, rr.TypeBitMap) copy(TypeBitMap, rr.TypeBitMap)
return &CSYNC{*rr.Hdr.copyHeader(), rr.Serial, rr.Flags, TypeBitMap} return &CSYNC{rr.Hdr, rr.Serial, rr.Flags, TypeBitMap}
} }
func (rr *DHCID) copy() RR { func (rr *DHCID) copy() RR {
return &DHCID{*rr.Hdr.copyHeader(), rr.Digest} return &DHCID{rr.Hdr, rr.Digest}
} }
func (rr *DNAME) copy() RR { func (rr *DNAME) copy() RR {
return &DNAME{*rr.Hdr.copyHeader(), rr.Target} return &DNAME{rr.Hdr, rr.Target}
} }
func (rr *DNSKEY) copy() RR { func (rr *DNSKEY) copy() RR {
return &DNSKEY{*rr.Hdr.copyHeader(), rr.Flags, rr.Protocol, rr.Algorithm, rr.PublicKey} return &DNSKEY{rr.Hdr, rr.Flags, rr.Protocol, rr.Algorithm, rr.PublicKey}
} }
func (rr *DS) copy() RR { func (rr *DS) copy() RR {
return &DS{*rr.Hdr.copyHeader(), rr.KeyTag, rr.Algorithm, rr.DigestType, rr.Digest} return &DS{rr.Hdr, rr.KeyTag, rr.Algorithm, rr.DigestType, rr.Digest}
} }
func (rr *EID) copy() RR { func (rr *EID) copy() RR {
return &EID{*rr.Hdr.copyHeader(), rr.Endpoint} return &EID{rr.Hdr, rr.Endpoint}
} }
func (rr *EUI48) copy() RR { func (rr *EUI48) copy() RR {
return &EUI48{*rr.Hdr.copyHeader(), rr.Address} return &EUI48{rr.Hdr, rr.Address}
} }
func (rr *EUI64) copy() RR { func (rr *EUI64) copy() RR {
return &EUI64{*rr.Hdr.copyHeader(), rr.Address} return &EUI64{rr.Hdr, rr.Address}
} }
func (rr *GID) copy() RR { func (rr *GID) copy() RR {
return &GID{*rr.Hdr.copyHeader(), rr.Gid} return &GID{rr.Hdr, rr.Gid}
} }
func (rr *GPOS) copy() RR { func (rr *GPOS) copy() RR {
return &GPOS{*rr.Hdr.copyHeader(), rr.Longitude, rr.Latitude, rr.Altitude} return &GPOS{rr.Hdr, rr.Longitude, rr.Latitude, rr.Altitude}
} }
func (rr *HINFO) copy() RR { func (rr *HINFO) copy() RR {
return &HINFO{*rr.Hdr.copyHeader(), rr.Cpu, rr.Os} return &HINFO{rr.Hdr, rr.Cpu, rr.Os}
} }
func (rr *HIP) copy() RR { func (rr *HIP) copy() RR {
RendezvousServers := make([]string, len(rr.RendezvousServers)) RendezvousServers := make([]string, len(rr.RendezvousServers))
copy(RendezvousServers, rr.RendezvousServers) copy(RendezvousServers, rr.RendezvousServers)
return &HIP{*rr.Hdr.copyHeader(), rr.HitLength, rr.PublicKeyAlgorithm, rr.PublicKeyLength, rr.Hit, rr.PublicKey, RendezvousServers} return &HIP{rr.Hdr, rr.HitLength, rr.PublicKeyAlgorithm, rr.PublicKeyLength, rr.Hit, rr.PublicKey, RendezvousServers}
} }
func (rr *KX) copy() RR { func (rr *KX) copy() RR {
return &KX{*rr.Hdr.copyHeader(), rr.Preference, rr.Exchanger} return &KX{rr.Hdr, rr.Preference, rr.Exchanger}
} }
func (rr *L32) copy() RR { func (rr *L32) copy() RR {
return &L32{*rr.Hdr.copyHeader(), rr.Preference, copyIP(rr.Locator32)} return &L32{rr.Hdr, rr.Preference, copyIP(rr.Locator32)}
} }
func (rr *L64) copy() RR { func (rr *L64) copy() RR {
return &L64{*rr.Hdr.copyHeader(), rr.Preference, rr.Locator64} return &L64{rr.Hdr, rr.Preference, rr.Locator64}
} }
func (rr *LOC) copy() RR { func (rr *LOC) copy() RR {
return &LOC{*rr.Hdr.copyHeader(), rr.Version, rr.Size, rr.HorizPre, rr.VertPre, rr.Latitude, rr.Longitude, rr.Altitude} return &LOC{rr.Hdr, rr.Version, rr.Size, rr.HorizPre, rr.VertPre, rr.Latitude, rr.Longitude, rr.Altitude}
} }
func (rr *LP) copy() RR { func (rr *LP) copy() RR {
return &LP{*rr.Hdr.copyHeader(), rr.Preference, rr.Fqdn} return &LP{rr.Hdr, rr.Preference, rr.Fqdn}
} }
func (rr *MB) copy() RR { func (rr *MB) copy() RR {
return &MB{*rr.Hdr.copyHeader(), rr.Mb} return &MB{rr.Hdr, rr.Mb}
} }
func (rr *MD) copy() RR { func (rr *MD) copy() RR {
return &MD{*rr.Hdr.copyHeader(), rr.Md} return &MD{rr.Hdr, rr.Md}
} }
func (rr *MF) copy() RR { func (rr *MF) copy() RR {
return &MF{*rr.Hdr.copyHeader(), rr.Mf} return &MF{rr.Hdr, rr.Mf}
} }
func (rr *MG) copy() RR { func (rr *MG) copy() RR {
return &MG{*rr.Hdr.copyHeader(), rr.Mg} return &MG{rr.Hdr, rr.Mg}
} }
func (rr *MINFO) copy() RR { func (rr *MINFO) copy() RR {
return &MINFO{*rr.Hdr.copyHeader(), rr.Rmail, rr.Email} return &MINFO{rr.Hdr, rr.Rmail, rr.Email}
} }
func (rr *MR) copy() RR { func (rr *MR) copy() RR {
return &MR{*rr.Hdr.copyHeader(), rr.Mr} return &MR{rr.Hdr, rr.Mr}
} }
func (rr *MX) copy() RR { func (rr *MX) copy() RR {
return &MX{*rr.Hdr.copyHeader(), rr.Preference, rr.Mx} return &MX{rr.Hdr, rr.Preference, rr.Mx}
} }
func (rr *NAPTR) copy() RR { func (rr *NAPTR) copy() RR {
return &NAPTR{*rr.Hdr.copyHeader(), rr.Order, rr.Preference, rr.Flags, rr.Service, rr.Regexp, rr.Replacement} return &NAPTR{rr.Hdr, rr.Order, rr.Preference, rr.Flags, rr.Service, rr.Regexp, rr.Replacement}
} }
func (rr *NID) copy() RR { func (rr *NID) copy() RR {
return &NID{*rr.Hdr.copyHeader(), rr.Preference, rr.NodeID} return &NID{rr.Hdr, rr.Preference, rr.NodeID}
} }
func (rr *NIMLOC) copy() RR { func (rr *NIMLOC) copy() RR {
return &NIMLOC{*rr.Hdr.copyHeader(), rr.Locator} return &NIMLOC{rr.Hdr, rr.Locator}
} }
func (rr *NINFO) copy() RR { func (rr *NINFO) copy() RR {
ZSData := make([]string, len(rr.ZSData)) ZSData := make([]string, len(rr.ZSData))
copy(ZSData, rr.ZSData) copy(ZSData, rr.ZSData)
return &NINFO{*rr.Hdr.copyHeader(), ZSData} return &NINFO{rr.Hdr, ZSData}
} }
func (rr *NS) copy() RR { func (rr *NS) copy() RR {
return &NS{*rr.Hdr.copyHeader(), rr.Ns} return &NS{rr.Hdr, rr.Ns}
} }
func (rr *NSAPPTR) copy() RR { func (rr *NSAPPTR) copy() RR {
return &NSAPPTR{*rr.Hdr.copyHeader(), rr.Ptr} return &NSAPPTR{rr.Hdr, rr.Ptr}
} }
func (rr *NSEC) copy() RR { func (rr *NSEC) copy() RR {
TypeBitMap := make([]uint16, len(rr.TypeBitMap)) TypeBitMap := make([]uint16, len(rr.TypeBitMap))
copy(TypeBitMap, rr.TypeBitMap) copy(TypeBitMap, rr.TypeBitMap)
return &NSEC{*rr.Hdr.copyHeader(), rr.NextDomain, TypeBitMap} return &NSEC{rr.Hdr, rr.NextDomain, TypeBitMap}
} }
func (rr *NSEC3) copy() RR { func (rr *NSEC3) copy() RR {
TypeBitMap := make([]uint16, len(rr.TypeBitMap)) TypeBitMap := make([]uint16, len(rr.TypeBitMap))
copy(TypeBitMap, rr.TypeBitMap) copy(TypeBitMap, rr.TypeBitMap)
return &NSEC3{*rr.Hdr.copyHeader(), rr.Hash, rr.Flags, rr.Iterations, rr.SaltLength, rr.Salt, rr.HashLength, rr.NextDomain, TypeBitMap} return &NSEC3{rr.Hdr, rr.Hash, rr.Flags, rr.Iterations, rr.SaltLength, rr.Salt, rr.HashLength, rr.NextDomain, TypeBitMap}
} }
func (rr *NSEC3PARAM) copy() RR { func (rr *NSEC3PARAM) copy() RR {
return &NSEC3PARAM{*rr.Hdr.copyHeader(), rr.Hash, rr.Flags, rr.Iterations, rr.SaltLength, rr.Salt} return &NSEC3PARAM{rr.Hdr, rr.Hash, rr.Flags, rr.Iterations, rr.SaltLength, rr.Salt}
} }
func (rr *OPENPGPKEY) copy() RR { func (rr *OPENPGPKEY) copy() RR {
return &OPENPGPKEY{*rr.Hdr.copyHeader(), rr.PublicKey} return &OPENPGPKEY{rr.Hdr, rr.PublicKey}
} }
func (rr *OPT) copy() RR { func (rr *OPT) copy() RR {
Option := make([]EDNS0, len(rr.Option)) Option := make([]EDNS0, len(rr.Option))
copy(Option, rr.Option) copy(Option, rr.Option)
return &OPT{*rr.Hdr.copyHeader(), Option} return &OPT{rr.Hdr, Option}
} }
func (rr *PTR) copy() RR { func (rr *PTR) copy() RR {
return &PTR{*rr.Hdr.copyHeader(), rr.Ptr} return &PTR{rr.Hdr, rr.Ptr}
} }
func (rr *PX) copy() RR { func (rr *PX) copy() RR {
return &PX{*rr.Hdr.copyHeader(), rr.Preference, rr.Map822, rr.Mapx400} return &PX{rr.Hdr, rr.Preference, rr.Map822, rr.Mapx400}
} }
func (rr *RFC3597) copy() RR { func (rr *RFC3597) copy() RR {
return &RFC3597{*rr.Hdr.copyHeader(), rr.Rdata} return &RFC3597{rr.Hdr, rr.Rdata}
} }
func (rr *RKEY) copy() RR { func (rr *RKEY) copy() RR {
return &RKEY{*rr.Hdr.copyHeader(), rr.Flags, rr.Protocol, rr.Algorithm, rr.PublicKey} return &RKEY{rr.Hdr, rr.Flags, rr.Protocol, rr.Algorithm, rr.PublicKey}
} }
func (rr *RP) copy() RR { func (rr *RP) copy() RR {
return &RP{*rr.Hdr.copyHeader(), rr.Mbox, rr.Txt} return &RP{rr.Hdr, rr.Mbox, rr.Txt}
} }
func (rr *RRSIG) copy() RR { func (rr *RRSIG) copy() RR {
return &RRSIG{*rr.Hdr.copyHeader(), rr.TypeCovered, rr.Algorithm, rr.Labels, rr.OrigTtl, rr.Expiration, rr.Inception, rr.KeyTag, rr.SignerName, rr.Signature} return &RRSIG{rr.Hdr, rr.TypeCovered, rr.Algorithm, rr.Labels, rr.OrigTtl, rr.Expiration, rr.Inception, rr.KeyTag, rr.SignerName, rr.Signature}
} }
func (rr *RT) copy() RR { func (rr *RT) copy() RR {
return &RT{*rr.Hdr.copyHeader(), rr.Preference, rr.Host} return &RT{rr.Hdr, rr.Preference, rr.Host}
} }
func (rr *SMIMEA) copy() RR { func (rr *SMIMEA) copy() RR {
return &SMIMEA{*rr.Hdr.copyHeader(), rr.Usage, rr.Selector, rr.MatchingType, rr.Certificate} return &SMIMEA{rr.Hdr, rr.Usage, rr.Selector, rr.MatchingType, rr.Certificate}
} }
func (rr *SOA) copy() RR { func (rr *SOA) copy() RR {
return &SOA{*rr.Hdr.copyHeader(), rr.Ns, rr.Mbox, rr.Serial, rr.Refresh, rr.Retry, rr.Expire, rr.Minttl} return &SOA{rr.Hdr, rr.Ns, rr.Mbox, rr.Serial, rr.Refresh, rr.Retry, rr.Expire, rr.Minttl}
} }
func (rr *SPF) copy() RR { func (rr *SPF) copy() RR {
Txt := make([]string, len(rr.Txt)) Txt := make([]string, len(rr.Txt))
copy(Txt, rr.Txt) copy(Txt, rr.Txt)
return &SPF{*rr.Hdr.copyHeader(), Txt} return &SPF{rr.Hdr, Txt}
} }
func (rr *SRV) copy() RR { func (rr *SRV) copy() RR {
return &SRV{*rr.Hdr.copyHeader(), rr.Priority, rr.Weight, rr.Port, rr.Target} return &SRV{rr.Hdr, rr.Priority, rr.Weight, rr.Port, rr.Target}
} }
func (rr *SSHFP) copy() RR { func (rr *SSHFP) copy() RR {
return &SSHFP{*rr.Hdr.copyHeader(), rr.Algorithm, rr.Type, rr.FingerPrint} return &SSHFP{rr.Hdr, rr.Algorithm, rr.Type, rr.FingerPrint}
} }
func (rr *TA) copy() RR { func (rr *TA) copy() RR {
return &TA{*rr.Hdr.copyHeader(), rr.KeyTag, rr.Algorithm, rr.DigestType, rr.Digest} return &TA{rr.Hdr, rr.KeyTag, rr.Algorithm, rr.DigestType, rr.Digest}
} }
func (rr *TALINK) copy() RR { func (rr *TALINK) copy() RR {
return &TALINK{*rr.Hdr.copyHeader(), rr.PreviousName, rr.NextName} return &TALINK{rr.Hdr, rr.PreviousName, rr.NextName}
} }
func (rr *TKEY) copy() RR { func (rr *TKEY) copy() RR {
return &TKEY{*rr.Hdr.copyHeader(), rr.Algorithm, rr.Inception, rr.Expiration, rr.Mode, rr.Error, rr.KeySize, rr.Key, rr.OtherLen, rr.OtherData} return &TKEY{rr.Hdr, rr.Algorithm, rr.Inception, rr.Expiration, rr.Mode, rr.Error, rr.KeySize, rr.Key, rr.OtherLen, rr.OtherData}
} }
func (rr *TLSA) copy() RR { func (rr *TLSA) copy() RR {
return &TLSA{*rr.Hdr.copyHeader(), rr.Usage, rr.Selector, rr.MatchingType, rr.Certificate} return &TLSA{rr.Hdr, rr.Usage, rr.Selector, rr.MatchingType, rr.Certificate}
} }
func (rr *TSIG) copy() RR { func (rr *TSIG) copy() RR {
return &TSIG{*rr.Hdr.copyHeader(), rr.Algorithm, rr.TimeSigned, rr.Fudge, rr.MACSize, rr.MAC, rr.OrigId, rr.Error, rr.OtherLen, rr.OtherData} return &TSIG{rr.Hdr, rr.Algorithm, rr.TimeSigned, rr.Fudge, rr.MACSize, rr.MAC, rr.OrigId, rr.Error, rr.OtherLen, rr.OtherData}
} }
func (rr *TXT) copy() RR { func (rr *TXT) copy() RR {
Txt := make([]string, len(rr.Txt)) Txt := make([]string, len(rr.Txt))
copy(Txt, rr.Txt) copy(Txt, rr.Txt)
return &TXT{*rr.Hdr.copyHeader(), Txt} return &TXT{rr.Hdr, Txt}
} }
func (rr *UID) copy() RR { func (rr *UID) copy() RR {
return &UID{*rr.Hdr.copyHeader(), rr.Uid} return &UID{rr.Hdr, rr.Uid}
} }
func (rr *UINFO) copy() RR { func (rr *UINFO) copy() RR {
return &UINFO{*rr.Hdr.copyHeader(), rr.Uinfo} return &UINFO{rr.Hdr, rr.Uinfo}
} }
func (rr *URI) copy() RR { func (rr *URI) copy() RR {
return &URI{*rr.Hdr.copyHeader(), rr.Priority, rr.Weight, rr.Target} return &URI{rr.Hdr, rr.Priority, rr.Weight, rr.Target}
} }
func (rr *X25) copy() RR { func (rr *X25) copy() RR {
return &X25{*rr.Hdr.copyHeader(), rr.PSDNAddress} return &X25{rr.Hdr, rr.PSDNAddress}
} }

69
vendor/github.com/xenolf/lego/acme/api/account.go generated vendored Normal file
View file

@ -0,0 +1,69 @@
package api
import (
"encoding/base64"
"errors"
"fmt"
"github.com/xenolf/lego/acme"
)
type AccountService service
// New Creates a new account.
func (a *AccountService) New(req acme.Account) (acme.ExtendedAccount, error) {
var account acme.Account
resp, err := a.core.post(a.core.GetDirectory().NewAccountURL, req, &account)
location := getLocation(resp)
if len(location) > 0 {
a.core.jws.SetKid(location)
}
if err != nil {
return acme.ExtendedAccount{Location: location}, err
}
return acme.ExtendedAccount{Account: account, Location: location}, nil
}
// NewEAB Creates a new account with an External Account Binding.
func (a *AccountService) NewEAB(accMsg acme.Account, kid string, hmacEncoded string) (acme.ExtendedAccount, error) {
hmac, err := base64.RawURLEncoding.DecodeString(hmacEncoded)
if err != nil {
return acme.ExtendedAccount{}, fmt.Errorf("acme: could not decode hmac key: %v", err)
}
eabJWS, err := a.core.signEABContent(a.core.GetDirectory().NewAccountURL, kid, hmac)
if err != nil {
return acme.ExtendedAccount{}, fmt.Errorf("acme: error signing eab content: %v", err)
}
accMsg.ExternalAccountBinding = eabJWS
return a.New(accMsg)
}
// Get Retrieves an account.
func (a *AccountService) Get(accountURL string) (acme.Account, error) {
if len(accountURL) == 0 {
return acme.Account{}, errors.New("account[get]: empty URL")
}
var account acme.Account
_, err := a.core.post(accountURL, acme.Account{}, &account)
if err != nil {
return acme.Account{}, err
}
return account, nil
}
// Deactivate Deactivates an account.
func (a *AccountService) Deactivate(accountURL string) error {
if len(accountURL) == 0 {
return errors.New("account[deactivate]: empty URL")
}
req := acme.Account{Status: acme.StatusDeactivated}
_, err := a.core.post(accountURL, req, nil)
return err
}

151
vendor/github.com/xenolf/lego/acme/api/api.go generated vendored Normal file
View file

@ -0,0 +1,151 @@
package api
import (
"bytes"
"crypto"
"encoding/json"
"errors"
"fmt"
"net/http"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/acme/api/internal/nonces"
"github.com/xenolf/lego/acme/api/internal/secure"
"github.com/xenolf/lego/acme/api/internal/sender"
"github.com/xenolf/lego/log"
)
// Core ACME/LE core API.
type Core struct {
doer *sender.Doer
nonceManager *nonces.Manager
jws *secure.JWS
directory acme.Directory
HTTPClient *http.Client
common service // Reuse a single struct instead of allocating one for each service on the heap.
Accounts *AccountService
Authorizations *AuthorizationService
Certificates *CertificateService
Challenges *ChallengeService
Orders *OrderService
}
// New Creates a new Core.
func New(httpClient *http.Client, userAgent string, caDirURL, kid string, privateKey crypto.PrivateKey) (*Core, error) {
doer := sender.NewDoer(httpClient, userAgent)
dir, err := getDirectory(doer, caDirURL)
if err != nil {
return nil, err
}
nonceManager := nonces.NewManager(doer, dir.NewNonceURL)
jws := secure.NewJWS(privateKey, kid, nonceManager)
c := &Core{doer: doer, nonceManager: nonceManager, jws: jws, directory: dir}
c.common.core = c
c.Accounts = (*AccountService)(&c.common)
c.Authorizations = (*AuthorizationService)(&c.common)
c.Certificates = (*CertificateService)(&c.common)
c.Challenges = (*ChallengeService)(&c.common)
c.Orders = (*OrderService)(&c.common)
return c, nil
}
// post performs an HTTP POST request and parses the response body as JSON,
// into the provided respBody object.
func (a *Core) post(uri string, reqBody, response interface{}) (*http.Response, error) {
content, err := json.Marshal(reqBody)
if err != nil {
return nil, errors.New("failed to marshal message")
}
return a.retrievablePost(uri, content, response, 0)
}
// postAsGet performs an HTTP POST ("POST-as-GET") request.
// https://tools.ietf.org/html/draft-ietf-acme-acme-16#section-6.3
func (a *Core) postAsGet(uri string, response interface{}) (*http.Response, error) {
return a.retrievablePost(uri, []byte{}, response, 0)
}
func (a *Core) retrievablePost(uri string, content []byte, response interface{}, retry int) (*http.Response, error) {
resp, err := a.signedPost(uri, content, response)
if err != nil {
// during tests, 5 retries allow to support ~50% of bad nonce.
if retry >= 5 {
log.Infof("too many retry on a nonce error, retry count: %d", retry)
return resp, err
}
switch err.(type) {
// Retry once if the nonce was invalidated
case *acme.NonceError:
log.Infof("nonce error retry: %s", err)
resp, err = a.retrievablePost(uri, content, response, retry+1)
if err != nil {
return resp, err
}
default:
return resp, err
}
}
return resp, nil
}
func (a *Core) signedPost(uri string, content []byte, response interface{}) (*http.Response, error) {
signedContent, err := a.jws.SignContent(uri, content)
if err != nil {
return nil, fmt.Errorf("failed to post JWS message -> failed to sign content -> %v", err)
}
signedBody := bytes.NewBuffer([]byte(signedContent.FullSerialize()))
resp, err := a.doer.Post(uri, signedBody, "application/jose+json", response)
// nonceErr is ignored to keep the root error.
nonce, nonceErr := nonces.GetFromResponse(resp)
if nonceErr == nil {
a.nonceManager.Push(nonce)
}
return resp, err
}
func (a *Core) signEABContent(newAccountURL, kid string, hmac []byte) ([]byte, error) {
eabJWS, err := a.jws.SignEABContent(newAccountURL, kid, hmac)
if err != nil {
return nil, err
}
return []byte(eabJWS.FullSerialize()), nil
}
// GetKeyAuthorization Gets the key authorization
func (a *Core) GetKeyAuthorization(token string) (string, error) {
return a.jws.GetKeyAuthorization(token)
}
func (a *Core) GetDirectory() acme.Directory {
return a.directory
}
func getDirectory(do *sender.Doer, caDirURL string) (acme.Directory, error) {
var dir acme.Directory
if _, err := do.Get(caDirURL, &dir); err != nil {
return dir, fmt.Errorf("get directory at '%s': %v", caDirURL, err)
}
if dir.NewAccountURL == "" {
return dir, errors.New("directory missing new registration URL")
}
if dir.NewOrderURL == "" {
return dir, errors.New("directory missing new order URL")
}
return dir, nil
}

View file

@ -0,0 +1,34 @@
package api
import (
"errors"
"github.com/xenolf/lego/acme"
)
type AuthorizationService service
// Get Gets an authorization.
func (c *AuthorizationService) Get(authzURL string) (acme.Authorization, error) {
if len(authzURL) == 0 {
return acme.Authorization{}, errors.New("authorization[get]: empty URL")
}
var authz acme.Authorization
_, err := c.core.postAsGet(authzURL, &authz)
if err != nil {
return acme.Authorization{}, err
}
return authz, nil
}
// Deactivate Deactivates an authorization.
func (c *AuthorizationService) Deactivate(authzURL string) error {
if len(authzURL) == 0 {
return errors.New("authorization[deactivate]: empty URL")
}
var disabledAuth acme.Authorization
_, err := c.core.post(authzURL, acme.Authorization{Status: acme.StatusDeactivated}, &disabledAuth)
return err
}

99
vendor/github.com/xenolf/lego/acme/api/certificate.go generated vendored Normal file
View file

@ -0,0 +1,99 @@
package api
import (
"crypto/x509"
"encoding/pem"
"errors"
"io/ioutil"
"net/http"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/certcrypto"
"github.com/xenolf/lego/log"
)
// maxBodySize is the maximum size of body that we will read.
const maxBodySize = 1024 * 1024
type CertificateService service
// Get Returns the certificate and the issuer certificate.
// 'bundle' is only applied if the issuer is provided by the 'up' link.
func (c *CertificateService) Get(certURL string, bundle bool) ([]byte, []byte, error) {
cert, up, err := c.get(certURL)
if err != nil {
return nil, nil, err
}
// Get issuerCert from bundled response from Let's Encrypt
// See https://community.letsencrypt.org/t/acme-v2-no-up-link-in-response/64962
_, issuer := pem.Decode(cert)
if issuer != nil {
return cert, issuer, nil
}
issuer, err = c.getIssuerFromLink(up)
if err != nil {
// If we fail to acquire the issuer cert, return the issued certificate - do not fail.
log.Warnf("acme: Could not bundle issuer certificate [%s]: %v", certURL, err)
} else if len(issuer) > 0 {
// If bundle is true, we want to return a certificate bundle.
// To do this, we append the issuer cert to the issued cert.
if bundle {
cert = append(cert, issuer...)
}
}
return cert, issuer, nil
}
// Revoke Revokes a certificate.
func (c *CertificateService) Revoke(req acme.RevokeCertMessage) error {
_, err := c.core.post(c.core.GetDirectory().RevokeCertURL, req, nil)
return err
}
// get Returns the certificate and the "up" link.
func (c *CertificateService) get(certURL string) ([]byte, string, error) {
if len(certURL) == 0 {
return nil, "", errors.New("certificate[get]: empty URL")
}
resp, err := c.core.postAsGet(certURL, nil)
if err != nil {
return nil, "", err
}
cert, err := ioutil.ReadAll(http.MaxBytesReader(nil, resp.Body, maxBodySize))
if err != nil {
return nil, "", err
}
// The issuer certificate link may be supplied via an "up" link
// in the response headers of a new certificate.
// See https://tools.ietf.org/html/draft-ietf-acme-acme-12#section-7.4.2
up := getLink(resp.Header, "up")
return cert, up, err
}
// getIssuerFromLink requests the issuer certificate
func (c *CertificateService) getIssuerFromLink(up string) ([]byte, error) {
if len(up) == 0 {
return nil, nil
}
log.Infof("acme: Requesting issuer cert from %s", up)
cert, _, err := c.get(up)
if err != nil {
return nil, err
}
_, err = x509.ParseCertificate(cert)
if err != nil {
return nil, err
}
return certcrypto.PEMEncode(certcrypto.DERCertificateBytes(cert)), nil
}

45
vendor/github.com/xenolf/lego/acme/api/challenge.go generated vendored Normal file
View file

@ -0,0 +1,45 @@
package api
import (
"errors"
"github.com/xenolf/lego/acme"
)
type ChallengeService service
// New Creates a challenge.
func (c *ChallengeService) New(chlgURL string) (acme.ExtendedChallenge, error) {
if len(chlgURL) == 0 {
return acme.ExtendedChallenge{}, errors.New("challenge[new]: empty URL")
}
// Challenge initiation is done by sending a JWS payload containing the trivial JSON object `{}`.
// We use an empty struct instance as the postJSON payload here to achieve this result.
var chlng acme.ExtendedChallenge
resp, err := c.core.post(chlgURL, struct{}{}, &chlng)
if err != nil {
return acme.ExtendedChallenge{}, err
}
chlng.AuthorizationURL = getLink(resp.Header, "up")
chlng.RetryAfter = getRetryAfter(resp)
return chlng, nil
}
// Get Gets a challenge.
func (c *ChallengeService) Get(chlgURL string) (acme.ExtendedChallenge, error) {
if len(chlgURL) == 0 {
return acme.ExtendedChallenge{}, errors.New("challenge[get]: empty URL")
}
var chlng acme.ExtendedChallenge
resp, err := c.core.postAsGet(chlgURL, &chlng)
if err != nil {
return acme.ExtendedChallenge{}, err
}
chlng.AuthorizationURL = getLink(resp.Header, "up")
chlng.RetryAfter = getRetryAfter(resp)
return chlng, nil
}

View file

@ -0,0 +1,78 @@
package nonces
import (
"errors"
"fmt"
"net/http"
"sync"
"github.com/xenolf/lego/acme/api/internal/sender"
)
// Manager Manages nonces.
type Manager struct {
do *sender.Doer
nonceURL string
nonces []string
sync.Mutex
}
// NewManager Creates a new Manager.
func NewManager(do *sender.Doer, nonceURL string) *Manager {
return &Manager{
do: do,
nonceURL: nonceURL,
}
}
// Pop Pops a nonce.
func (n *Manager) Pop() (string, bool) {
n.Lock()
defer n.Unlock()
if len(n.nonces) == 0 {
return "", false
}
nonce := n.nonces[len(n.nonces)-1]
n.nonces = n.nonces[:len(n.nonces)-1]
return nonce, true
}
// Push Pushes a nonce.
func (n *Manager) Push(nonce string) {
n.Lock()
defer n.Unlock()
n.nonces = append(n.nonces, nonce)
}
// Nonce implement jose.NonceSource
func (n *Manager) Nonce() (string, error) {
if nonce, ok := n.Pop(); ok {
return nonce, nil
}
return n.getNonce()
}
func (n *Manager) getNonce() (string, error) {
resp, err := n.do.Head(n.nonceURL)
if err != nil {
return "", fmt.Errorf("failed to get nonce from HTTP HEAD -> %v", err)
}
return GetFromResponse(resp)
}
// GetFromResponse Extracts a nonce from a HTTP response.
func GetFromResponse(resp *http.Response) (string, error) {
if resp == nil {
return "", errors.New("nil response")
}
nonce := resp.Header.Get("Replay-Nonce")
if nonce == "" {
return "", fmt.Errorf("server did not respond with a proper nonce header")
}
return nonce, nil
}

View file

@ -0,0 +1,134 @@
package secure
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"encoding/base64"
"errors"
"fmt"
"github.com/xenolf/lego/acme/api/internal/nonces"
"gopkg.in/square/go-jose.v2"
)
// JWS Represents a JWS.
type JWS struct {
privKey crypto.PrivateKey
kid string // Key identifier
nonces *nonces.Manager
}
// NewJWS Create a new JWS.
func NewJWS(privateKey crypto.PrivateKey, kid string, nonceManager *nonces.Manager) *JWS {
return &JWS{
privKey: privateKey,
nonces: nonceManager,
kid: kid,
}
}
// SetKid Sets a key identifier.
func (j *JWS) SetKid(kid string) {
j.kid = kid
}
// SignContent Signs a content with the JWS.
func (j *JWS) SignContent(url string, content []byte) (*jose.JSONWebSignature, error) {
var alg jose.SignatureAlgorithm
switch k := j.privKey.(type) {
case *rsa.PrivateKey:
alg = jose.RS256
case *ecdsa.PrivateKey:
if k.Curve == elliptic.P256() {
alg = jose.ES256
} else if k.Curve == elliptic.P384() {
alg = jose.ES384
}
}
signKey := jose.SigningKey{
Algorithm: alg,
Key: jose.JSONWebKey{Key: j.privKey, KeyID: j.kid},
}
options := jose.SignerOptions{
NonceSource: j.nonces,
ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": url,
},
}
if j.kid == "" {
options.EmbedJWK = true
}
signer, err := jose.NewSigner(signKey, &options)
if err != nil {
return nil, fmt.Errorf("failed to create jose signer -> %v", err)
}
signed, err := signer.Sign(content)
if err != nil {
return nil, fmt.Errorf("failed to sign content -> %v", err)
}
return signed, nil
}
// SignEABContent Signs an external account binding content with the JWS.
func (j *JWS) SignEABContent(url, kid string, hmac []byte) (*jose.JSONWebSignature, error) {
jwk := jose.JSONWebKey{Key: j.privKey}
jwkJSON, err := jwk.Public().MarshalJSON()
if err != nil {
return nil, fmt.Errorf("acme: error encoding eab jwk key: %v", err)
}
signer, err := jose.NewSigner(
jose.SigningKey{Algorithm: jose.HS256, Key: hmac},
&jose.SignerOptions{
EmbedJWK: false,
ExtraHeaders: map[jose.HeaderKey]interface{}{
"kid": kid,
"url": url,
},
},
)
if err != nil {
return nil, fmt.Errorf("failed to create External Account Binding jose signer -> %v", err)
}
signed, err := signer.Sign(jwkJSON)
if err != nil {
return nil, fmt.Errorf("failed to External Account Binding sign content -> %v", err)
}
return signed, nil
}
// GetKeyAuthorization Gets the key authorization for a token.
func (j *JWS) GetKeyAuthorization(token string) (string, error) {
var publicKey crypto.PublicKey
switch k := j.privKey.(type) {
case *ecdsa.PrivateKey:
publicKey = k.Public()
case *rsa.PrivateKey:
publicKey = k.Public()
}
// Generate the Key Authorization for the challenge
jwk := &jose.JSONWebKey{Key: publicKey}
if jwk == nil {
return "", errors.New("could not generate JWK from key")
}
thumbBytes, err := jwk.Thumbprint(crypto.SHA256)
if err != nil {
return "", err
}
// unpad the base64URL
keyThumb := base64.RawURLEncoding.EncodeToString(thumbBytes)
return token + "." + keyThumb, nil
}

View file

@ -0,0 +1,146 @@
package sender
import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"runtime"
"strings"
"github.com/xenolf/lego/acme"
)
type RequestOption func(*http.Request) error
func contentType(ct string) RequestOption {
return func(req *http.Request) error {
req.Header.Set("Content-Type", ct)
return nil
}
}
type Doer struct {
httpClient *http.Client
userAgent string
}
// NewDoer Creates a new Doer.
func NewDoer(client *http.Client, userAgent string) *Doer {
return &Doer{
httpClient: client,
userAgent: userAgent,
}
}
// Get performs a GET request with a proper User-Agent string.
// If "response" is not provided, callers should close resp.Body when done reading from it.
func (d *Doer) Get(url string, response interface{}) (*http.Response, error) {
req, err := d.newRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
return d.do(req, response)
}
// Head performs a HEAD request with a proper User-Agent string.
// The response body (resp.Body) is already closed when this function returns.
func (d *Doer) Head(url string) (*http.Response, error) {
req, err := d.newRequest(http.MethodHead, url, nil)
if err != nil {
return nil, err
}
return d.do(req, nil)
}
// Post performs a POST request with a proper User-Agent string.
// If "response" is not provided, callers should close resp.Body when done reading from it.
func (d *Doer) Post(url string, body io.Reader, bodyType string, response interface{}) (*http.Response, error) {
req, err := d.newRequest(http.MethodPost, url, body, contentType(bodyType))
if err != nil {
return nil, err
}
return d.do(req, response)
}
func (d *Doer) newRequest(method, uri string, body io.Reader, opts ...RequestOption) (*http.Request, error) {
req, err := http.NewRequest(method, uri, body)
if err != nil {
return nil, fmt.Errorf("failed to create request: %v", err)
}
req.Header.Set("User-Agent", d.formatUserAgent())
for _, opt := range opts {
err = opt(req)
if err != nil {
return nil, fmt.Errorf("failed to create request: %v", err)
}
}
return req, nil
}
func (d *Doer) do(req *http.Request, response interface{}) (*http.Response, error) {
resp, err := d.httpClient.Do(req)
if err != nil {
return nil, err
}
if err = checkError(req, resp); err != nil {
return resp, err
}
if response != nil {
raw, err := ioutil.ReadAll(resp.Body)
if err != nil {
return resp, err
}
defer resp.Body.Close()
err = json.Unmarshal(raw, response)
if err != nil {
return resp, fmt.Errorf("failed to unmarshal %q to type %T: %v", raw, response, err)
}
}
return resp, nil
}
// formatUserAgent builds and returns the User-Agent string to use in requests.
func (d *Doer) formatUserAgent() string {
ua := fmt.Sprintf("%s %s (%s; %s; %s)", d.userAgent, ourUserAgent, ourUserAgentComment, runtime.GOOS, runtime.GOARCH)
return strings.TrimSpace(ua)
}
func checkError(req *http.Request, resp *http.Response) error {
if resp.StatusCode >= http.StatusBadRequest {
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("%d :: %s :: %s :: %v", resp.StatusCode, req.Method, req.URL, err)
}
var errorDetails *acme.ProblemDetails
err = json.Unmarshal(body, &errorDetails)
if err != nil {
return fmt.Errorf("%d ::%s :: %s :: %v :: %s", resp.StatusCode, req.Method, req.URL, err, string(body))
}
errorDetails.Method = req.Method
errorDetails.URL = req.URL.String()
// Check for errors we handle specifically
if errorDetails.HTTPStatus == http.StatusBadRequest && errorDetails.Type == acme.BadNonceErr {
return &acme.NonceError{ProblemDetails: errorDetails}
}
return errorDetails
}
return nil
}

View file

@ -0,0 +1,14 @@
package sender
// CODE GENERATED AUTOMATICALLY
// THIS FILE MUST NOT BE EDITED BY HAND
const (
// ourUserAgent is the User-Agent of this underlying library package.
ourUserAgent = "xenolf-acme/1.2.1"
// ourUserAgentComment is part of the UA comment linked to the version status of this underlying library package.
// values: detach|release
// NOTE: Update this with each tagged release.
ourUserAgentComment = "detach"
)

65
vendor/github.com/xenolf/lego/acme/api/order.go generated vendored Normal file
View file

@ -0,0 +1,65 @@
package api
import (
"encoding/base64"
"errors"
"github.com/xenolf/lego/acme"
)
type OrderService service
// New Creates a new order.
func (o *OrderService) New(domains []string) (acme.ExtendedOrder, error) {
var identifiers []acme.Identifier
for _, domain := range domains {
identifiers = append(identifiers, acme.Identifier{Type: "dns", Value: domain})
}
orderReq := acme.Order{Identifiers: identifiers}
var order acme.Order
resp, err := o.core.post(o.core.GetDirectory().NewOrderURL, orderReq, &order)
if err != nil {
return acme.ExtendedOrder{}, err
}
return acme.ExtendedOrder{
Location: resp.Header.Get("Location"),
Order: order,
}, nil
}
// Get Gets an order.
func (o *OrderService) Get(orderURL string) (acme.Order, error) {
if len(orderURL) == 0 {
return acme.Order{}, errors.New("order[get]: empty URL")
}
var order acme.Order
_, err := o.core.postAsGet(orderURL, &order)
if err != nil {
return acme.Order{}, err
}
return order, nil
}
// UpdateForCSR Updates an order for a CSR.
func (o *OrderService) UpdateForCSR(orderURL string, csr []byte) (acme.Order, error) {
csrMsg := acme.CSRMessage{
Csr: base64.RawURLEncoding.EncodeToString(csr),
}
var order acme.Order
_, err := o.core.post(orderURL, csrMsg, &order)
if err != nil {
return acme.Order{}, err
}
if order.Status == acme.StatusInvalid {
return acme.Order{}, order.Error
}
return order, nil
}

45
vendor/github.com/xenolf/lego/acme/api/service.go generated vendored Normal file
View file

@ -0,0 +1,45 @@
package api
import (
"net/http"
"regexp"
)
type service struct {
core *Core
}
// getLink get a rel into the Link header
func getLink(header http.Header, rel string) string {
var linkExpr = regexp.MustCompile(`<(.+?)>;\s*rel="(.+?)"`)
for _, link := range header["Link"] {
for _, m := range linkExpr.FindAllStringSubmatch(link, -1) {
if len(m) != 3 {
continue
}
if m[2] == rel {
return m[1]
}
}
}
return ""
}
// getLocation get the value of the header Location
func getLocation(resp *http.Response) string {
if resp == nil {
return ""
}
return resp.Header.Get("Location")
}
// getRetryAfter get the value of the header Retry-After
func getRetryAfter(resp *http.Response) string {
if resp == nil {
return ""
}
return resp.Header.Get("Retry-After")
}

View file

@ -1,17 +0,0 @@
package acme
// Challenge is a string that identifies a particular type and version of ACME challenge.
type Challenge string
const (
// HTTP01 is the "http-01" ACME challenge https://github.com/ietf-wg-acme/acme/blob/master/draft-ietf-acme-acme.md#http
// Note: HTTP01ChallengePath returns the URL path to fulfill this challenge
HTTP01 = Challenge("http-01")
// DNS01 is the "dns-01" ACME challenge https://github.com/ietf-wg-acme/acme/blob/master/draft-ietf-acme-acme.md#dns
// Note: DNS01Record returns a DNS record which will fulfill this challenge
DNS01 = Challenge("dns-01")
// TLSALPN01 is the "tls-alpn-01" ACME challenge https://tools.ietf.org/html/draft-ietf-acme-tls-alpn-01
TLSALPN01 = Challenge("tls-alpn-01")
)

View file

@ -1,957 +0,0 @@
// Package acme implements the ACME protocol for Let's Encrypt and other conforming providers.
package acme
import (
"crypto"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
"io/ioutil"
"net"
"regexp"
"strconv"
"strings"
"time"
"github.com/xenolf/lego/log"
)
const (
// maxBodySize is the maximum size of body that we will read.
maxBodySize = 1024 * 1024
// overallRequestLimit is the overall number of request per second limited on the
// “new-reg”, “new-authz” and “new-cert” endpoints. From the documentation the
// limitation is 20 requests per second, but using 20 as value doesn't work but 18 do
overallRequestLimit = 18
statusValid = "valid"
statusInvalid = "invalid"
)
// User interface is to be implemented by users of this library.
// It is used by the client type to get user specific information.
type User interface {
GetEmail() string
GetRegistration() *RegistrationResource
GetPrivateKey() crypto.PrivateKey
}
// Interface for all challenge solvers to implement.
type solver interface {
Solve(challenge challenge, domain string) error
}
// Interface for challenges like dns, where we can set a record in advance for ALL challenges.
// This saves quite a bit of time vs creating the records and solving them serially.
type preSolver interface {
PreSolve(challenge challenge, domain string) error
}
// Interface for challenges like dns, where we can solve all the challenges before to delete them.
type cleanup interface {
CleanUp(challenge challenge, domain string) error
}
type validateFunc func(j *jws, domain, uri string, chlng challenge) error
// Client is the user-friendy way to ACME
type Client struct {
directory directory
user User
jws *jws
keyType KeyType
solvers map[Challenge]solver
}
// NewClient creates a new ACME client on behalf of the user. The client will depend on
// the ACME directory located at caDirURL for the rest of its actions. A private
// key of type keyType (see KeyType contants) will be generated when requesting a new
// certificate if one isn't provided.
func NewClient(caDirURL string, user User, keyType KeyType) (*Client, error) {
privKey := user.GetPrivateKey()
if privKey == nil {
return nil, errors.New("private key was nil")
}
var dir directory
if _, err := getJSON(caDirURL, &dir); err != nil {
return nil, fmt.Errorf("get directory at '%s': %v", caDirURL, err)
}
if dir.NewAccountURL == "" {
return nil, errors.New("directory missing new registration URL")
}
if dir.NewOrderURL == "" {
return nil, errors.New("directory missing new order URL")
}
jws := &jws{privKey: privKey, getNonceURL: dir.NewNonceURL}
if reg := user.GetRegistration(); reg != nil {
jws.kid = reg.URI
}
// REVIEW: best possibility?
// Add all available solvers with the right index as per ACME
// spec to this map. Otherwise they won`t be found.
solvers := map[Challenge]solver{
HTTP01: &httpChallenge{jws: jws, validate: validate, provider: &HTTPProviderServer{}},
TLSALPN01: &tlsALPNChallenge{jws: jws, validate: validate, provider: &TLSALPNProviderServer{}},
}
return &Client{directory: dir, user: user, jws: jws, keyType: keyType, solvers: solvers}, nil
}
// SetChallengeProvider specifies a custom provider p that can solve the given challenge type.
func (c *Client) SetChallengeProvider(challenge Challenge, p ChallengeProvider) error {
switch challenge {
case HTTP01:
c.solvers[challenge] = &httpChallenge{jws: c.jws, validate: validate, provider: p}
case DNS01:
c.solvers[challenge] = &dnsChallenge{jws: c.jws, validate: validate, provider: p}
case TLSALPN01:
c.solvers[challenge] = &tlsALPNChallenge{jws: c.jws, validate: validate, provider: p}
default:
return fmt.Errorf("unknown challenge %v", challenge)
}
return nil
}
// SetHTTPAddress specifies a custom interface:port to be used for HTTP based challenges.
// If this option is not used, the default port 80 and all interfaces will be used.
// To only specify a port and no interface use the ":port" notation.
//
// NOTE: This REPLACES any custom HTTP provider previously set by calling
// c.SetChallengeProvider with the default HTTP challenge provider.
func (c *Client) SetHTTPAddress(iface string) error {
host, port, err := net.SplitHostPort(iface)
if err != nil {
return err
}
if chlng, ok := c.solvers[HTTP01]; ok {
chlng.(*httpChallenge).provider = NewHTTPProviderServer(host, port)
}
return nil
}
// SetTLSAddress specifies a custom interface:port to be used for TLS based challenges.
// If this option is not used, the default port 443 and all interfaces will be used.
// To only specify a port and no interface use the ":port" notation.
//
// NOTE: This REPLACES any custom TLS-ALPN provider previously set by calling
// c.SetChallengeProvider with the default TLS-ALPN challenge provider.
func (c *Client) SetTLSAddress(iface string) error {
host, port, err := net.SplitHostPort(iface)
if err != nil {
return err
}
if chlng, ok := c.solvers[TLSALPN01]; ok {
chlng.(*tlsALPNChallenge).provider = NewTLSALPNProviderServer(host, port)
}
return nil
}
// ExcludeChallenges explicitly removes challenges from the pool for solving.
func (c *Client) ExcludeChallenges(challenges []Challenge) {
// Loop through all challenges and delete the requested one if found.
for _, challenge := range challenges {
delete(c.solvers, challenge)
}
}
// GetToSURL returns the current ToS URL from the Directory
func (c *Client) GetToSURL() string {
return c.directory.Meta.TermsOfService
}
// GetExternalAccountRequired returns the External Account Binding requirement of the Directory
func (c *Client) GetExternalAccountRequired() bool {
return c.directory.Meta.ExternalAccountRequired
}
// Register the current account to the ACME server.
func (c *Client) Register(tosAgreed bool) (*RegistrationResource, error) {
if c == nil || c.user == nil {
return nil, errors.New("acme: cannot register a nil client or user")
}
log.Infof("acme: Registering account for %s", c.user.GetEmail())
accMsg := accountMessage{}
if c.user.GetEmail() != "" {
accMsg.Contact = []string{"mailto:" + c.user.GetEmail()}
} else {
accMsg.Contact = []string{}
}
accMsg.TermsOfServiceAgreed = tosAgreed
var serverReg accountMessage
hdr, err := postJSON(c.jws, c.directory.NewAccountURL, accMsg, &serverReg)
if err != nil {
remoteErr, ok := err.(RemoteError)
if ok && remoteErr.StatusCode == 409 {
} else {
return nil, err
}
}
reg := &RegistrationResource{
URI: hdr.Get("Location"),
Body: serverReg,
}
c.jws.kid = reg.URI
return reg, nil
}
// RegisterWithExternalAccountBinding Register the current account to the ACME server.
func (c *Client) RegisterWithExternalAccountBinding(tosAgreed bool, kid string, hmacEncoded string) (*RegistrationResource, error) {
if c == nil || c.user == nil {
return nil, errors.New("acme: cannot register a nil client or user")
}
log.Infof("acme: Registering account (EAB) for %s", c.user.GetEmail())
accMsg := accountMessage{}
if c.user.GetEmail() != "" {
accMsg.Contact = []string{"mailto:" + c.user.GetEmail()}
} else {
accMsg.Contact = []string{}
}
accMsg.TermsOfServiceAgreed = tosAgreed
hmac, err := base64.RawURLEncoding.DecodeString(hmacEncoded)
if err != nil {
return nil, fmt.Errorf("acme: could not decode hmac key: %s", err.Error())
}
eabJWS, err := c.jws.signEABContent(c.directory.NewAccountURL, kid, hmac)
if err != nil {
return nil, fmt.Errorf("acme: error signing eab content: %s", err.Error())
}
eabPayload := eabJWS.FullSerialize()
accMsg.ExternalAccountBinding = []byte(eabPayload)
var serverReg accountMessage
hdr, err := postJSON(c.jws, c.directory.NewAccountURL, accMsg, &serverReg)
if err != nil {
remoteErr, ok := err.(RemoteError)
if ok && remoteErr.StatusCode == 409 {
} else {
return nil, err
}
}
reg := &RegistrationResource{
URI: hdr.Get("Location"),
Body: serverReg,
}
c.jws.kid = reg.URI
return reg, nil
}
// ResolveAccountByKey will attempt to look up an account using the given account key
// and return its registration resource.
func (c *Client) ResolveAccountByKey() (*RegistrationResource, error) {
log.Infof("acme: Trying to resolve account by key")
acc := accountMessage{OnlyReturnExisting: true}
hdr, err := postJSON(c.jws, c.directory.NewAccountURL, acc, nil)
if err != nil {
return nil, err
}
accountLink := hdr.Get("Location")
if accountLink == "" {
return nil, errors.New("Server did not return the account link")
}
var retAccount accountMessage
c.jws.kid = accountLink
_, err = postJSON(c.jws, accountLink, accountMessage{}, &retAccount)
if err != nil {
return nil, err
}
return &RegistrationResource{URI: accountLink, Body: retAccount}, nil
}
// DeleteRegistration deletes the client's user registration from the ACME
// server.
func (c *Client) DeleteRegistration() error {
if c == nil || c.user == nil {
return errors.New("acme: cannot unregister a nil client or user")
}
log.Infof("acme: Deleting account for %s", c.user.GetEmail())
accMsg := accountMessage{
Status: "deactivated",
}
_, err := postJSON(c.jws, c.user.GetRegistration().URI, accMsg, nil)
return err
}
// QueryRegistration runs a POST request on the client's registration and
// returns the result.
//
// This is similar to the Register function, but acting on an existing
// registration link and resource.
func (c *Client) QueryRegistration() (*RegistrationResource, error) {
if c == nil || c.user == nil {
return nil, errors.New("acme: cannot query the registration of a nil client or user")
}
// Log the URL here instead of the email as the email may not be set
log.Infof("acme: Querying account for %s", c.user.GetRegistration().URI)
accMsg := accountMessage{}
var serverReg accountMessage
_, err := postJSON(c.jws, c.user.GetRegistration().URI, accMsg, &serverReg)
if err != nil {
return nil, err
}
reg := &RegistrationResource{Body: serverReg}
// Location: header is not returned so this needs to be populated off of
// existing URI
reg.URI = c.user.GetRegistration().URI
return reg, nil
}
// ObtainCertificateForCSR tries to obtain a certificate matching the CSR passed into it.
// The domains are inferred from the CommonName and SubjectAltNames, if any. The private key
// for this CSR is not required.
// If bundle is true, the []byte contains both the issuer certificate and
// your issued certificate as a bundle.
// This function will never return a partial certificate. If one domain in the list fails,
// the whole certificate will fail.
func (c *Client) ObtainCertificateForCSR(csr x509.CertificateRequest, bundle bool) (*CertificateResource, error) {
// figure out what domains it concerns
// start with the common name
domains := []string{csr.Subject.CommonName}
// loop over the SubjectAltName DNS names
DNSNames:
for _, sanName := range csr.DNSNames {
for _, existingName := range domains {
if existingName == sanName {
// duplicate; skip this name
continue DNSNames
}
}
// name is unique
domains = append(domains, sanName)
}
if bundle {
log.Infof("[%s] acme: Obtaining bundled SAN certificate given a CSR", strings.Join(domains, ", "))
} else {
log.Infof("[%s] acme: Obtaining SAN certificate given a CSR", strings.Join(domains, ", "))
}
order, err := c.createOrderForIdentifiers(domains)
if err != nil {
return nil, err
}
authz, err := c.getAuthzForOrder(order)
if err != nil {
// If any challenge fails, return. Do not generate partial SAN certificates.
/*for _, auth := range authz {
c.disableAuthz(auth)
}*/
return nil, err
}
err = c.solveChallengeForAuthz(authz)
if err != nil {
// If any challenge fails, return. Do not generate partial SAN certificates.
return nil, err
}
log.Infof("[%s] acme: Validations succeeded; requesting certificates", strings.Join(domains, ", "))
failures := make(ObtainError)
cert, err := c.requestCertificateForCsr(order, bundle, csr.Raw, nil)
if err != nil {
for _, chln := range authz {
failures[chln.Identifier.Value] = err
}
}
if cert != nil {
// Add the CSR to the certificate so that it can be used for renewals.
cert.CSR = pemEncode(&csr)
}
// do not return an empty failures map, because
// it would still be a non-nil error value
if len(failures) > 0 {
return cert, failures
}
return cert, nil
}
// ObtainCertificate tries to obtain a single certificate using all domains passed into it.
// The first domain in domains is used for the CommonName field of the certificate, all other
// domains are added using the Subject Alternate Names extension. A new private key is generated
// for every invocation of this function. If you do not want that you can supply your own private key
// in the privKey parameter. If this parameter is non-nil it will be used instead of generating a new one.
// If bundle is true, the []byte contains both the issuer certificate and
// your issued certificate as a bundle.
// This function will never return a partial certificate. If one domain in the list fails,
// the whole certificate will fail.
func (c *Client) ObtainCertificate(domains []string, bundle bool, privKey crypto.PrivateKey, mustStaple bool) (*CertificateResource, error) {
if len(domains) == 0 {
return nil, errors.New("no domains to obtain a certificate for")
}
if bundle {
log.Infof("[%s] acme: Obtaining bundled SAN certificate", strings.Join(domains, ", "))
} else {
log.Infof("[%s] acme: Obtaining SAN certificate", strings.Join(domains, ", "))
}
order, err := c.createOrderForIdentifiers(domains)
if err != nil {
return nil, err
}
authz, err := c.getAuthzForOrder(order)
if err != nil {
// If any challenge fails, return. Do not generate partial SAN certificates.
/*for _, auth := range authz {
c.disableAuthz(auth)
}*/
return nil, err
}
err = c.solveChallengeForAuthz(authz)
if err != nil {
// If any challenge fails, return. Do not generate partial SAN certificates.
return nil, err
}
log.Infof("[%s] acme: Validations succeeded; requesting certificates", strings.Join(domains, ", "))
failures := make(ObtainError)
cert, err := c.requestCertificateForOrder(order, bundle, privKey, mustStaple)
if err != nil {
for _, auth := range authz {
failures[auth.Identifier.Value] = err
}
}
// do not return an empty failures map, because
// it would still be a non-nil error value
if len(failures) > 0 {
return cert, failures
}
return cert, nil
}
// RevokeCertificate takes a PEM encoded certificate or bundle and tries to revoke it at the CA.
func (c *Client) RevokeCertificate(certificate []byte) error {
certificates, err := parsePEMBundle(certificate)
if err != nil {
return err
}
x509Cert := certificates[0]
if x509Cert.IsCA {
return fmt.Errorf("Certificate bundle starts with a CA certificate")
}
encodedCert := base64.URLEncoding.EncodeToString(x509Cert.Raw)
_, err = postJSON(c.jws, c.directory.RevokeCertURL, revokeCertMessage{Certificate: encodedCert}, nil)
return err
}
// RenewCertificate takes a CertificateResource and tries to renew the certificate.
// If the renewal process succeeds, the new certificate will ge returned in a new CertResource.
// Please be aware that this function will return a new certificate in ANY case that is not an error.
// If the server does not provide us with a new cert on a GET request to the CertURL
// this function will start a new-cert flow where a new certificate gets generated.
// If bundle is true, the []byte contains both the issuer certificate and
// your issued certificate as a bundle.
// For private key reuse the PrivateKey property of the passed in CertificateResource should be non-nil.
func (c *Client) RenewCertificate(cert CertificateResource, bundle, mustStaple bool) (*CertificateResource, error) {
// Input certificate is PEM encoded. Decode it here as we may need the decoded
// cert later on in the renewal process. The input may be a bundle or a single certificate.
certificates, err := parsePEMBundle(cert.Certificate)
if err != nil {
return nil, err
}
x509Cert := certificates[0]
if x509Cert.IsCA {
return nil, fmt.Errorf("[%s] Certificate bundle starts with a CA certificate", cert.Domain)
}
// This is just meant to be informal for the user.
timeLeft := x509Cert.NotAfter.Sub(time.Now().UTC())
log.Infof("[%s] acme: Trying renewal with %d hours remaining", cert.Domain, int(timeLeft.Hours()))
// We always need to request a new certificate to renew.
// Start by checking to see if the certificate was based off a CSR, and
// use that if it's defined.
if len(cert.CSR) > 0 {
csr, errP := pemDecodeTox509CSR(cert.CSR)
if errP != nil {
return nil, errP
}
newCert, failures := c.ObtainCertificateForCSR(*csr, bundle)
return newCert, failures
}
var privKey crypto.PrivateKey
if cert.PrivateKey != nil {
privKey, err = parsePEMPrivateKey(cert.PrivateKey)
if err != nil {
return nil, err
}
}
var domains []string
// check for SAN certificate
if len(x509Cert.DNSNames) > 1 {
domains = append(domains, x509Cert.Subject.CommonName)
for _, sanDomain := range x509Cert.DNSNames {
if sanDomain == x509Cert.Subject.CommonName {
continue
}
domains = append(domains, sanDomain)
}
} else {
domains = append(domains, x509Cert.Subject.CommonName)
}
newCert, err := c.ObtainCertificate(domains, bundle, privKey, mustStaple)
return newCert, err
}
func (c *Client) createOrderForIdentifiers(domains []string) (orderResource, error) {
var identifiers []identifier
for _, domain := range domains {
identifiers = append(identifiers, identifier{Type: "dns", Value: domain})
}
order := orderMessage{
Identifiers: identifiers,
}
var response orderMessage
hdr, err := postJSON(c.jws, c.directory.NewOrderURL, order, &response)
if err != nil {
return orderResource{}, err
}
orderRes := orderResource{
URL: hdr.Get("Location"),
Domains: domains,
orderMessage: response,
}
return orderRes, nil
}
// an authz with the solver we have chosen and the index of the challenge associated with it
type selectedAuthSolver struct {
authz authorization
challengeIndex int
solver solver
}
// Looks through the challenge combinations to find a solvable match.
// Then solves the challenges in series and returns.
func (c *Client) solveChallengeForAuthz(authorizations []authorization) error {
failures := make(ObtainError)
authSolvers := []*selectedAuthSolver{}
// loop through the resources, basically through the domains. First pass just selects a solver for each authz.
for _, authz := range authorizations {
if authz.Status == statusValid {
// Boulder might recycle recent validated authz (see issue #267)
log.Infof("[%s] acme: Authorization already valid; skipping challenge", authz.Identifier.Value)
continue
}
if i, solvr := c.chooseSolver(authz, authz.Identifier.Value); solvr != nil {
authSolvers = append(authSolvers, &selectedAuthSolver{
authz: authz,
challengeIndex: i,
solver: solvr,
})
} else {
failures[authz.Identifier.Value] = fmt.Errorf("[%s] acme: Could not determine solvers", authz.Identifier.Value)
}
}
// for all valid presolvers, first submit the challenges so they have max time to propagate
for _, item := range authSolvers {
authz := item.authz
i := item.challengeIndex
if presolver, ok := item.solver.(preSolver); ok {
if err := presolver.PreSolve(authz.Challenges[i], authz.Identifier.Value); err != nil {
failures[authz.Identifier.Value] = err
}
}
}
defer func() {
// clean all created TXT records
for _, item := range authSolvers {
if clean, ok := item.solver.(cleanup); ok {
if failures[item.authz.Identifier.Value] != nil {
// already failed in previous loop
continue
}
err := clean.CleanUp(item.authz.Challenges[item.challengeIndex], item.authz.Identifier.Value)
if err != nil {
log.Warnf("Error cleaning up %s: %v ", item.authz.Identifier.Value, err)
}
}
}
}()
// finally solve all challenges for real
for _, item := range authSolvers {
authz := item.authz
i := item.challengeIndex
if failures[authz.Identifier.Value] != nil {
// already failed in previous loop
continue
}
if err := item.solver.Solve(authz.Challenges[i], authz.Identifier.Value); err != nil {
failures[authz.Identifier.Value] = err
}
}
// be careful not to return an empty failures map, for
// even an empty ObtainError is a non-nil error value
if len(failures) > 0 {
return failures
}
return nil
}
// Checks all challenges from the server in order and returns the first matching solver.
func (c *Client) chooseSolver(auth authorization, domain string) (int, solver) {
for i, challenge := range auth.Challenges {
if solver, ok := c.solvers[Challenge(challenge.Type)]; ok {
return i, solver
}
log.Infof("[%s] acme: Could not find solver for: %s", domain, challenge.Type)
}
return 0, nil
}
// Get the challenges needed to proof our identifier to the ACME server.
func (c *Client) getAuthzForOrder(order orderResource) ([]authorization, error) {
resc, errc := make(chan authorization), make(chan domainError)
delay := time.Second / overallRequestLimit
for _, authzURL := range order.Authorizations {
time.Sleep(delay)
go func(authzURL string) {
var authz authorization
_, err := postAsGet(c.jws, authzURL, &authz)
if err != nil {
errc <- domainError{Domain: authz.Identifier.Value, Error: err}
return
}
resc <- authz
}(authzURL)
}
var responses []authorization
failures := make(ObtainError)
for i := 0; i < len(order.Authorizations); i++ {
select {
case res := <-resc:
responses = append(responses, res)
case err := <-errc:
failures[err.Domain] = err.Error
}
}
logAuthz(order)
close(resc)
close(errc)
// be careful to not return an empty failures map;
// even if empty, they become non-nil error values
if len(failures) > 0 {
return responses, failures
}
return responses, nil
}
func logAuthz(order orderResource) {
for i, auth := range order.Authorizations {
log.Infof("[%s] AuthURL: %s", order.Identifiers[i].Value, auth)
}
}
// cleanAuthz loops through the passed in slice and disables any auths which are not "valid"
func (c *Client) disableAuthz(authURL string) error {
var disabledAuth authorization
_, err := postJSON(c.jws, authURL, deactivateAuthMessage{Status: "deactivated"}, &disabledAuth)
return err
}
func (c *Client) requestCertificateForOrder(order orderResource, bundle bool, privKey crypto.PrivateKey, mustStaple bool) (*CertificateResource, error) {
var err error
if privKey == nil {
privKey, err = generatePrivateKey(c.keyType)
if err != nil {
return nil, err
}
}
// determine certificate name(s) based on the authorization resources
commonName := order.Domains[0]
// ACME draft Section 7.4 "Applying for Certificate Issuance"
// https://tools.ietf.org/html/draft-ietf-acme-acme-12#section-7.4
// says:
// Clients SHOULD NOT make any assumptions about the sort order of
// "identifiers" or "authorizations" elements in the returned order
// object.
san := []string{commonName}
for _, auth := range order.Identifiers {
if auth.Value != commonName {
san = append(san, auth.Value)
}
}
// TODO: should the CSR be customizable?
csr, err := generateCsr(privKey, commonName, san, mustStaple)
if err != nil {
return nil, err
}
return c.requestCertificateForCsr(order, bundle, csr, pemEncode(privKey))
}
func (c *Client) requestCertificateForCsr(order orderResource, bundle bool, csr []byte, privateKeyPem []byte) (*CertificateResource, error) {
commonName := order.Domains[0]
csrString := base64.RawURLEncoding.EncodeToString(csr)
var retOrder orderMessage
_, err := postJSON(c.jws, order.Finalize, csrMessage{Csr: csrString}, &retOrder)
if err != nil {
return nil, err
}
if retOrder.Status == statusInvalid {
return nil, err
}
certRes := CertificateResource{
Domain: commonName,
CertURL: retOrder.Certificate,
PrivateKey: privateKeyPem,
}
if retOrder.Status == statusValid {
// if the certificate is available right away, short cut!
ok, err := c.checkCertResponse(retOrder, &certRes, bundle)
if err != nil {
return nil, err
}
if ok {
return &certRes, nil
}
}
stopTimer := time.NewTimer(30 * time.Second)
defer stopTimer.Stop()
retryTick := time.NewTicker(500 * time.Millisecond)
defer retryTick.Stop()
for {
select {
case <-stopTimer.C:
return nil, errors.New("certificate polling timed out")
case <-retryTick.C:
_, err := postAsGet(c.jws, order.URL, &retOrder)
if err != nil {
return nil, err
}
done, err := c.checkCertResponse(retOrder, &certRes, bundle)
if err != nil {
return nil, err
}
if done {
return &certRes, nil
}
}
}
}
// checkCertResponse checks to see if the certificate is ready and a link is contained in the
// response. if so, loads it into certRes and returns true. If the cert
// is not yet ready, it returns false. The certRes input
// should already have the Domain (common name) field populated. If bundle is
// true, the certificate will be bundled with the issuer's cert.
func (c *Client) checkCertResponse(order orderMessage, certRes *CertificateResource, bundle bool) (bool, error) {
switch order.Status {
case statusValid:
resp, err := postAsGet(c.jws, order.Certificate, nil)
if err != nil {
return false, err
}
cert, err := ioutil.ReadAll(limitReader(resp.Body, maxBodySize))
if err != nil {
return false, err
}
// The issuer certificate link may be supplied via an "up" link
// in the response headers of a new certificate. See
// https://tools.ietf.org/html/draft-ietf-acme-acme-12#section-7.4.2
links := parseLinks(resp.Header["Link"])
if link, ok := links["up"]; ok {
issuerCert, err := c.getIssuerCertificate(link)
if err != nil {
// If we fail to acquire the issuer cert, return the issued certificate - do not fail.
log.Warnf("[%s] acme: Could not bundle issuer certificate: %v", certRes.Domain, err)
} else {
issuerCert = pemEncode(derCertificateBytes(issuerCert))
// If bundle is true, we want to return a certificate bundle.
// To do this, we append the issuer cert to the issued cert.
if bundle {
cert = append(cert, issuerCert...)
}
certRes.IssuerCertificate = issuerCert
}
} else {
// Get issuerCert from bundled response from Let's Encrypt
// See https://community.letsencrypt.org/t/acme-v2-no-up-link-in-response/64962
_, rest := pem.Decode(cert)
if rest != nil {
certRes.IssuerCertificate = rest
}
}
certRes.Certificate = cert
certRes.CertURL = order.Certificate
certRes.CertStableURL = order.Certificate
log.Infof("[%s] Server responded with a certificate.", certRes.Domain)
return true, nil
case "processing":
return false, nil
case statusInvalid:
return false, errors.New("order has invalid state: invalid")
default:
return false, nil
}
}
// getIssuerCertificate requests the issuer certificate
func (c *Client) getIssuerCertificate(url string) ([]byte, error) {
log.Infof("acme: Requesting issuer cert from %s", url)
resp, err := postAsGet(c.jws, url, nil)
if err != nil {
return nil, err
}
defer resp.Body.Close()
issuerBytes, err := ioutil.ReadAll(limitReader(resp.Body, maxBodySize))
if err != nil {
return nil, err
}
_, err = x509.ParseCertificate(issuerBytes)
if err != nil {
return nil, err
}
return issuerBytes, err
}
func parseLinks(links []string) map[string]string {
aBrkt := regexp.MustCompile("[<>]")
slver := regexp.MustCompile("(.+) *= *\"(.+)\"")
linkMap := make(map[string]string)
for _, link := range links {
link = aBrkt.ReplaceAllString(link, "")
parts := strings.Split(link, ";")
matches := slver.FindStringSubmatch(parts[1])
if len(matches) > 0 {
linkMap[matches[2]] = parts[0]
}
}
return linkMap
}
// validate makes the ACME server start validating a
// challenge response, only returning once it is done.
func validate(j *jws, domain, uri string, c challenge) error {
var chlng challenge
// Challenge initiation is done by sending a JWS payload containing the
// trivial JSON object `{}`. We use an empty struct instance as the postJSON
// payload here to achieve this result.
hdr, err := postJSON(j, uri, struct{}{}, &chlng)
if err != nil {
return err
}
// After the path is sent, the ACME server will access our server.
// Repeatedly check the server for an updated status on our request.
for {
switch chlng.Status {
case statusValid:
log.Infof("[%s] The server validated our request", domain)
return nil
case "pending":
case "processing":
case statusInvalid:
return handleChallengeError(chlng)
default:
return errors.New("the server returned an unexpected state")
}
ra, err := strconv.Atoi(hdr.Get("Retry-After"))
if err != nil {
// The ACME server MUST return a Retry-After.
// If it doesn't, we'll just poll hard.
ra = 5
}
time.Sleep(time.Duration(ra) * time.Second)
resp, err := postAsGet(j, uri, &chlng)
if err != nil {
return err
}
if resp != nil {
hdr = resp.Header
}
}
}

284
vendor/github.com/xenolf/lego/acme/commons.go generated vendored Normal file
View file

@ -0,0 +1,284 @@
// Package acme contains all objects related the ACME endpoints.
// https://tools.ietf.org/html/draft-ietf-acme-acme-16
package acme
import (
"encoding/json"
"time"
)
// Challenge statuses
// https://tools.ietf.org/html/draft-ietf-acme-acme-16#section-7.1.6
const (
StatusPending = "pending"
StatusInvalid = "invalid"
StatusValid = "valid"
StatusProcessing = "processing"
StatusDeactivated = "deactivated"
StatusExpired = "expired"
StatusRevoked = "revoked"
)
// Directory the ACME directory object.
// - https://tools.ietf.org/html/draft-ietf-acme-acme-16#section-7.1.1
type Directory struct {
NewNonceURL string `json:"newNonce"`
NewAccountURL string `json:"newAccount"`
NewOrderURL string `json:"newOrder"`
NewAuthzURL string `json:"newAuthz"`
RevokeCertURL string `json:"revokeCert"`
KeyChangeURL string `json:"keyChange"`
Meta Meta `json:"meta"`
}
// Meta the ACME meta object (related to Directory).
// - https://tools.ietf.org/html/draft-ietf-acme-acme-16#section-7.1.1
type Meta struct {
// termsOfService (optional, string):
// A URL identifying the current terms of service.
TermsOfService string `json:"termsOfService"`
// website (optional, string):
// An HTTP or HTTPS URL locating a website providing more information about the ACME server.
Website string `json:"website"`
// caaIdentities (optional, array of string):
// The hostnames that the ACME server recognizes as referring to itself
// for the purposes of CAA record validation as defined in [RFC6844].
// Each string MUST represent the same sequence of ASCII code points
// that the server will expect to see as the "Issuer Domain Name" in a CAA issue or issuewild property tag.
// This allows clients to determine the correct issuer domain name to use when configuring CAA records.
CaaIdentities []string `json:"caaIdentities"`
// externalAccountRequired (optional, boolean):
// If this field is present and set to "true",
// then the CA requires that all new- account requests include an "externalAccountBinding" field
// associating the new account with an external account.
ExternalAccountRequired bool `json:"externalAccountRequired"`
}
// ExtendedAccount a extended Account.
type ExtendedAccount struct {
Account
// Contains the value of the response header `Location`
Location string `json:"-"`
}
// Account the ACME account Object.
// - https://tools.ietf.org/html/draft-ietf-acme-acme-16#section-7.1.2
// - https://tools.ietf.org/html/draft-ietf-acme-acme-16#section-7.3
type Account struct {
// status (required, string):
// The status of this account.
// Possible values are: "valid", "deactivated", and "revoked".
// The value "deactivated" should be used to indicate client-initiated deactivation
// whereas "revoked" should be used to indicate server- initiated deactivation. (See Section 7.1.6)
Status string `json:"status,omitempty"`
// contact (optional, array of string):
// An array of URLs that the server can use to contact the client for issues related to this account.
// For example, the server may wish to notify the client about server-initiated revocation or certificate expiration.
// For information on supported URL schemes, see Section 7.3
Contact []string `json:"contact,omitempty"`
// termsOfServiceAgreed (optional, boolean):
// Including this field in a new-account request,
// with a value of true, indicates the client's agreement with the terms of service.
// This field is not updateable by the client.
TermsOfServiceAgreed bool `json:"termsOfServiceAgreed,omitempty"`
// orders (required, string):
// A URL from which a list of orders submitted by this account can be fetched via a POST-as-GET request,
// as described in Section 7.1.2.1.
Orders string `json:"orders,omitempty"`
// onlyReturnExisting (optional, boolean):
// If this field is present with the value "true",
// then the server MUST NOT create a new account if one does not already exist.
// This allows a client to look up an account URL based on an account key (see Section 7.3.1).
OnlyReturnExisting bool `json:"onlyReturnExisting,omitempty"`
// externalAccountBinding (optional, object):
// An optional field for binding the new account with an existing non-ACME account (see Section 7.3.4).
ExternalAccountBinding json.RawMessage `json:"externalAccountBinding,omitempty"`
}
// ExtendedOrder a extended Order.
type ExtendedOrder struct {
Order
// The order URL, contains the value of the response header `Location`
Location string `json:"-"`
}
// Order the ACME order Object.
// - https://tools.ietf.org/html/draft-ietf-acme-acme-16#section-7.1.3
type Order struct {
// status (required, string):
// The status of this order.
// Possible values are: "pending", "ready", "processing", "valid", and "invalid".
Status string `json:"status,omitempty"`
// expires (optional, string):
// The timestamp after which the server will consider this order invalid,
// encoded in the format specified in RFC 3339 [RFC3339].
// This field is REQUIRED for objects with "pending" or "valid" in the status field.
Expires string `json:"expires,omitempty"`
// identifiers (required, array of object):
// An array of identifier objects that the order pertains to.
Identifiers []Identifier `json:"identifiers"`
// notBefore (optional, string):
// The requested value of the notBefore field in the certificate,
// in the date format defined in [RFC3339].
NotBefore string `json:"notBefore,omitempty"`
// notAfter (optional, string):
// The requested value of the notAfter field in the certificate,
// in the date format defined in [RFC3339].
NotAfter string `json:"notAfter,omitempty"`
// error (optional, object):
// The error that occurred while processing the order, if any.
// This field is structured as a problem document [RFC7807].
Error *ProblemDetails `json:"error,omitempty"`
// authorizations (required, array of string):
// For pending orders,
// the authorizations that the client needs to complete before the requested certificate can be issued (see Section 7.5),
// including unexpired authorizations that the client has completed in the past for identifiers specified in the order.
// The authorizations required are dictated by server policy
// and there may not be a 1:1 relationship between the order identifiers and the authorizations required.
// For final orders (in the "valid" or "invalid" state), the authorizations that were completed.
// Each entry is a URL from which an authorization can be fetched with a POST-as-GET request.
Authorizations []string `json:"authorizations,omitempty"`
// finalize (required, string):
// A URL that a CSR must be POSTed to once all of the order's authorizations are satisfied to finalize the order.
// The result of a successful finalization will be the population of the certificate URL for the order.
Finalize string `json:"finalize,omitempty"`
// certificate (optional, string):
// A URL for the certificate that has been issued in response to this order
Certificate string `json:"certificate,omitempty"`
}
// Authorization the ACME authorization object.
// - https://tools.ietf.org/html/draft-ietf-acme-acme-16#section-7.1.4
type Authorization struct {
// status (required, string):
// The status of this authorization.
// Possible values are: "pending", "valid", "invalid", "deactivated", "expired", and "revoked".
Status string `json:"status"`
// expires (optional, string):
// The timestamp after which the server will consider this authorization invalid,
// encoded in the format specified in RFC 3339 [RFC3339].
// This field is REQUIRED for objects with "valid" in the "status" field.
Expires time.Time `json:"expires,omitempty"`
// identifier (required, object):
// The identifier that the account is authorized to represent
Identifier Identifier `json:"identifier,omitempty"`
// challenges (required, array of objects):
// For pending authorizations, the challenges that the client can fulfill in order to prove possession of the identifier.
// For valid authorizations, the challenge that was validated.
// For invalid authorizations, the challenge that was attempted and failed.
// Each array entry is an object with parameters required to validate the challenge.
// A client should attempt to fulfill one of these challenges,
// and a server should consider any one of the challenges sufficient to make the authorization valid.
Challenges []Challenge `json:"challenges,omitempty"`
// wildcard (optional, boolean):
// For authorizations created as a result of a newOrder request containing a DNS identifier
// with a value that contained a wildcard prefix this field MUST be present, and true.
Wildcard bool `json:"wildcard,omitempty"`
}
// ExtendedChallenge a extended Challenge.
type ExtendedChallenge struct {
Challenge
// Contains the value of the response header `Retry-After`
RetryAfter string `json:"-"`
// Contains the value of the response header `Link` rel="up"
AuthorizationURL string `json:"-"`
}
// Challenge the ACME challenge object.
// - https://tools.ietf.org/html/draft-ietf-acme-acme-16#section-7.1.5
// - https://tools.ietf.org/html/draft-ietf-acme-acme-16#section-8
type Challenge struct {
// type (required, string):
// The type of challenge encoded in the object.
Type string `json:"type"`
// url (required, string):
// The URL to which a response can be posted.
URL string `json:"url"`
// status (required, string):
// The status of this challenge. Possible values are: "pending", "processing", "valid", and "invalid".
Status string `json:"status"`
// validated (optional, string):
// The time at which the server validated this challenge,
// encoded in the format specified in RFC 3339 [RFC3339].
// This field is REQUIRED if the "status" field is "valid".
Validated time.Time `json:"validated,omitempty"`
// error (optional, object):
// Error that occurred while the server was validating the challenge, if any,
// structured as a problem document [RFC7807].
// Multiple errors can be indicated by using subproblems Section 6.7.1.
// A challenge object with an error MUST have status equal to "invalid".
Error *ProblemDetails `json:"error,omitempty"`
// token (required, string):
// A random value that uniquely identifies the challenge.
// This value MUST have at least 128 bits of entropy.
// It MUST NOT contain any characters outside the base64url alphabet,
// and MUST NOT include base64 padding characters ("=").
// See [RFC4086] for additional information on randomness requirements.
// https://tools.ietf.org/html/draft-ietf-acme-acme-16#section-8.3
// https://tools.ietf.org/html/draft-ietf-acme-acme-16#section-8.4
Token string `json:"token"`
// https://tools.ietf.org/html/draft-ietf-acme-acme-16#section-8.1
KeyAuthorization string `json:"keyAuthorization"`
}
// Identifier the ACME identifier object.
// - https://tools.ietf.org/html/draft-ietf-acme-acme-16#section-9.7.7
type Identifier struct {
Type string `json:"type"`
Value string `json:"value"`
}
// CSRMessage Certificate Signing Request
// - https://tools.ietf.org/html/draft-ietf-acme-acme-16#section-7.4
type CSRMessage struct {
// csr (required, string):
// A CSR encoding the parameters for the certificate being requested [RFC2986].
// The CSR is sent in the base64url-encoded version of the DER format.
// (Note: Because this field uses base64url, and does not include headers, it is different from PEM.).
Csr string `json:"csr"`
}
// RevokeCertMessage a certificate revocation message
// - https://tools.ietf.org/html/draft-ietf-acme-acme-16#section-7.6
// - https://tools.ietf.org/html/rfc5280#section-5.3.1
type RevokeCertMessage struct {
// certificate (required, string):
// The certificate to be revoked, in the base64url-encoded version of the DER format.
// (Note: Because this field uses base64url, and does not include headers, it is different from PEM.)
Certificate string `json:"certificate"`
// reason (optional, int):
// One of the revocation reasonCodes defined in Section 5.3.1 of [RFC5280] to be used when generating OCSP responses and CRLs.
// If this field is not set the server SHOULD omit the reasonCode CRL entry extension when generating OCSP responses and CRLs.
// The server MAY disallow a subset of reasonCodes from being used by the user.
// If a request contains a disallowed reasonCode the server MUST reject it with the error type "urn:ietf:params:acme:error:badRevocationReason".
// The problem document detail SHOULD indicate which reasonCodes are allowed.
Reason *uint `json:"reason,omitempty"`
}

View file

@ -1,334 +0,0 @@
package acme
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
"io"
"io/ioutil"
"math/big"
"net/http"
"time"
"golang.org/x/crypto/ocsp"
jose "gopkg.in/square/go-jose.v2"
)
// KeyType represents the key algo as well as the key size or curve to use.
type KeyType string
type derCertificateBytes []byte
// Constants for all key types we support.
const (
EC256 = KeyType("P256")
EC384 = KeyType("P384")
RSA2048 = KeyType("2048")
RSA4096 = KeyType("4096")
RSA8192 = KeyType("8192")
)
const (
// OCSPGood means that the certificate is valid.
OCSPGood = ocsp.Good
// OCSPRevoked means that the certificate has been deliberately revoked.
OCSPRevoked = ocsp.Revoked
// OCSPUnknown means that the OCSP responder doesn't know about the certificate.
OCSPUnknown = ocsp.Unknown
// OCSPServerFailed means that the OCSP responder failed to process the request.
OCSPServerFailed = ocsp.ServerFailed
)
// Constants for OCSP must staple
var (
tlsFeatureExtensionOID = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 24}
ocspMustStapleFeature = []byte{0x30, 0x03, 0x02, 0x01, 0x05}
)
// GetOCSPForCert takes a PEM encoded cert or cert bundle returning the raw OCSP response,
// the parsed response, and an error, if any. The returned []byte can be passed directly
// into the OCSPStaple property of a tls.Certificate. If the bundle only contains the
// issued certificate, this function will try to get the issuer certificate from the
// IssuingCertificateURL in the certificate. If the []byte and/or ocsp.Response return
// values are nil, the OCSP status may be assumed OCSPUnknown.
func GetOCSPForCert(bundle []byte) ([]byte, *ocsp.Response, error) {
certificates, err := parsePEMBundle(bundle)
if err != nil {
return nil, nil, err
}
// We expect the certificate slice to be ordered downwards the chain.
// SRV CRT -> CA. We need to pull the leaf and issuer certs out of it,
// which should always be the first two certificates. If there's no
// OCSP server listed in the leaf cert, there's nothing to do. And if
// we have only one certificate so far, we need to get the issuer cert.
issuedCert := certificates[0]
if len(issuedCert.OCSPServer) == 0 {
return nil, nil, errors.New("no OCSP server specified in cert")
}
if len(certificates) == 1 {
// TODO: build fallback. If this fails, check the remaining array entries.
if len(issuedCert.IssuingCertificateURL) == 0 {
return nil, nil, errors.New("no issuing certificate URL")
}
resp, errC := httpGet(issuedCert.IssuingCertificateURL[0])
if errC != nil {
return nil, nil, errC
}
defer resp.Body.Close()
issuerBytes, errC := ioutil.ReadAll(limitReader(resp.Body, 1024*1024))
if errC != nil {
return nil, nil, errC
}
issuerCert, errC := x509.ParseCertificate(issuerBytes)
if errC != nil {
return nil, nil, errC
}
// Insert it into the slice on position 0
// We want it ordered right SRV CRT -> CA
certificates = append(certificates, issuerCert)
}
issuerCert := certificates[1]
// Finally kick off the OCSP request.
ocspReq, err := ocsp.CreateRequest(issuedCert, issuerCert, nil)
if err != nil {
return nil, nil, err
}
reader := bytes.NewReader(ocspReq)
req, err := httpPost(issuedCert.OCSPServer[0], "application/ocsp-request", reader)
if err != nil {
return nil, nil, err
}
defer req.Body.Close()
ocspResBytes, err := ioutil.ReadAll(limitReader(req.Body, 1024*1024))
if err != nil {
return nil, nil, err
}
ocspRes, err := ocsp.ParseResponse(ocspResBytes, issuerCert)
if err != nil {
return nil, nil, err
}
return ocspResBytes, ocspRes, nil
}
func getKeyAuthorization(token string, key interface{}) (string, error) {
var publicKey crypto.PublicKey
switch k := key.(type) {
case *ecdsa.PrivateKey:
publicKey = k.Public()
case *rsa.PrivateKey:
publicKey = k.Public()
}
// Generate the Key Authorization for the challenge
jwk := &jose.JSONWebKey{Key: publicKey}
if jwk == nil {
return "", errors.New("could not generate JWK from key")
}
thumbBytes, err := jwk.Thumbprint(crypto.SHA256)
if err != nil {
return "", err
}
// unpad the base64URL
keyThumb := base64.RawURLEncoding.EncodeToString(thumbBytes)
return token + "." + keyThumb, nil
}
// parsePEMBundle parses a certificate bundle from top to bottom and returns
// a slice of x509 certificates. This function will error if no certificates are found.
func parsePEMBundle(bundle []byte) ([]*x509.Certificate, error) {
var certificates []*x509.Certificate
var certDERBlock *pem.Block
for {
certDERBlock, bundle = pem.Decode(bundle)
if certDERBlock == nil {
break
}
if certDERBlock.Type == "CERTIFICATE" {
cert, err := x509.ParseCertificate(certDERBlock.Bytes)
if err != nil {
return nil, err
}
certificates = append(certificates, cert)
}
}
if len(certificates) == 0 {
return nil, errors.New("no certificates were found while parsing the bundle")
}
return certificates, nil
}
func parsePEMPrivateKey(key []byte) (crypto.PrivateKey, error) {
keyBlock, _ := pem.Decode(key)
switch keyBlock.Type {
case "RSA PRIVATE KEY":
return x509.ParsePKCS1PrivateKey(keyBlock.Bytes)
case "EC PRIVATE KEY":
return x509.ParseECPrivateKey(keyBlock.Bytes)
default:
return nil, errors.New("unknown PEM header value")
}
}
func generatePrivateKey(keyType KeyType) (crypto.PrivateKey, error) {
switch keyType {
case EC256:
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
case EC384:
return ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
case RSA2048:
return rsa.GenerateKey(rand.Reader, 2048)
case RSA4096:
return rsa.GenerateKey(rand.Reader, 4096)
case RSA8192:
return rsa.GenerateKey(rand.Reader, 8192)
}
return nil, fmt.Errorf("invalid KeyType: %s", keyType)
}
func generateCsr(privateKey crypto.PrivateKey, domain string, san []string, mustStaple bool) ([]byte, error) {
template := x509.CertificateRequest{
Subject: pkix.Name{CommonName: domain},
}
if len(san) > 0 {
template.DNSNames = san
}
if mustStaple {
template.ExtraExtensions = append(template.ExtraExtensions, pkix.Extension{
Id: tlsFeatureExtensionOID,
Value: ocspMustStapleFeature,
})
}
return x509.CreateCertificateRequest(rand.Reader, &template, privateKey)
}
func pemEncode(data interface{}) []byte {
var pemBlock *pem.Block
switch key := data.(type) {
case *ecdsa.PrivateKey:
keyBytes, _ := x509.MarshalECPrivateKey(key)
pemBlock = &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyBytes}
case *rsa.PrivateKey:
pemBlock = &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}
case *x509.CertificateRequest:
pemBlock = &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: key.Raw}
case derCertificateBytes:
pemBlock = &pem.Block{Type: "CERTIFICATE", Bytes: []byte(data.(derCertificateBytes))}
}
return pem.EncodeToMemory(pemBlock)
}
func pemDecode(data []byte) (*pem.Block, error) {
pemBlock, _ := pem.Decode(data)
if pemBlock == nil {
return nil, fmt.Errorf("Pem decode did not yield a valid block. Is the certificate in the right format?")
}
return pemBlock, nil
}
func pemDecodeTox509CSR(pem []byte) (*x509.CertificateRequest, error) {
pemBlock, err := pemDecode(pem)
if pemBlock == nil {
return nil, err
}
if pemBlock.Type != "CERTIFICATE REQUEST" {
return nil, fmt.Errorf("PEM block is not a certificate request")
}
return x509.ParseCertificateRequest(pemBlock.Bytes)
}
// GetPEMCertExpiration returns the "NotAfter" date of a PEM encoded certificate.
// The certificate has to be PEM encoded. Any other encodings like DER will fail.
func GetPEMCertExpiration(cert []byte) (time.Time, error) {
pemBlock, err := pemDecode(cert)
if pemBlock == nil {
return time.Time{}, err
}
return getCertExpiration(pemBlock.Bytes)
}
// getCertExpiration returns the "NotAfter" date of a DER encoded certificate.
func getCertExpiration(cert []byte) (time.Time, error) {
pCert, err := x509.ParseCertificate(cert)
if err != nil {
return time.Time{}, err
}
return pCert.NotAfter, nil
}
func generatePemCert(privKey *rsa.PrivateKey, domain string, extensions []pkix.Extension) ([]byte, error) {
derBytes, err := generateDerCert(privKey, time.Time{}, domain, extensions)
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}), nil
}
func generateDerCert(privKey *rsa.PrivateKey, expiration time.Time, domain string, extensions []pkix.Extension) ([]byte, error) {
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, err
}
if expiration.IsZero() {
expiration = time.Now().Add(365)
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: "ACME Challenge TEMP",
},
NotBefore: time.Now(),
NotAfter: expiration,
KeyUsage: x509.KeyUsageKeyEncipherment,
BasicConstraintsValid: true,
DNSNames: []string{domain},
ExtraExtensions: extensions,
}
return x509.CreateCertificate(rand.Reader, &template, &template, &privKey.PublicKey, privKey)
}
func limitReader(rd io.ReadCloser, numBytes int64) io.ReadCloser {
return http.MaxBytesReader(nil, rd, numBytes)
}

View file

@ -1,343 +0,0 @@
package acme
import (
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/miekg/dns"
"github.com/xenolf/lego/log"
)
type preCheckDNSFunc func(fqdn, value string) (bool, error)
var (
// PreCheckDNS checks DNS propagation before notifying ACME that
// the DNS challenge is ready.
PreCheckDNS preCheckDNSFunc = checkDNSPropagation
fqdnToZone = map[string]string{}
muFqdnToZone sync.Mutex
)
const defaultResolvConf = "/etc/resolv.conf"
const (
// DefaultPropagationTimeout default propagation timeout
DefaultPropagationTimeout = 60 * time.Second
// DefaultPollingInterval default polling interval
DefaultPollingInterval = 2 * time.Second
// DefaultTTL default TTL
DefaultTTL = 120
)
var defaultNameservers = []string{
"google-public-dns-a.google.com:53",
"google-public-dns-b.google.com:53",
}
// RecursiveNameservers are used to pre-check DNS propagation
var RecursiveNameservers = getNameservers(defaultResolvConf, defaultNameservers)
// DNSTimeout is used to override the default DNS timeout of 10 seconds.
var DNSTimeout = 10 * time.Second
// getNameservers attempts to get systems nameservers before falling back to the defaults
func getNameservers(path string, defaults []string) []string {
config, err := dns.ClientConfigFromFile(path)
if err != nil || len(config.Servers) == 0 {
return defaults
}
systemNameservers := []string{}
for _, server := range config.Servers {
// ensure all servers have a port number
if _, _, err := net.SplitHostPort(server); err != nil {
systemNameservers = append(systemNameservers, net.JoinHostPort(server, "53"))
} else {
systemNameservers = append(systemNameservers, server)
}
}
return systemNameservers
}
// DNS01Record returns a DNS record which will fulfill the `dns-01` challenge
func DNS01Record(domain, keyAuth string) (fqdn string, value string, ttl int) {
keyAuthShaBytes := sha256.Sum256([]byte(keyAuth))
// base64URL encoding without padding
value = base64.RawURLEncoding.EncodeToString(keyAuthShaBytes[:sha256.Size])
ttl = DefaultTTL
fqdn = fmt.Sprintf("_acme-challenge.%s.", domain)
return
}
// dnsChallenge implements the dns-01 challenge according to ACME 7.5
type dnsChallenge struct {
jws *jws
validate validateFunc
provider ChallengeProvider
}
// PreSolve just submits the txt record to the dns provider. It does not validate record propagation, or
// do anything at all with the acme server.
func (s *dnsChallenge) PreSolve(chlng challenge, domain string) error {
log.Infof("[%s] acme: Preparing to solve DNS-01", domain)
if s.provider == nil {
return errors.New("no DNS Provider configured")
}
// Generate the Key Authorization for the challenge
keyAuth, err := getKeyAuthorization(chlng.Token, s.jws.privKey)
if err != nil {
return err
}
err = s.provider.Present(domain, chlng.Token, keyAuth)
if err != nil {
return fmt.Errorf("error presenting token: %s", err)
}
return nil
}
func (s *dnsChallenge) Solve(chlng challenge, domain string) error {
log.Infof("[%s] acme: Trying to solve DNS-01", domain)
// Generate the Key Authorization for the challenge
keyAuth, err := getKeyAuthorization(chlng.Token, s.jws.privKey)
if err != nil {
return err
}
fqdn, value, _ := DNS01Record(domain, keyAuth)
log.Infof("[%s] Checking DNS record propagation using %+v", domain, RecursiveNameservers)
var timeout, interval time.Duration
switch provider := s.provider.(type) {
case ChallengeProviderTimeout:
timeout, interval = provider.Timeout()
default:
timeout, interval = DefaultPropagationTimeout, DefaultPollingInterval
}
err = WaitFor(timeout, interval, func() (bool, error) {
return PreCheckDNS(fqdn, value)
})
if err != nil {
return err
}
return s.validate(s.jws, domain, chlng.URL, challenge{Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth})
}
// CleanUp cleans the challenge
func (s *dnsChallenge) CleanUp(chlng challenge, domain string) error {
keyAuth, err := getKeyAuthorization(chlng.Token, s.jws.privKey)
if err != nil {
return err
}
return s.provider.CleanUp(domain, chlng.Token, keyAuth)
}
// checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers.
func checkDNSPropagation(fqdn, value string) (bool, error) {
// Initial attempt to resolve at the recursive NS
r, err := dnsQuery(fqdn, dns.TypeTXT, RecursiveNameservers, true)
if err != nil {
return false, err
}
if r.Rcode == dns.RcodeSuccess {
// If we see a CNAME here then use the alias
for _, rr := range r.Answer {
if cn, ok := rr.(*dns.CNAME); ok {
if cn.Hdr.Name == fqdn {
fqdn = cn.Target
break
}
}
}
}
authoritativeNss, err := lookupNameservers(fqdn)
if err != nil {
return false, err
}
return checkAuthoritativeNss(fqdn, value, authoritativeNss)
}
// checkAuthoritativeNss queries each of the given nameservers for the expected TXT record.
func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, error) {
for _, ns := range nameservers {
r, err := dnsQuery(fqdn, dns.TypeTXT, []string{net.JoinHostPort(ns, "53")}, false)
if err != nil {
return false, err
}
if r.Rcode != dns.RcodeSuccess {
return false, fmt.Errorf("NS %s returned %s for %s", ns, dns.RcodeToString[r.Rcode], fqdn)
}
var found bool
for _, rr := range r.Answer {
if txt, ok := rr.(*dns.TXT); ok {
if strings.Join(txt.Txt, "") == value {
found = true
break
}
}
}
if !found {
return false, fmt.Errorf("NS %s did not return the expected TXT record [fqdn: %s]", ns, fqdn)
}
}
return true, nil
}
// dnsQuery will query a nameserver, iterating through the supplied servers as it retries
// The nameserver should include a port, to facilitate testing where we talk to a mock dns server.
func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (in *dns.Msg, err error) {
m := new(dns.Msg)
m.SetQuestion(fqdn, rtype)
m.SetEdns0(4096, false)
if !recursive {
m.RecursionDesired = false
}
// Will retry the request based on the number of servers (n+1)
for i := 1; i <= len(nameservers)+1; i++ {
ns := nameservers[i%len(nameservers)]
udp := &dns.Client{Net: "udp", Timeout: DNSTimeout}
in, _, err = udp.Exchange(m, ns)
if err == dns.ErrTruncated {
tcp := &dns.Client{Net: "tcp", Timeout: DNSTimeout}
// If the TCP request succeeds, the err will reset to nil
in, _, err = tcp.Exchange(m, ns)
}
if err == nil {
break
}
}
return
}
// lookupNameservers returns the authoritative nameservers for the given fqdn.
func lookupNameservers(fqdn string) ([]string, error) {
var authoritativeNss []string
zone, err := FindZoneByFqdn(fqdn, RecursiveNameservers)
if err != nil {
return nil, fmt.Errorf("could not determine the zone: %v", err)
}
r, err := dnsQuery(zone, dns.TypeNS, RecursiveNameservers, true)
if err != nil {
return nil, err
}
for _, rr := range r.Answer {
if ns, ok := rr.(*dns.NS); ok {
authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns))
}
}
if len(authoritativeNss) > 0 {
return authoritativeNss, nil
}
return nil, fmt.Errorf("could not determine authoritative nameservers")
}
// FindZoneByFqdn determines the zone apex for the given fqdn by recursing up the
// domain labels until the nameserver returns a SOA record in the answer section.
func FindZoneByFqdn(fqdn string, nameservers []string) (string, error) {
muFqdnToZone.Lock()
defer muFqdnToZone.Unlock()
// Do we have it cached?
if zone, ok := fqdnToZone[fqdn]; ok {
return zone, nil
}
labelIndexes := dns.Split(fqdn)
for _, index := range labelIndexes {
domain := fqdn[index:]
in, err := dnsQuery(domain, dns.TypeSOA, nameservers, true)
if err != nil {
return "", err
}
// Any response code other than NOERROR and NXDOMAIN is treated as error
if in.Rcode != dns.RcodeNameError && in.Rcode != dns.RcodeSuccess {
return "", fmt.Errorf("unexpected response code '%s' for %s",
dns.RcodeToString[in.Rcode], domain)
}
// Check if we got a SOA RR in the answer section
if in.Rcode == dns.RcodeSuccess {
// CNAME records cannot/should not exist at the root of a zone.
// So we skip a domain when a CNAME is found.
if dnsMsgContainsCNAME(in) {
continue
}
for _, ans := range in.Answer {
if soa, ok := ans.(*dns.SOA); ok {
zone := soa.Hdr.Name
fqdnToZone[fqdn] = zone
return zone, nil
}
}
}
}
return "", fmt.Errorf("could not find the start of authority")
}
// dnsMsgContainsCNAME checks for a CNAME answer in msg
func dnsMsgContainsCNAME(msg *dns.Msg) bool {
for _, ans := range msg.Answer {
if _, ok := ans.(*dns.CNAME); ok {
return true
}
}
return false
}
// ClearFqdnCache clears the cache of fqdn to zone mappings. Primarily used in testing.
func ClearFqdnCache() {
fqdnToZone = map[string]string{}
}
// ToFqdn converts the name into a fqdn appending a trailing dot.
func ToFqdn(name string) string {
n := len(name)
if n == 0 || name[n-1] == '.' {
return name
}
return name + "."
}
// UnFqdn converts the fqdn into a name removing the trailing dot.
func UnFqdn(name string) string {
n := len(name)
if n != 0 && name[n-1] == '.' {
return name[:n-1]
}
return name
}

View file

@ -1,55 +0,0 @@
package acme
import (
"bufio"
"fmt"
"os"
"github.com/xenolf/lego/log"
)
const (
dnsTemplate = "%s %d IN TXT \"%s\""
)
// DNSProviderManual is an implementation of the ChallengeProvider interface
type DNSProviderManual struct{}
// NewDNSProviderManual returns a DNSProviderManual instance.
func NewDNSProviderManual() (*DNSProviderManual, error) {
return &DNSProviderManual{}, nil
}
// Present prints instructions for manually creating the TXT record
func (*DNSProviderManual) Present(domain, token, keyAuth string) error {
fqdn, value, ttl := DNS01Record(domain, keyAuth)
dnsRecord := fmt.Sprintf(dnsTemplate, fqdn, ttl, value)
authZone, err := FindZoneByFqdn(fqdn, RecursiveNameservers)
if err != nil {
return err
}
log.Infof("acme: Please create the following TXT record in your %s zone:", authZone)
log.Infof("acme: %s", dnsRecord)
log.Infof("acme: Press 'Enter' when you are done")
reader := bufio.NewReader(os.Stdin)
_, _ = reader.ReadString('\n')
return nil
}
// CleanUp prints instructions for manually removing the TXT record
func (*DNSProviderManual) CleanUp(domain, token, keyAuth string) error {
fqdn, _, ttl := DNS01Record(domain, keyAuth)
dnsRecord := fmt.Sprintf(dnsTemplate, fqdn, ttl, "...")
authZone, err := FindZoneByFqdn(fqdn, RecursiveNameservers)
if err != nil {
return err
}
log.Infof("acme: You can now remove this TXT record from your %s zone:", authZone)
log.Infof("acme: %s", dnsRecord)
return nil
}

View file

@ -1,91 +0,0 @@
package acme
import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"strings"
)
const (
tosAgreementError = "Terms of service have changed"
invalidNonceError = "urn:ietf:params:acme:error:badNonce"
)
// RemoteError is the base type for all errors specific to the ACME protocol.
type RemoteError struct {
StatusCode int `json:"status,omitempty"`
Type string `json:"type"`
Detail string `json:"detail"`
}
func (e RemoteError) Error() string {
return fmt.Sprintf("acme: Error %d - %s - %s", e.StatusCode, e.Type, e.Detail)
}
// TOSError represents the error which is returned if the user needs to
// accept the TOS.
// TODO: include the new TOS url if we can somehow obtain it.
type TOSError struct {
RemoteError
}
// NonceError represents the error which is returned if the
// nonce sent by the client was not accepted by the server.
type NonceError struct {
RemoteError
}
type domainError struct {
Domain string
Error error
}
// ObtainError is returned when there are specific errors available
// per domain. For example in ObtainCertificate
type ObtainError map[string]error
func (e ObtainError) Error() string {
buffer := bytes.NewBufferString("acme: Error -> One or more domains had a problem:\n")
for dom, err := range e {
buffer.WriteString(fmt.Sprintf("[%s] %s\n", dom, err))
}
return buffer.String()
}
func handleHTTPError(resp *http.Response) error {
var errorDetail RemoteError
contentType := resp.Header.Get("Content-Type")
if contentType == "application/json" || strings.HasPrefix(contentType, "application/problem+json") {
err := json.NewDecoder(resp.Body).Decode(&errorDetail)
if err != nil {
return err
}
} else {
detailBytes, err := ioutil.ReadAll(limitReader(resp.Body, maxBodySize))
if err != nil {
return err
}
errorDetail.Detail = string(detailBytes)
}
errorDetail.StatusCode = resp.StatusCode
// Check for errors we handle specifically
if errorDetail.StatusCode == http.StatusForbidden && errorDetail.Detail == tosAgreementError {
return TOSError{errorDetail}
}
if errorDetail.StatusCode == http.StatusBadRequest && errorDetail.Type == invalidNonceError {
return NonceError{errorDetail}
}
return errorDetail
}
func handleChallengeError(chlng challenge) error {
return chlng.Error
}

58
vendor/github.com/xenolf/lego/acme/errors.go generated vendored Normal file
View file

@ -0,0 +1,58 @@
package acme
import (
"fmt"
)
// Errors types
const (
errNS = "urn:ietf:params:acme:error:"
BadNonceErr = errNS + "badNonce"
)
// ProblemDetails the problem details object
// - https://tools.ietf.org/html/rfc7807#section-3.1
// - https://tools.ietf.org/html/draft-ietf-acme-acme-16#section-7.3.3
type ProblemDetails struct {
Type string `json:"type,omitempty"`
Detail string `json:"detail,omitempty"`
HTTPStatus int `json:"status,omitempty"`
Instance string `json:"instance,omitempty"`
SubProblems []SubProblem `json:"subproblems,omitempty"`
// additional values to have a better error message (Not defined by the RFC)
Method string `json:"method,omitempty"`
URL string `json:"url,omitempty"`
}
// SubProblem a "subproblems"
// - https://tools.ietf.org/html/draft-ietf-acme-acme-16#section-6.7.1
type SubProblem struct {
Type string `json:"type,omitempty"`
Detail string `json:"detail,omitempty"`
Identifier Identifier `json:"identifier,omitempty"`
}
func (p ProblemDetails) Error() string {
msg := fmt.Sprintf("acme: error: %d", p.HTTPStatus)
if len(p.Method) != 0 || len(p.URL) != 0 {
msg += fmt.Sprintf(" :: %s :: %s", p.Method, p.URL)
}
msg += fmt.Sprintf(" :: %s :: %s", p.Type, p.Detail)
for _, sub := range p.SubProblems {
msg += fmt.Sprintf(", problem: %q :: %s", sub.Type, sub.Detail)
}
if len(p.Instance) == 0 {
msg += ", url: " + p.Instance
}
return msg
}
// NonceError represents the error which is returned
// if the nonce sent by the client was not accepted by the server.
type NonceError struct {
*ProblemDetails
}

View file

@ -1,212 +0,0 @@
package acme
import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"os"
"runtime"
"strings"
"time"
)
var (
// UserAgent (if non-empty) will be tacked onto the User-Agent string in requests.
UserAgent string
// HTTPClient is an HTTP client with a reasonable timeout value and
// potentially a custom *x509.CertPool based on the caCertificatesEnvVar
// environment variable (see the `initCertPool` function)
HTTPClient = http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 15 * time.Second,
ResponseHeaderTimeout: 15 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{
ServerName: os.Getenv(caServerNameEnvVar),
RootCAs: initCertPool(),
},
},
}
)
const (
// ourUserAgent is the User-Agent of this underlying library package.
// NOTE: Update this with each tagged release.
ourUserAgent = "xenolf-acme/1.2.1"
// ourUserAgentComment is part of the UA comment linked to the version status of this underlying library package.
// values: detach|release
// NOTE: Update this with each tagged release.
ourUserAgentComment = "detach"
// caCertificatesEnvVar is the environment variable name that can be used to
// specify the path to PEM encoded CA Certificates that can be used to
// authenticate an ACME server with a HTTPS certificate not issued by a CA in
// the system-wide trusted root list.
caCertificatesEnvVar = "LEGO_CA_CERTIFICATES"
// caServerNameEnvVar is the environment variable name that can be used to
// specify the CA server name that can be used to
// authenticate an ACME server with a HTTPS certificate not issued by a CA in
// the system-wide trusted root list.
caServerNameEnvVar = "LEGO_CA_SERVER_NAME"
)
// initCertPool creates a *x509.CertPool populated with the PEM certificates
// found in the filepath specified in the caCertificatesEnvVar OS environment
// variable. If the caCertificatesEnvVar is not set then initCertPool will
// return nil. If there is an error creating a *x509.CertPool from the provided
// caCertificatesEnvVar value then initCertPool will panic.
func initCertPool() *x509.CertPool {
if customCACertsPath := os.Getenv(caCertificatesEnvVar); customCACertsPath != "" {
customCAs, err := ioutil.ReadFile(customCACertsPath)
if err != nil {
panic(fmt.Sprintf("error reading %s=%q: %v",
caCertificatesEnvVar, customCACertsPath, err))
}
certPool := x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM(customCAs); !ok {
panic(fmt.Sprintf("error creating x509 cert pool from %s=%q: %v",
caCertificatesEnvVar, customCACertsPath, err))
}
return certPool
}
return nil
}
// httpHead performs a HEAD request with a proper User-Agent string.
// The response body (resp.Body) is already closed when this function returns.
func httpHead(url string) (resp *http.Response, err error) {
req, err := http.NewRequest(http.MethodHead, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to head %q: %v", url, err)
}
req.Header.Set("User-Agent", userAgent())
resp, err = HTTPClient.Do(req)
if err != nil {
return resp, fmt.Errorf("failed to do head %q: %v", url, err)
}
resp.Body.Close()
return resp, err
}
// httpPost performs a POST request with a proper User-Agent string.
// Callers should close resp.Body when done reading from it.
func httpPost(url string, bodyType string, body io.Reader) (resp *http.Response, err error) {
req, err := http.NewRequest(http.MethodPost, url, body)
if err != nil {
return nil, fmt.Errorf("failed to post %q: %v", url, err)
}
req.Header.Set("Content-Type", bodyType)
req.Header.Set("User-Agent", userAgent())
return HTTPClient.Do(req)
}
// httpGet performs a GET request with a proper User-Agent string.
// Callers should close resp.Body when done reading from it.
func httpGet(url string) (resp *http.Response, err error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to get %q: %v", url, err)
}
req.Header.Set("User-Agent", userAgent())
return HTTPClient.Do(req)
}
// getJSON performs an HTTP GET request and parses the response body
// as JSON, into the provided respBody object.
func getJSON(uri string, respBody interface{}) (http.Header, error) {
resp, err := httpGet(uri)
if err != nil {
return nil, fmt.Errorf("failed to get json %q: %v", uri, err)
}
defer resp.Body.Close()
if resp.StatusCode >= http.StatusBadRequest {
return resp.Header, handleHTTPError(resp)
}
return resp.Header, json.NewDecoder(resp.Body).Decode(respBody)
}
// postJSON performs an HTTP POST request and parses the response body
// as JSON, into the provided respBody object.
func postJSON(j *jws, uri string, reqBody, respBody interface{}) (http.Header, error) {
jsonBytes, err := json.Marshal(reqBody)
if err != nil {
return nil, errors.New("failed to marshal network message")
}
resp, err := post(j, uri, jsonBytes, respBody)
if resp == nil {
return nil, err
}
defer resp.Body.Close()
return resp.Header, err
}
func postAsGet(j *jws, uri string, respBody interface{}) (*http.Response, error) {
return post(j, uri, []byte{}, respBody)
}
func post(j *jws, uri string, reqBody []byte, respBody interface{}) (*http.Response, error) {
resp, err := j.post(uri, reqBody)
if err != nil {
return nil, fmt.Errorf("failed to post JWS message. -> %v", err)
}
if resp.StatusCode >= http.StatusBadRequest {
err = handleHTTPError(resp)
switch err.(type) {
case NonceError:
// Retry once if the nonce was invalidated
retryResp, errP := j.post(uri, reqBody)
if errP != nil {
return nil, fmt.Errorf("failed to post JWS message. -> %v", errP)
}
if retryResp.StatusCode >= http.StatusBadRequest {
return retryResp, handleHTTPError(retryResp)
}
if respBody == nil {
return retryResp, nil
}
return retryResp, json.NewDecoder(retryResp.Body).Decode(respBody)
default:
return resp, err
}
}
if respBody == nil {
return resp, nil
}
return resp, json.NewDecoder(resp.Body).Decode(respBody)
}
// userAgent builds and returns the User-Agent string to use in requests.
func userAgent() string {
ua := fmt.Sprintf("%s %s (%s; %s; %s)", UserAgent, ourUserAgent, ourUserAgentComment, runtime.GOOS, runtime.GOARCH)
return strings.TrimSpace(ua)
}

View file

@ -1,42 +0,0 @@
package acme
import (
"fmt"
"github.com/xenolf/lego/log"
)
type httpChallenge struct {
jws *jws
validate validateFunc
provider ChallengeProvider
}
// HTTP01ChallengePath returns the URL path for the `http-01` challenge
func HTTP01ChallengePath(token string) string {
return "/.well-known/acme-challenge/" + token
}
func (s *httpChallenge) Solve(chlng challenge, domain string) error {
log.Infof("[%s] acme: Trying to solve HTTP-01", domain)
// Generate the Key Authorization for the challenge
keyAuth, err := getKeyAuthorization(chlng.Token, s.jws.privKey)
if err != nil {
return err
}
err = s.provider.Present(domain, chlng.Token, keyAuth)
if err != nil {
return fmt.Errorf("[%s] error presenting token: %v", domain, err)
}
defer func() {
err := s.provider.CleanUp(domain, chlng.Token, keyAuth)
if err != nil {
log.Warnf("[%s] error cleaning up: %v", domain, err)
}
}()
return s.validate(s.jws, domain, chlng.URL, challenge{Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth})
}

View file

@ -1,167 +0,0 @@
package acme
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"fmt"
"net/http"
"sync"
"gopkg.in/square/go-jose.v2"
)
type jws struct {
getNonceURL string
privKey crypto.PrivateKey
kid string
nonces nonceManager
}
// Posts a JWS signed message to the specified URL.
// It does NOT close the response body, so the caller must
// do that if no error was returned.
func (j *jws) post(url string, content []byte) (*http.Response, error) {
signedContent, err := j.signContent(url, content)
if err != nil {
return nil, fmt.Errorf("failed to sign content -> %s", err.Error())
}
data := bytes.NewBuffer([]byte(signedContent.FullSerialize()))
resp, err := httpPost(url, "application/jose+json", data)
if err != nil {
return nil, fmt.Errorf("failed to HTTP POST to %s -> %s", url, err.Error())
}
nonce, nonceErr := getNonceFromResponse(resp)
if nonceErr == nil {
j.nonces.Push(nonce)
}
return resp, nil
}
func (j *jws) signContent(url string, content []byte) (*jose.JSONWebSignature, error) {
var alg jose.SignatureAlgorithm
switch k := j.privKey.(type) {
case *rsa.PrivateKey:
alg = jose.RS256
case *ecdsa.PrivateKey:
if k.Curve == elliptic.P256() {
alg = jose.ES256
} else if k.Curve == elliptic.P384() {
alg = jose.ES384
}
}
jsonKey := jose.JSONWebKey{
Key: j.privKey,
KeyID: j.kid,
}
signKey := jose.SigningKey{
Algorithm: alg,
Key: jsonKey,
}
options := jose.SignerOptions{
NonceSource: j,
ExtraHeaders: make(map[jose.HeaderKey]interface{}),
}
options.ExtraHeaders["url"] = url
if j.kid == "" {
options.EmbedJWK = true
}
signer, err := jose.NewSigner(signKey, &options)
if err != nil {
return nil, fmt.Errorf("failed to create jose signer -> %s", err.Error())
}
signed, err := signer.Sign(content)
if err != nil {
return nil, fmt.Errorf("failed to sign content -> %s", err.Error())
}
return signed, nil
}
func (j *jws) signEABContent(url, kid string, hmac []byte) (*jose.JSONWebSignature, error) {
jwk := jose.JSONWebKey{Key: j.privKey}
jwkJSON, err := jwk.Public().MarshalJSON()
if err != nil {
return nil, fmt.Errorf("acme: error encoding eab jwk key: %s", err.Error())
}
signer, err := jose.NewSigner(
jose.SigningKey{Algorithm: jose.HS256, Key: hmac},
&jose.SignerOptions{
EmbedJWK: false,
ExtraHeaders: map[jose.HeaderKey]interface{}{
"kid": kid,
"url": url,
},
},
)
if err != nil {
return nil, fmt.Errorf("failed to create External Account Binding jose signer -> %s", err.Error())
}
signed, err := signer.Sign(jwkJSON)
if err != nil {
return nil, fmt.Errorf("failed to External Account Binding sign content -> %s", err.Error())
}
return signed, nil
}
func (j *jws) Nonce() (string, error) {
if nonce, ok := j.nonces.Pop(); ok {
return nonce, nil
}
return getNonce(j.getNonceURL)
}
type nonceManager struct {
nonces []string
sync.Mutex
}
func (n *nonceManager) Pop() (string, bool) {
n.Lock()
defer n.Unlock()
if len(n.nonces) == 0 {
return "", false
}
nonce := n.nonces[len(n.nonces)-1]
n.nonces = n.nonces[:len(n.nonces)-1]
return nonce, true
}
func (n *nonceManager) Push(nonce string) {
n.Lock()
defer n.Unlock()
n.nonces = append(n.nonces, nonce)
}
func getNonce(url string) (string, error) {
resp, err := httpHead(url)
if err != nil {
return "", fmt.Errorf("failed to get nonce from HTTP HEAD -> %s", err.Error())
}
return getNonceFromResponse(resp)
}
func getNonceFromResponse(resp *http.Response) (string, error) {
nonce := resp.Header.Get("Replay-Nonce")
if nonce == "" {
return "", fmt.Errorf("server did not respond with a proper nonce header")
}
return nonce, nil
}

View file

@ -1,103 +0,0 @@
package acme
import (
"encoding/json"
"time"
)
// RegistrationResource represents all important informations about a registration
// of which the client needs to keep track itself.
type RegistrationResource struct {
Body accountMessage `json:"body,omitempty"`
URI string `json:"uri,omitempty"`
}
type directory struct {
NewNonceURL string `json:"newNonce"`
NewAccountURL string `json:"newAccount"`
NewOrderURL string `json:"newOrder"`
RevokeCertURL string `json:"revokeCert"`
KeyChangeURL string `json:"keyChange"`
Meta struct {
TermsOfService string `json:"termsOfService"`
Website string `json:"website"`
CaaIdentities []string `json:"caaIdentities"`
ExternalAccountRequired bool `json:"externalAccountRequired"`
} `json:"meta"`
}
type accountMessage struct {
Status string `json:"status,omitempty"`
Contact []string `json:"contact,omitempty"`
TermsOfServiceAgreed bool `json:"termsOfServiceAgreed,omitempty"`
Orders string `json:"orders,omitempty"`
OnlyReturnExisting bool `json:"onlyReturnExisting,omitempty"`
ExternalAccountBinding json.RawMessage `json:"externalAccountBinding,omitempty"`
}
type orderResource struct {
URL string `json:"url,omitempty"`
Domains []string `json:"domains,omitempty"`
orderMessage `json:"body,omitempty"`
}
type orderMessage struct {
Status string `json:"status,omitempty"`
Expires string `json:"expires,omitempty"`
Identifiers []identifier `json:"identifiers"`
NotBefore string `json:"notBefore,omitempty"`
NotAfter string `json:"notAfter,omitempty"`
Authorizations []string `json:"authorizations,omitempty"`
Finalize string `json:"finalize,omitempty"`
Certificate string `json:"certificate,omitempty"`
}
type authorization struct {
Status string `json:"status"`
Expires time.Time `json:"expires"`
Identifier identifier `json:"identifier"`
Challenges []challenge `json:"challenges"`
}
type identifier struct {
Type string `json:"type"`
Value string `json:"value"`
}
type challenge struct {
URL string `json:"url"`
Type string `json:"type"`
Status string `json:"status"`
Token string `json:"token"`
Validated time.Time `json:"validated"`
KeyAuthorization string `json:"keyAuthorization"`
Error RemoteError `json:"error"`
}
type csrMessage struct {
Csr string `json:"csr"`
}
type revokeCertMessage struct {
Certificate string `json:"certificate"`
}
type deactivateAuthMessage struct {
Status string `jsom:"status"`
}
// CertificateResource represents a CA issued certificate.
// PrivateKey, Certificate and IssuerCertificate are all
// already PEM encoded and can be directly written to disk.
// Certificate may be a certificate bundle, depending on the
// options supplied to create it.
type CertificateResource struct {
Domain string `json:"domain"`
CertURL string `json:"certUrl"`
CertStableURL string `json:"certStableUrl"`
AccountRef string `json:"accountRef,omitempty"`
PrivateKey []byte `json:"-"`
Certificate []byte `json:"-"`
IssuerCertificate []byte `json:"-"`
CSR []byte `json:"-"`
}

View file

@ -1,104 +0,0 @@
package acme
import (
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509/pkix"
"encoding/asn1"
"fmt"
"github.com/xenolf/lego/log"
)
// idPeAcmeIdentifierV1 is the SMI Security for PKIX Certification Extension OID referencing the ACME extension.
// Reference: https://tools.ietf.org/html/draft-ietf-acme-tls-alpn-05#section-5.1
var idPeAcmeIdentifierV1 = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31}
type tlsALPNChallenge struct {
jws *jws
validate validateFunc
provider ChallengeProvider
}
// Solve manages the provider to validate and solve the challenge.
func (t *tlsALPNChallenge) Solve(chlng challenge, domain string) error {
log.Infof("[%s] acme: Trying to solve TLS-ALPN-01", domain)
// Generate the Key Authorization for the challenge
keyAuth, err := getKeyAuthorization(chlng.Token, t.jws.privKey)
if err != nil {
return err
}
err = t.provider.Present(domain, chlng.Token, keyAuth)
if err != nil {
return fmt.Errorf("[%s] error presenting token: %v", domain, err)
}
defer func() {
err := t.provider.CleanUp(domain, chlng.Token, keyAuth)
if err != nil {
log.Warnf("[%s] error cleaning up: %v", domain, err)
}
}()
return t.validate(t.jws, domain, chlng.URL, challenge{Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth})
}
// TLSALPNChallengeBlocks returns PEM blocks (certPEMBlock, keyPEMBlock) with the acmeValidation-v1 extension
// and domain name for the `tls-alpn-01` challenge.
func TLSALPNChallengeBlocks(domain, keyAuth string) ([]byte, []byte, error) {
// Compute the SHA-256 digest of the key authorization.
zBytes := sha256.Sum256([]byte(keyAuth))
value, err := asn1.Marshal(zBytes[:sha256.Size])
if err != nil {
return nil, nil, err
}
// Add the keyAuth digest as the acmeValidation-v1 extension
// (marked as critical such that it won't be used by non-ACME software).
// Reference: https://tools.ietf.org/html/draft-ietf-acme-tls-alpn-05#section-3
extensions := []pkix.Extension{
{
Id: idPeAcmeIdentifierV1,
Critical: true,
Value: value,
},
}
// Generate a new RSA key for the certificates.
tempPrivKey, err := generatePrivateKey(RSA2048)
if err != nil {
return nil, nil, err
}
rsaPrivKey := tempPrivKey.(*rsa.PrivateKey)
// Generate the PEM certificate using the provided private key, domain, and extra extensions.
tempCertPEM, err := generatePemCert(rsaPrivKey, domain, extensions)
if err != nil {
return nil, nil, err
}
// Encode the private key into a PEM format. We'll need to use it to generate the x509 keypair.
rsaPrivPEM := pemEncode(rsaPrivKey)
return tempCertPEM, rsaPrivPEM, nil
}
// TLSALPNChallengeCert returns a certificate with the acmeValidation-v1 extension
// and domain name for the `tls-alpn-01` challenge.
func TLSALPNChallengeCert(domain, keyAuth string) (*tls.Certificate, error) {
tempCertPEM, rsaPrivPEM, err := TLSALPNChallengeBlocks(domain, keyAuth)
if err != nil {
return nil, err
}
certificate, err := tls.X509KeyPair(tempCertPEM, rsaPrivPEM)
if err != nil {
return nil, err
}
return &certificate, nil
}

252
vendor/github.com/xenolf/lego/certcrypto/crypto.go generated vendored Normal file
View file

@ -0,0 +1,252 @@
package certcrypto
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/pem"
"errors"
"fmt"
"math/big"
"time"
"golang.org/x/crypto/ocsp"
)
// Constants for all key types we support.
const (
EC256 = KeyType("P256")
EC384 = KeyType("P384")
RSA2048 = KeyType("2048")
RSA4096 = KeyType("4096")
RSA8192 = KeyType("8192")
)
const (
// OCSPGood means that the certificate is valid.
OCSPGood = ocsp.Good
// OCSPRevoked means that the certificate has been deliberately revoked.
OCSPRevoked = ocsp.Revoked
// OCSPUnknown means that the OCSP responder doesn't know about the certificate.
OCSPUnknown = ocsp.Unknown
// OCSPServerFailed means that the OCSP responder failed to process the request.
OCSPServerFailed = ocsp.ServerFailed
)
// Constants for OCSP must staple
var (
tlsFeatureExtensionOID = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 24}
ocspMustStapleFeature = []byte{0x30, 0x03, 0x02, 0x01, 0x05}
)
// KeyType represents the key algo as well as the key size or curve to use.
type KeyType string
type DERCertificateBytes []byte
// ParsePEMBundle parses a certificate bundle from top to bottom and returns
// a slice of x509 certificates. This function will error if no certificates are found.
func ParsePEMBundle(bundle []byte) ([]*x509.Certificate, error) {
var certificates []*x509.Certificate
var certDERBlock *pem.Block
for {
certDERBlock, bundle = pem.Decode(bundle)
if certDERBlock == nil {
break
}
if certDERBlock.Type == "CERTIFICATE" {
cert, err := x509.ParseCertificate(certDERBlock.Bytes)
if err != nil {
return nil, err
}
certificates = append(certificates, cert)
}
}
if len(certificates) == 0 {
return nil, errors.New("no certificates were found while parsing the bundle")
}
return certificates, nil
}
func ParsePEMPrivateKey(key []byte) (crypto.PrivateKey, error) {
keyBlock, _ := pem.Decode(key)
switch keyBlock.Type {
case "RSA PRIVATE KEY":
return x509.ParsePKCS1PrivateKey(keyBlock.Bytes)
case "EC PRIVATE KEY":
return x509.ParseECPrivateKey(keyBlock.Bytes)
default:
return nil, errors.New("unknown PEM header value")
}
}
func GeneratePrivateKey(keyType KeyType) (crypto.PrivateKey, error) {
switch keyType {
case EC256:
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
case EC384:
return ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
case RSA2048:
return rsa.GenerateKey(rand.Reader, 2048)
case RSA4096:
return rsa.GenerateKey(rand.Reader, 4096)
case RSA8192:
return rsa.GenerateKey(rand.Reader, 8192)
}
return nil, fmt.Errorf("invalid KeyType: %s", keyType)
}
func GenerateCSR(privateKey crypto.PrivateKey, domain string, san []string, mustStaple bool) ([]byte, error) {
template := x509.CertificateRequest{
Subject: pkix.Name{CommonName: domain},
DNSNames: san,
}
if mustStaple {
template.ExtraExtensions = append(template.ExtraExtensions, pkix.Extension{
Id: tlsFeatureExtensionOID,
Value: ocspMustStapleFeature,
})
}
return x509.CreateCertificateRequest(rand.Reader, &template, privateKey)
}
func PEMEncode(data interface{}) []byte {
var pemBlock *pem.Block
switch key := data.(type) {
case *ecdsa.PrivateKey:
keyBytes, _ := x509.MarshalECPrivateKey(key)
pemBlock = &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyBytes}
case *rsa.PrivateKey:
pemBlock = &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}
case *x509.CertificateRequest:
pemBlock = &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: key.Raw}
case DERCertificateBytes:
pemBlock = &pem.Block{Type: "CERTIFICATE", Bytes: []byte(data.(DERCertificateBytes))}
}
return pem.EncodeToMemory(pemBlock)
}
func pemDecode(data []byte) (*pem.Block, error) {
pemBlock, _ := pem.Decode(data)
if pemBlock == nil {
return nil, fmt.Errorf("PEM decode did not yield a valid block. Is the certificate in the right format?")
}
return pemBlock, nil
}
func PemDecodeTox509CSR(pem []byte) (*x509.CertificateRequest, error) {
pemBlock, err := pemDecode(pem)
if pemBlock == nil {
return nil, err
}
if pemBlock.Type != "CERTIFICATE REQUEST" {
return nil, fmt.Errorf("PEM block is not a certificate request")
}
return x509.ParseCertificateRequest(pemBlock.Bytes)
}
// ParsePEMCertificate returns Certificate from a PEM encoded certificate.
// The certificate has to be PEM encoded. Any other encodings like DER will fail.
func ParsePEMCertificate(cert []byte) (*x509.Certificate, error) {
pemBlock, err := pemDecode(cert)
if pemBlock == nil {
return nil, err
}
// from a DER encoded certificate
return x509.ParseCertificate(pemBlock.Bytes)
}
func ExtractDomains(cert *x509.Certificate) []string {
domains := []string{cert.Subject.CommonName}
// Check for SAN certificate
for _, sanDomain := range cert.DNSNames {
if sanDomain == cert.Subject.CommonName {
continue
}
domains = append(domains, sanDomain)
}
return domains
}
func ExtractDomainsCSR(csr *x509.CertificateRequest) []string {
domains := []string{csr.Subject.CommonName}
// loop over the SubjectAltName DNS names
for _, sanName := range csr.DNSNames {
if containsSAN(domains, sanName) {
// Duplicate; skip this name
continue
}
// Name is unique
domains = append(domains, sanName)
}
return domains
}
func containsSAN(domains []string, sanName string) bool {
for _, existingName := range domains {
if existingName == sanName {
return true
}
}
return false
}
func GeneratePemCert(privateKey *rsa.PrivateKey, domain string, extensions []pkix.Extension) ([]byte, error) {
derBytes, err := generateDerCert(privateKey, time.Time{}, domain, extensions)
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}), nil
}
func generateDerCert(privateKey *rsa.PrivateKey, expiration time.Time, domain string, extensions []pkix.Extension) ([]byte, error) {
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, err
}
if expiration.IsZero() {
expiration = time.Now().Add(365)
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: "ACME Challenge TEMP",
},
NotBefore: time.Now(),
NotAfter: expiration,
KeyUsage: x509.KeyUsageKeyEncipherment,
BasicConstraintsValid: true,
DNSNames: []string{domain},
ExtraExtensions: extensions,
}
return x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
}

View file

@ -0,0 +1,69 @@
package certificate
import (
"time"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/log"
)
const (
// overallRequestLimit is the overall number of request per second
// limited on the "new-reg", "new-authz" and "new-cert" endpoints.
// From the documentation the limitation is 20 requests per second,
// but using 20 as value doesn't work but 18 do
overallRequestLimit = 18
)
func (c *Certifier) getAuthorizations(order acme.ExtendedOrder) ([]acme.Authorization, error) {
resc, errc := make(chan acme.Authorization), make(chan domainError)
delay := time.Second / overallRequestLimit
for _, authzURL := range order.Authorizations {
time.Sleep(delay)
go func(authzURL string) {
authz, err := c.core.Authorizations.Get(authzURL)
if err != nil {
errc <- domainError{Domain: authz.Identifier.Value, Error: err}
return
}
resc <- authz
}(authzURL)
}
var responses []acme.Authorization
failures := make(obtainError)
for i := 0; i < len(order.Authorizations); i++ {
select {
case res := <-resc:
responses = append(responses, res)
case err := <-errc:
failures[err.Domain] = err.Error
}
}
for i, auth := range order.Authorizations {
log.Infof("[%s] AuthURL: %s", order.Identifiers[i].Value, auth)
}
close(resc)
close(errc)
// be careful to not return an empty failures map;
// even if empty, they become non-nil error values
if len(failures) > 0 {
return responses, failures
}
return responses, nil
}
func (c *Certifier) deactivateAuthorizations(order acme.ExtendedOrder) {
for _, auth := range order.Authorizations {
if err := c.core.Authorizations.Deactivate(auth); err != nil {
log.Infof("Unable to deactivated authorizations: %s", auth)
}
}
}

View file

@ -0,0 +1,493 @@
package certificate
import (
"bytes"
"crypto"
"crypto/x509"
"encoding/base64"
"errors"
"fmt"
"io/ioutil"
"net/http"
"strings"
"time"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/acme/api"
"github.com/xenolf/lego/certcrypto"
"github.com/xenolf/lego/challenge"
"github.com/xenolf/lego/log"
"golang.org/x/crypto/ocsp"
"golang.org/x/net/idna"
)
// maxBodySize is the maximum size of body that we will read.
const maxBodySize = 1024 * 1024
// Resource represents a CA issued certificate.
// PrivateKey, Certificate and IssuerCertificate are all
// already PEM encoded and can be directly written to disk.
// Certificate may be a certificate bundle,
// depending on the options supplied to create it.
type Resource struct {
Domain string `json:"domain"`
CertURL string `json:"certUrl"`
CertStableURL string `json:"certStableUrl"`
PrivateKey []byte `json:"-"`
Certificate []byte `json:"-"`
IssuerCertificate []byte `json:"-"`
CSR []byte `json:"-"`
}
// ObtainRequest The request to obtain certificate.
//
// The first domain in domains is used for the CommonName field of the certificate,
// all other domains are added using the Subject Alternate Names extension.
//
// A new private key is generated for every invocation of the function Obtain.
// If you do not want that you can supply your own private key in the privateKey parameter.
// If this parameter is non-nil it will be used instead of generating a new one.
//
// If bundle is true, the []byte contains both the issuer certificate and your issued certificate as a bundle.
type ObtainRequest struct {
Domains []string
Bundle bool
PrivateKey crypto.PrivateKey
MustStaple bool
}
type resolver interface {
Solve(authorizations []acme.Authorization) error
}
type Certifier struct {
core *api.Core
keyType certcrypto.KeyType
resolver resolver
}
func NewCertifier(core *api.Core, keyType certcrypto.KeyType, resolver resolver) *Certifier {
return &Certifier{
core: core,
keyType: keyType,
resolver: resolver,
}
}
// Obtain tries to obtain a single certificate using all domains passed into it.
//
// This function will never return a partial certificate.
// If one domain in the list fails, the whole certificate will fail.
func (c *Certifier) Obtain(request ObtainRequest) (*Resource, error) {
if len(request.Domains) == 0 {
return nil, errors.New("no domains to obtain a certificate for")
}
domains := sanitizeDomain(request.Domains)
if request.Bundle {
log.Infof("[%s] acme: Obtaining bundled SAN certificate", strings.Join(domains, ", "))
} else {
log.Infof("[%s] acme: Obtaining SAN certificate", strings.Join(domains, ", "))
}
order, err := c.core.Orders.New(domains)
if err != nil {
return nil, err
}
authz, err := c.getAuthorizations(order)
if err != nil {
// If any challenge fails, return. Do not generate partial SAN certificates.
c.deactivateAuthorizations(order)
return nil, err
}
err = c.resolver.Solve(authz)
if err != nil {
// If any challenge fails, return. Do not generate partial SAN certificates.
return nil, err
}
log.Infof("[%s] acme: Validations succeeded; requesting certificates", strings.Join(domains, ", "))
failures := make(obtainError)
cert, err := c.getForOrder(domains, order, request.Bundle, request.PrivateKey, request.MustStaple)
if err != nil {
for _, auth := range authz {
failures[challenge.GetTargetedDomain(auth)] = err
}
}
// Do not return an empty failures map, because
// it would still be a non-nil error value
if len(failures) > 0 {
return cert, failures
}
return cert, nil
}
// ObtainForCSR tries to obtain a certificate matching the CSR passed into it.
//
// The domains are inferred from the CommonName and SubjectAltNames, if any.
// The private key for this CSR is not required.
//
// If bundle is true, the []byte contains both the issuer certificate and your issued certificate as a bundle.
//
// This function will never return a partial certificate.
// If one domain in the list fails, the whole certificate will fail.
func (c *Certifier) ObtainForCSR(csr x509.CertificateRequest, bundle bool) (*Resource, error) {
// figure out what domains it concerns
// start with the common name
domains := certcrypto.ExtractDomainsCSR(&csr)
if bundle {
log.Infof("[%s] acme: Obtaining bundled SAN certificate given a CSR", strings.Join(domains, ", "))
} else {
log.Infof("[%s] acme: Obtaining SAN certificate given a CSR", strings.Join(domains, ", "))
}
order, err := c.core.Orders.New(domains)
if err != nil {
return nil, err
}
authz, err := c.getAuthorizations(order)
if err != nil {
// If any challenge fails, return. Do not generate partial SAN certificates.
c.deactivateAuthorizations(order)
return nil, err
}
err = c.resolver.Solve(authz)
if err != nil {
// If any challenge fails, return. Do not generate partial SAN certificates.
return nil, err
}
log.Infof("[%s] acme: Validations succeeded; requesting certificates", strings.Join(domains, ", "))
failures := make(obtainError)
cert, err := c.getForCSR(domains, order, bundle, csr.Raw, nil)
if err != nil {
for _, auth := range authz {
failures[challenge.GetTargetedDomain(auth)] = err
}
}
if cert != nil {
// Add the CSR to the certificate so that it can be used for renewals.
cert.CSR = certcrypto.PEMEncode(&csr)
}
// Do not return an empty failures map,
// because it would still be a non-nil error value
if len(failures) > 0 {
return cert, failures
}
return cert, nil
}
func (c *Certifier) getForOrder(domains []string, order acme.ExtendedOrder, bundle bool, privateKey crypto.PrivateKey, mustStaple bool) (*Resource, error) {
if privateKey == nil {
var err error
privateKey, err = certcrypto.GeneratePrivateKey(c.keyType)
if err != nil {
return nil, err
}
}
// Determine certificate name(s) based on the authorization resources
commonName := domains[0]
// ACME draft Section 7.4 "Applying for Certificate Issuance"
// https://tools.ietf.org/html/draft-ietf-acme-acme-12#section-7.4
// says:
// Clients SHOULD NOT make any assumptions about the sort order of
// "identifiers" or "authorizations" elements in the returned order
// object.
san := []string{commonName}
for _, auth := range order.Identifiers {
if auth.Value != commonName {
san = append(san, auth.Value)
}
}
// TODO: should the CSR be customizable?
csr, err := certcrypto.GenerateCSR(privateKey, commonName, san, mustStaple)
if err != nil {
return nil, err
}
return c.getForCSR(domains, order, bundle, csr, certcrypto.PEMEncode(privateKey))
}
func (c *Certifier) getForCSR(domains []string, order acme.ExtendedOrder, bundle bool, csr []byte, privateKeyPem []byte) (*Resource, error) {
respOrder, err := c.core.Orders.UpdateForCSR(order.Finalize, csr)
if err != nil {
return nil, err
}
commonName := domains[0]
certRes := &Resource{
Domain: commonName,
CertURL: respOrder.Certificate,
PrivateKey: privateKeyPem,
}
if respOrder.Status == acme.StatusValid {
// if the certificate is available right away, short cut!
ok, err := c.checkResponse(respOrder, certRes, bundle)
if err != nil {
return nil, err
}
if ok {
return certRes, nil
}
}
return c.waitForCertificate(certRes, order.Location, bundle)
}
func (c *Certifier) waitForCertificate(certRes *Resource, orderURL string, bundle bool) (*Resource, error) {
stopTimer := time.NewTimer(30 * time.Second)
defer stopTimer.Stop()
retryTick := time.NewTicker(500 * time.Millisecond)
defer retryTick.Stop()
for {
select {
case <-stopTimer.C:
return nil, errors.New("certificate polling timed out")
case <-retryTick.C:
order, err := c.core.Orders.Get(orderURL)
if err != nil {
return nil, err
}
done, err := c.checkResponse(order, certRes, bundle)
if err != nil {
return nil, err
}
if done {
return certRes, nil
}
}
}
}
// checkResponse checks to see if the certificate is ready and a link is contained in the response.
//
// If so, loads it into certRes and returns true.
// If the cert is not yet ready, it returns false.
//
// The certRes input should already have the Domain (common name) field populated.
//
// If bundle is true, the certificate will be bundled with the issuer's cert.
func (c *Certifier) checkResponse(order acme.Order, certRes *Resource, bundle bool) (bool, error) {
valid, err := checkOrderStatus(order)
if err != nil || !valid {
return valid, err
}
cert, issuer, err := c.core.Certificates.Get(order.Certificate, bundle)
if err != nil {
return false, err
}
log.Infof("[%s] Server responded with a certificate.", certRes.Domain)
certRes.IssuerCertificate = issuer
certRes.Certificate = cert
certRes.CertURL = order.Certificate
certRes.CertStableURL = order.Certificate
return true, nil
}
// Revoke takes a PEM encoded certificate or bundle and tries to revoke it at the CA.
func (c *Certifier) Revoke(cert []byte) error {
certificates, err := certcrypto.ParsePEMBundle(cert)
if err != nil {
return err
}
x509Cert := certificates[0]
if x509Cert.IsCA {
return fmt.Errorf("certificate bundle starts with a CA certificate")
}
revokeMsg := acme.RevokeCertMessage{
Certificate: base64.RawURLEncoding.EncodeToString(x509Cert.Raw),
}
return c.core.Certificates.Revoke(revokeMsg)
}
// Renew takes a Resource and tries to renew the certificate.
//
// If the renewal process succeeds, the new certificate will ge returned in a new CertResource.
// Please be aware that this function will return a new certificate in ANY case that is not an error.
// If the server does not provide us with a new cert on a GET request to the CertURL
// this function will start a new-cert flow where a new certificate gets generated.
//
// If bundle is true, the []byte contains both the issuer certificate and your issued certificate as a bundle.
//
// For private key reuse the PrivateKey property of the passed in Resource should be non-nil.
func (c *Certifier) Renew(certRes Resource, bundle, mustStaple bool) (*Resource, error) {
// Input certificate is PEM encoded.
// Decode it here as we may need the decoded cert later on in the renewal process.
// The input may be a bundle or a single certificate.
certificates, err := certcrypto.ParsePEMBundle(certRes.Certificate)
if err != nil {
return nil, err
}
x509Cert := certificates[0]
if x509Cert.IsCA {
return nil, fmt.Errorf("[%s] Certificate bundle starts with a CA certificate", certRes.Domain)
}
// This is just meant to be informal for the user.
timeLeft := x509Cert.NotAfter.Sub(time.Now().UTC())
log.Infof("[%s] acme: Trying renewal with %d hours remaining", certRes.Domain, int(timeLeft.Hours()))
// We always need to request a new certificate to renew.
// Start by checking to see if the certificate was based off a CSR,
// and use that if it's defined.
if len(certRes.CSR) > 0 {
csr, errP := certcrypto.PemDecodeTox509CSR(certRes.CSR)
if errP != nil {
return nil, errP
}
return c.ObtainForCSR(*csr, bundle)
}
var privateKey crypto.PrivateKey
if certRes.PrivateKey != nil {
privateKey, err = certcrypto.ParsePEMPrivateKey(certRes.PrivateKey)
if err != nil {
return nil, err
}
}
query := ObtainRequest{
Domains: certcrypto.ExtractDomains(x509Cert),
Bundle: bundle,
PrivateKey: privateKey,
MustStaple: mustStaple,
}
return c.Obtain(query)
}
// GetOCSP takes a PEM encoded cert or cert bundle returning the raw OCSP response,
// the parsed response, and an error, if any.
//
// The returned []byte can be passed directly into the OCSPStaple property of a tls.Certificate.
// If the bundle only contains the issued certificate,
// this function will try to get the issuer certificate from the IssuingCertificateURL in the certificate.
//
// If the []byte and/or ocsp.Response return values are nil, the OCSP status may be assumed OCSPUnknown.
func (c *Certifier) GetOCSP(bundle []byte) ([]byte, *ocsp.Response, error) {
certificates, err := certcrypto.ParsePEMBundle(bundle)
if err != nil {
return nil, nil, err
}
// We expect the certificate slice to be ordered downwards the chain.
// SRV CRT -> CA. We need to pull the leaf and issuer certs out of it,
// which should always be the first two certificates.
// If there's no OCSP server listed in the leaf cert, there's nothing to do.
// And if we have only one certificate so far, we need to get the issuer cert.
issuedCert := certificates[0]
if len(issuedCert.OCSPServer) == 0 {
return nil, nil, errors.New("no OCSP server specified in cert")
}
if len(certificates) == 1 {
// TODO: build fallback. If this fails, check the remaining array entries.
if len(issuedCert.IssuingCertificateURL) == 0 {
return nil, nil, errors.New("no issuing certificate URL")
}
resp, errC := c.core.HTTPClient.Get(issuedCert.IssuingCertificateURL[0])
if errC != nil {
return nil, nil, errC
}
defer resp.Body.Close()
issuerBytes, errC := ioutil.ReadAll(http.MaxBytesReader(nil, resp.Body, maxBodySize))
if errC != nil {
return nil, nil, errC
}
issuerCert, errC := x509.ParseCertificate(issuerBytes)
if errC != nil {
return nil, nil, errC
}
// Insert it into the slice on position 0
// We want it ordered right SRV CRT -> CA
certificates = append(certificates, issuerCert)
}
issuerCert := certificates[1]
// Finally kick off the OCSP request.
ocspReq, err := ocsp.CreateRequest(issuedCert, issuerCert, nil)
if err != nil {
return nil, nil, err
}
resp, err := c.core.HTTPClient.Post(issuedCert.OCSPServer[0], "application/ocsp-request", bytes.NewReader(ocspReq))
if err != nil {
return nil, nil, err
}
defer resp.Body.Close()
ocspResBytes, err := ioutil.ReadAll(http.MaxBytesReader(nil, resp.Body, maxBodySize))
if err != nil {
return nil, nil, err
}
ocspRes, err := ocsp.ParseResponse(ocspResBytes, issuerCert)
if err != nil {
return nil, nil, err
}
return ocspResBytes, ocspRes, nil
}
func checkOrderStatus(order acme.Order) (bool, error) {
switch order.Status {
case acme.StatusValid:
return true, nil
case acme.StatusInvalid:
return false, order.Error
default:
return false, nil
}
}
// https://tools.ietf.org/html/draft-ietf-acme-acme-16#section-7.1.4
// The domain name MUST be encoded
// in the form in which it would appear in a certificate. That is, it
// MUST be encoded according to the rules in Section 7 of [RFC5280].
//
// https://tools.ietf.org/html/rfc5280#section-7
func sanitizeDomain(domains []string) []string {
var sanitizedDomains []string
for _, domain := range domains {
sanitizedDomain, err := idna.ToASCII(domain)
if err != nil {
log.Infof("skip domain %q: unable to sanitize (punnycode): %v", domain, err)
} else {
sanitizedDomains = append(sanitizedDomains, sanitizedDomain)
}
}
return sanitizedDomains
}

30
vendor/github.com/xenolf/lego/certificate/errors.go generated vendored Normal file
View file

@ -0,0 +1,30 @@
package certificate
import (
"bytes"
"fmt"
"sort"
)
// obtainError is returned when there are specific errors available per domain.
type obtainError map[string]error
func (e obtainError) Error() string {
buffer := bytes.NewBufferString("acme: Error -> One or more domains had a problem:\n")
var domains []string
for domain := range e {
domains = append(domains, domain)
}
sort.Strings(domains)
for _, domain := range domains {
buffer.WriteString(fmt.Sprintf("[%s] %s\n", domain, e[domain]))
}
return buffer.String()
}
type domainError struct {
Domain string
Error error
}

44
vendor/github.com/xenolf/lego/challenge/challenges.go generated vendored Normal file
View file

@ -0,0 +1,44 @@
package challenge
import (
"fmt"
"github.com/xenolf/lego/acme"
)
// Type is a string that identifies a particular challenge type and version of ACME challenge.
type Type string
const (
// HTTP01 is the "http-01" ACME challenge https://tools.ietf.org/html/draft-ietf-acme-acme-16#section-8.3
// Note: ChallengePath returns the URL path to fulfill this challenge
HTTP01 = Type("http-01")
// DNS01 is the "dns-01" ACME challenge https://tools.ietf.org/html/draft-ietf-acme-acme-16#section-8.4
// Note: GetRecord returns a DNS record which will fulfill this challenge
DNS01 = Type("dns-01")
// TLSALPN01 is the "tls-alpn-01" ACME challenge https://tools.ietf.org/html/draft-ietf-acme-tls-alpn-05
TLSALPN01 = Type("tls-alpn-01")
)
func (t Type) String() string {
return string(t)
}
func FindChallenge(chlgType Type, authz acme.Authorization) (acme.Challenge, error) {
for _, chlg := range authz.Challenges {
if chlg.Type == string(chlgType) {
return chlg, nil
}
}
return acme.Challenge{}, fmt.Errorf("[%s] acme: unable to find challenge %s", GetTargetedDomain(authz), chlgType)
}
func GetTargetedDomain(authz acme.Authorization) string {
if authz.Wildcard {
return "*." + authz.Identifier.Value
}
return authz.Identifier.Value
}

View file

@ -0,0 +1,176 @@
package dns01
import (
"crypto/sha256"
"encoding/base64"
"fmt"
"time"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/acme/api"
"github.com/xenolf/lego/challenge"
"github.com/xenolf/lego/log"
"github.com/xenolf/lego/platform/wait"
)
const (
// DefaultPropagationTimeout default propagation timeout
DefaultPropagationTimeout = 60 * time.Second
// DefaultPollingInterval default polling interval
DefaultPollingInterval = 2 * time.Second
// DefaultTTL default TTL
DefaultTTL = 120
)
type ValidateFunc func(core *api.Core, domain string, chlng acme.Challenge) error
type ChallengeOption func(*Challenge) error
// CondOption Conditional challenge option.
func CondOption(condition bool, opt ChallengeOption) ChallengeOption {
if !condition {
// NoOp options
return func(*Challenge) error {
return nil
}
}
return opt
}
// Challenge implements the dns-01 challenge
type Challenge struct {
core *api.Core
validate ValidateFunc
provider challenge.Provider
preCheck preCheck
dnsTimeout time.Duration
}
func NewChallenge(core *api.Core, validate ValidateFunc, provider challenge.Provider, opts ...ChallengeOption) *Challenge {
chlg := &Challenge{
core: core,
validate: validate,
provider: provider,
preCheck: newPreCheck(),
dnsTimeout: 10 * time.Second,
}
for _, opt := range opts {
err := opt(chlg)
if err != nil {
log.Infof("challenge option error: %v", err)
}
}
return chlg
}
// PreSolve just submits the txt record to the dns provider.
// It does not validate record propagation, or do anything at all with the acme server.
func (c *Challenge) PreSolve(authz acme.Authorization) error {
domain := challenge.GetTargetedDomain(authz)
log.Infof("[%s] acme: Preparing to solve DNS-01", domain)
chlng, err := challenge.FindChallenge(challenge.DNS01, authz)
if err != nil {
return err
}
if c.provider == nil {
return fmt.Errorf("[%s] acme: no DNS Provider configured", domain)
}
// Generate the Key Authorization for the challenge
keyAuth, err := c.core.GetKeyAuthorization(chlng.Token)
if err != nil {
return err
}
err = c.provider.Present(authz.Identifier.Value, chlng.Token, keyAuth)
if err != nil {
return fmt.Errorf("[%s] acme: error presenting token: %s", domain, err)
}
return nil
}
func (c *Challenge) Solve(authz acme.Authorization) error {
domain := challenge.GetTargetedDomain(authz)
log.Infof("[%s] acme: Trying to solve DNS-01", domain)
chlng, err := challenge.FindChallenge(challenge.DNS01, authz)
if err != nil {
return err
}
// Generate the Key Authorization for the challenge
keyAuth, err := c.core.GetKeyAuthorization(chlng.Token)
if err != nil {
return err
}
fqdn, value := GetRecord(authz.Identifier.Value, keyAuth)
var timeout, interval time.Duration
switch provider := c.provider.(type) {
case challenge.ProviderTimeout:
timeout, interval = provider.Timeout()
default:
timeout, interval = DefaultPropagationTimeout, DefaultPollingInterval
}
log.Infof("[%s] acme: Checking DNS record propagation using %+v", domain, recursiveNameservers)
err = wait.For("propagation", timeout, interval, func() (bool, error) {
stop, errP := c.preCheck.call(fqdn, value)
if !stop || errP != nil {
log.Infof("[%s] acme: Waiting for DNS record propagation.", domain)
}
return stop, errP
})
if err != nil {
return err
}
chlng.KeyAuthorization = keyAuth
return c.validate(c.core, authz.Identifier.Value, chlng)
}
// CleanUp cleans the challenge.
func (c *Challenge) CleanUp(authz acme.Authorization) error {
log.Infof("[%s] acme: Cleaning DNS-01 challenge", challenge.GetTargetedDomain(authz))
chlng, err := challenge.FindChallenge(challenge.DNS01, authz)
if err != nil {
return err
}
keyAuth, err := c.core.GetKeyAuthorization(chlng.Token)
if err != nil {
return err
}
return c.provider.CleanUp(authz.Identifier.Value, chlng.Token, keyAuth)
}
func (c *Challenge) Sequential() (bool, time.Duration) {
if p, ok := c.provider.(sequential); ok {
return ok, p.Sequential()
}
return false, 0
}
type sequential interface {
Sequential() time.Duration
}
// GetRecord returns a DNS record which will fulfill the `dns-01` challenge
func GetRecord(domain, keyAuth string) (fqdn string, value string) {
keyAuthShaBytes := sha256.Sum256([]byte(keyAuth))
// base64URL encoding without padding
value = base64.RawURLEncoding.EncodeToString(keyAuthShaBytes[:sha256.Size])
fqdn = fmt.Sprintf("_acme-challenge.%s.", domain)
return
}

View file

@ -0,0 +1,52 @@
package dns01
import (
"bufio"
"fmt"
"os"
)
const (
dnsTemplate = `%s %d IN TXT "%s"`
)
// DNSProviderManual is an implementation of the ChallengeProvider interface
type DNSProviderManual struct{}
// NewDNSProviderManual returns a DNSProviderManual instance.
func NewDNSProviderManual() (*DNSProviderManual, error) {
return &DNSProviderManual{}, nil
}
// Present prints instructions for manually creating the TXT record
func (*DNSProviderManual) Present(domain, token, keyAuth string) error {
fqdn, value := GetRecord(domain, keyAuth)
authZone, err := FindZoneByFqdn(fqdn)
if err != nil {
return err
}
fmt.Printf("lego: Please create the following TXT record in your %s zone:\n", authZone)
fmt.Printf(dnsTemplate+"\n", fqdn, DefaultTTL, value)
fmt.Printf("lego: Press 'Enter' when you are done\n")
_, err = bufio.NewReader(os.Stdin).ReadBytes('\n')
return err
}
// CleanUp prints instructions for manually removing the TXT record
func (*DNSProviderManual) CleanUp(domain, token, keyAuth string) error {
fqdn, _ := GetRecord(domain, keyAuth)
authZone, err := FindZoneByFqdn(fqdn)
if err != nil {
return err
}
fmt.Printf("lego: You can now remove this TXT record from your %s zone:\n", authZone)
fmt.Printf(dnsTemplate+"\n", fqdn, DefaultTTL, "...")
return nil
}

19
vendor/github.com/xenolf/lego/challenge/dns01/fqdn.go generated vendored Normal file
View file

@ -0,0 +1,19 @@
package dns01
// ToFqdn converts the name into a fqdn appending a trailing dot.
func ToFqdn(name string) string {
n := len(name)
if n == 0 || name[n-1] == '.' {
return name
}
return name + "."
}
// UnFqdn converts the fqdn into a name removing the trailing dot.
func UnFqdn(name string) string {
n := len(name)
if n != 0 && name[n-1] == '.' {
return name[:n-1]
}
return name
}

View file

@ -0,0 +1,232 @@
package dns01
import (
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/miekg/dns"
)
const defaultResolvConf = "/etc/resolv.conf"
// dnsTimeout is used to override the default DNS timeout of 10 seconds.
var dnsTimeout = 10 * time.Second
var (
fqdnToZone = map[string]string{}
muFqdnToZone sync.Mutex
)
var defaultNameservers = []string{
"google-public-dns-a.google.com:53",
"google-public-dns-b.google.com:53",
}
// recursiveNameservers are used to pre-check DNS propagation
var recursiveNameservers = getNameservers(defaultResolvConf, defaultNameservers)
// ClearFqdnCache clears the cache of fqdn to zone mappings. Primarily used in testing.
func ClearFqdnCache() {
muFqdnToZone.Lock()
fqdnToZone = map[string]string{}
muFqdnToZone.Unlock()
}
func AddDNSTimeout(timeout time.Duration) ChallengeOption {
return func(_ *Challenge) error {
dnsTimeout = timeout
return nil
}
}
func AddRecursiveNameservers(nameservers []string) ChallengeOption {
return func(_ *Challenge) error {
recursiveNameservers = ParseNameservers(nameservers)
return nil
}
}
// getNameservers attempts to get systems nameservers before falling back to the defaults
func getNameservers(path string, defaults []string) []string {
config, err := dns.ClientConfigFromFile(path)
if err != nil || len(config.Servers) == 0 {
return defaults
}
return ParseNameservers(config.Servers)
}
func ParseNameservers(servers []string) []string {
var resolvers []string
for _, resolver := range servers {
// ensure all servers have a port number
if _, _, err := net.SplitHostPort(resolver); err != nil {
resolvers = append(resolvers, net.JoinHostPort(resolver, "53"))
} else {
resolvers = append(resolvers, resolver)
}
}
return resolvers
}
// lookupNameservers returns the authoritative nameservers for the given fqdn.
func lookupNameservers(fqdn string) ([]string, error) {
var authoritativeNss []string
zone, err := FindZoneByFqdn(fqdn)
if err != nil {
return nil, fmt.Errorf("could not determine the zone: %v", err)
}
r, err := dnsQuery(zone, dns.TypeNS, recursiveNameservers, true)
if err != nil {
return nil, err
}
for _, rr := range r.Answer {
if ns, ok := rr.(*dns.NS); ok {
authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns))
}
}
if len(authoritativeNss) > 0 {
return authoritativeNss, nil
}
return nil, fmt.Errorf("could not determine authoritative nameservers")
}
// FindZoneByFqdn determines the zone apex for the given fqdn
// by recursing up the domain labels until the nameserver returns a SOA record in the answer section.
func FindZoneByFqdn(fqdn string) (string, error) {
return FindZoneByFqdnCustom(fqdn, recursiveNameservers)
}
// FindZoneByFqdnCustom determines the zone apex for the given fqdn
// by recursing up the domain labels until the nameserver returns a SOA record in the answer section.
func FindZoneByFqdnCustom(fqdn string, nameservers []string) (string, error) {
muFqdnToZone.Lock()
defer muFqdnToZone.Unlock()
// Do we have it cached?
if zone, ok := fqdnToZone[fqdn]; ok {
return zone, nil
}
var err error
var in *dns.Msg
labelIndexes := dns.Split(fqdn)
for _, index := range labelIndexes {
domain := fqdn[index:]
in, err = dnsQuery(domain, dns.TypeSOA, nameservers, true)
if err != nil {
continue
}
if in == nil {
continue
}
switch in.Rcode {
case dns.RcodeSuccess:
// Check if we got a SOA RR in the answer section
if len(in.Answer) == 0 {
continue
}
// CNAME records cannot/should not exist at the root of a zone.
// So we skip a domain when a CNAME is found.
if dnsMsgContainsCNAME(in) {
continue
}
for _, ans := range in.Answer {
if soa, ok := ans.(*dns.SOA); ok {
zone := soa.Hdr.Name
fqdnToZone[fqdn] = zone
return zone, nil
}
}
case dns.RcodeNameError:
// NXDOMAIN
default:
// Any response code other than NOERROR and NXDOMAIN is treated as error
return "", fmt.Errorf("unexpected response code '%s' for %s", dns.RcodeToString[in.Rcode], domain)
}
}
return "", fmt.Errorf("could not find the start of authority for %s%s", fqdn, formatDNSError(in, err))
}
// dnsMsgContainsCNAME checks for a CNAME answer in msg
func dnsMsgContainsCNAME(msg *dns.Msg) bool {
for _, ans := range msg.Answer {
if _, ok := ans.(*dns.CNAME); ok {
return true
}
}
return false
}
func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (*dns.Msg, error) {
m := createDNSMsg(fqdn, rtype, recursive)
var in *dns.Msg
var err error
for _, ns := range nameservers {
in, err = sendDNSQuery(m, ns)
if err == nil && len(in.Answer) > 0 {
break
}
}
return in, err
}
func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {
m := new(dns.Msg)
m.SetQuestion(fqdn, rtype)
m.SetEdns0(4096, false)
if !recursive {
m.RecursionDesired = false
}
return m
}
func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) {
udp := &dns.Client{Net: "udp", Timeout: dnsTimeout}
in, _, err := udp.Exchange(m, ns)
if in != nil && in.Truncated {
tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
// If the TCP request succeeds, the err will reset to nil
in, _, err = tcp.Exchange(m, ns)
}
return in, err
}
func formatDNSError(msg *dns.Msg, err error) string {
var parts []string
if msg != nil {
parts = append(parts, dns.RcodeToString[msg.Rcode])
}
if err != nil {
parts = append(parts, fmt.Sprintf("%v", err))
}
if len(parts) > 0 {
return ": " + strings.Join(parts, " ")
}
return ""
}

View file

@ -0,0 +1,114 @@
package dns01
import (
"fmt"
"net"
"strings"
"github.com/miekg/dns"
)
// PreCheckFunc checks DNS propagation before notifying ACME that the DNS challenge is ready.
type PreCheckFunc func(fqdn, value string) (bool, error)
func AddPreCheck(preCheck PreCheckFunc) ChallengeOption {
// Prevent race condition
check := preCheck
return func(chlg *Challenge) error {
chlg.preCheck.checkFunc = check
return nil
}
}
func DisableCompletePropagationRequirement() ChallengeOption {
return func(chlg *Challenge) error {
chlg.preCheck.requireCompletePropagation = false
return nil
}
}
type preCheck struct {
// checks DNS propagation before notifying ACME that the DNS challenge is ready.
checkFunc PreCheckFunc
// require the TXT record to be propagated to all authoritative name servers
requireCompletePropagation bool
}
func newPreCheck() preCheck {
return preCheck{
requireCompletePropagation: true,
}
}
func (p preCheck) call(fqdn, value string) (bool, error) {
if p.checkFunc == nil {
return p.checkDNSPropagation(fqdn, value)
}
return p.checkFunc(fqdn, value)
}
// checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers.
func (p preCheck) checkDNSPropagation(fqdn, value string) (bool, error) {
// Initial attempt to resolve at the recursive NS
r, err := dnsQuery(fqdn, dns.TypeTXT, recursiveNameservers, true)
if err != nil {
return false, err
}
if !p.requireCompletePropagation {
return true, nil
}
if r.Rcode == dns.RcodeSuccess {
// If we see a CNAME here then use the alias
for _, rr := range r.Answer {
if cn, ok := rr.(*dns.CNAME); ok {
if cn.Hdr.Name == fqdn {
fqdn = cn.Target
break
}
}
}
}
authoritativeNss, err := lookupNameservers(fqdn)
if err != nil {
return false, err
}
return checkAuthoritativeNss(fqdn, value, authoritativeNss)
}
// checkAuthoritativeNss queries each of the given nameservers for the expected TXT record.
func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, error) {
for _, ns := range nameservers {
r, err := dnsQuery(fqdn, dns.TypeTXT, []string{net.JoinHostPort(ns, "53")}, false)
if err != nil {
return false, err
}
if r.Rcode != dns.RcodeSuccess {
return false, fmt.Errorf("NS %s returned %s for %s", ns, dns.RcodeToString[r.Rcode], fqdn)
}
var records []string
var found bool
for _, rr := range r.Answer {
if txt, ok := rr.(*dns.TXT); ok {
record := strings.Join(txt.Txt, "")
records = append(records, record)
if record == value {
found = true
break
}
}
}
if !found {
return false, fmt.Errorf("NS %s did not return the expected TXT record [fqdn: %s, value: %s]: %s", ns, fqdn, value, strings.Join(records, " ,"))
}
}
return true, nil
}

View file

@ -0,0 +1,65 @@
package http01
import (
"fmt"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/acme/api"
"github.com/xenolf/lego/challenge"
"github.com/xenolf/lego/log"
)
type ValidateFunc func(core *api.Core, domain string, chlng acme.Challenge) error
// ChallengePath returns the URL path for the `http-01` challenge
func ChallengePath(token string) string {
return "/.well-known/acme-challenge/" + token
}
type Challenge struct {
core *api.Core
validate ValidateFunc
provider challenge.Provider
}
func NewChallenge(core *api.Core, validate ValidateFunc, provider challenge.Provider) *Challenge {
return &Challenge{
core: core,
validate: validate,
provider: provider,
}
}
func (c *Challenge) SetProvider(provider challenge.Provider) {
c.provider = provider
}
func (c *Challenge) Solve(authz acme.Authorization) error {
domain := challenge.GetTargetedDomain(authz)
log.Infof("[%s] acme: Trying to solve HTTP-01", domain)
chlng, err := challenge.FindChallenge(challenge.HTTP01, authz)
if err != nil {
return err
}
// Generate the Key Authorization for the challenge
keyAuth, err := c.core.GetKeyAuthorization(chlng.Token)
if err != nil {
return err
}
err = c.provider.Present(authz.Identifier.Value, chlng.Token, keyAuth)
if err != nil {
return fmt.Errorf("[%s] acme: error presenting token: %v", domain, err)
}
defer func() {
err := c.provider.CleanUp(authz.Identifier.Value, chlng.Token, keyAuth)
if err != nil {
log.Warnf("[%s] acme: error cleaning up: %v", domain, err)
}
}()
chlng.KeyAuthorization = keyAuth
return c.validate(c.core, authz.Identifier.Value, chlng)
}

View file

@ -1,4 +1,4 @@
package acme package http01
import ( import (
"fmt" "fmt"
@ -9,31 +9,31 @@ import (
"github.com/xenolf/lego/log" "github.com/xenolf/lego/log"
) )
// HTTPProviderServer implements ChallengeProvider for `http-01` challenge // ProviderServer implements ChallengeProvider for `http-01` challenge
// It may be instantiated without using the NewHTTPProviderServer function if // It may be instantiated without using the NewProviderServer function if
// you want only to use the default values. // you want only to use the default values.
type HTTPProviderServer struct { type ProviderServer struct {
iface string iface string
port string port string
done chan bool done chan bool
listener net.Listener listener net.Listener
} }
// NewHTTPProviderServer creates a new HTTPProviderServer on the selected interface and port. // NewProviderServer creates a new ProviderServer on the selected interface and port.
// Setting iface and / or port to an empty string will make the server fall back to // Setting iface and / or port to an empty string will make the server fall back to
// the "any" interface and port 80 respectively. // the "any" interface and port 80 respectively.
func NewHTTPProviderServer(iface, port string) *HTTPProviderServer { func NewProviderServer(iface, port string) *ProviderServer {
return &HTTPProviderServer{iface: iface, port: port} return &ProviderServer{iface: iface, port: port}
} }
// Present starts a web server and makes the token available at `HTTP01ChallengePath(token)` for web requests. // Present starts a web server and makes the token available at `ChallengePath(token)` for web requests.
func (s *HTTPProviderServer) Present(domain, token, keyAuth string) error { func (s *ProviderServer) Present(domain, token, keyAuth string) error {
if s.port == "" { if s.port == "" {
s.port = "80" s.port = "80"
} }
var err error var err error
s.listener, err = net.Listen("tcp", net.JoinHostPort(s.iface, s.port)) s.listener, err = net.Listen("tcp", s.GetAddress())
if err != nil { if err != nil {
return fmt.Errorf("could not start HTTP server for challenge -> %v", err) return fmt.Errorf("could not start HTTP server for challenge -> %v", err)
} }
@ -43,8 +43,12 @@ func (s *HTTPProviderServer) Present(domain, token, keyAuth string) error {
return nil return nil
} }
// CleanUp closes the HTTP server and removes the token from `HTTP01ChallengePath(token)` func (s *ProviderServer) GetAddress() string {
func (s *HTTPProviderServer) CleanUp(domain, token, keyAuth string) error { return net.JoinHostPort(s.iface, s.port)
}
// CleanUp closes the HTTP server and removes the token from `ChallengePath(token)`
func (s *ProviderServer) CleanUp(domain, token, keyAuth string) error {
if s.listener == nil { if s.listener == nil {
return nil return nil
} }
@ -53,8 +57,8 @@ func (s *HTTPProviderServer) CleanUp(domain, token, keyAuth string) error {
return nil return nil
} }
func (s *HTTPProviderServer) serve(domain, token, keyAuth string) { func (s *ProviderServer) serve(domain, token, keyAuth string) {
path := HTTP01ChallengePath(token) path := ChallengePath(token)
// The handler validates the HOST header and request type. // The handler validates the HOST header and request type.
// For validation it then writes the token the server returned with the challenge // For validation it then writes the token the server returned with the challenge
@ -80,12 +84,12 @@ func (s *HTTPProviderServer) serve(domain, token, keyAuth string) {
httpServer := &http.Server{Handler: mux} httpServer := &http.Server{Handler: mux}
// Once httpServer is shut down we don't want any lingering // Once httpServer is shut down
// connections, so disable KeepAlives. // we don't want any lingering connections, so disable KeepAlives.
httpServer.SetKeepAlivesEnabled(false) httpServer.SetKeepAlivesEnabled(false)
err := httpServer.Serve(s.listener) err := httpServer.Serve(s.listener)
if err != nil { if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
log.Println(err) log.Println(err)
} }
s.done <- true s.done <- true

View file

@ -1,28 +1,28 @@
package acme package challenge
import "time" import "time"
// ChallengeProvider enables implementing a custom challenge // Provider enables implementing a custom challenge
// provider. Present presents the solution to a challenge available to // provider. Present presents the solution to a challenge available to
// be solved. CleanUp will be called by the challenge if Present ends // be solved. CleanUp will be called by the challenge if Present ends
// in a non-error state. // in a non-error state.
type ChallengeProvider interface { type Provider interface {
Present(domain, token, keyAuth string) error Present(domain, token, keyAuth string) error
CleanUp(domain, token, keyAuth string) error CleanUp(domain, token, keyAuth string) error
} }
// ChallengeProviderTimeout allows for implementing a // ProviderTimeout allows for implementing a
// ChallengeProvider where an unusually long timeout is required when // Provider where an unusually long timeout is required when
// waiting for an ACME challenge to be satisfied, such as when // waiting for an ACME challenge to be satisfied, such as when
// checking for DNS record progagation. If an implementor of a // checking for DNS record propagation. If an implementor of a
// ChallengeProvider provides a Timeout method, then the return values // Provider provides a Timeout method, then the return values
// of the Timeout method will be used when appropriate by the acme // of the Timeout method will be used when appropriate by the acme
// package. The interval value is the time between checks. // package. The interval value is the time between checks.
// //
// The default values used for timeout and interval are 60 seconds and // The default values used for timeout and interval are 60 seconds and
// 2 seconds respectively. These are used when no Timeout method is // 2 seconds respectively. These are used when no Timeout method is
// defined for the ChallengeProvider. // defined for the Provider.
type ChallengeProviderTimeout interface { type ProviderTimeout interface {
ChallengeProvider Provider
Timeout() (timeout, interval time.Duration) Timeout() (timeout, interval time.Duration)
} }

View file

@ -0,0 +1,25 @@
package resolver
import (
"bytes"
"fmt"
"sort"
)
// obtainError is returned when there are specific errors available per domain.
type obtainError map[string]error
func (e obtainError) Error() string {
buffer := bytes.NewBufferString("acme: Error -> One or more domains had a problem:\n")
var domains []string
for domain := range e {
domains = append(domains, domain)
}
sort.Strings(domains)
for _, domain := range domains {
buffer.WriteString(fmt.Sprintf("[%s] %s\n", domain, e[domain]))
}
return buffer.String()
}

View file

@ -0,0 +1,173 @@
package resolver
import (
"fmt"
"time"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/challenge"
"github.com/xenolf/lego/log"
)
// Interface for all challenge solvers to implement.
type solver interface {
Solve(authorization acme.Authorization) error
}
// Interface for challenges like dns, where we can set a record in advance for ALL challenges.
// This saves quite a bit of time vs creating the records and solving them serially.
type preSolver interface {
PreSolve(authorization acme.Authorization) error
}
// Interface for challenges like dns, where we can solve all the challenges before to delete them.
type cleanup interface {
CleanUp(authorization acme.Authorization) error
}
type sequential interface {
Sequential() (bool, time.Duration)
}
// an authz with the solver we have chosen and the index of the challenge associated with it
type selectedAuthSolver struct {
authz acme.Authorization
solver solver
}
type Prober struct {
solverManager *SolverManager
}
func NewProber(solverManager *SolverManager) *Prober {
return &Prober{
solverManager: solverManager,
}
}
// Solve Looks through the challenge combinations to find a solvable match.
// Then solves the challenges in series and returns.
func (p *Prober) Solve(authorizations []acme.Authorization) error {
failures := make(obtainError)
var authSolvers []*selectedAuthSolver
var authSolversSequential []*selectedAuthSolver
// Loop through the resources, basically through the domains.
// First pass just selects a solver for each authz.
for _, authz := range authorizations {
domain := challenge.GetTargetedDomain(authz)
if authz.Status == acme.StatusValid {
// Boulder might recycle recent validated authz (see issue #267)
log.Infof("[%s] acme: authorization already valid; skipping challenge", domain)
continue
}
if solvr := p.solverManager.chooseSolver(authz); solvr != nil {
authSolver := &selectedAuthSolver{authz: authz, solver: solvr}
switch s := solvr.(type) {
case sequential:
if ok, _ := s.Sequential(); ok {
authSolversSequential = append(authSolversSequential, authSolver)
} else {
authSolvers = append(authSolvers, authSolver)
}
default:
authSolvers = append(authSolvers, authSolver)
}
} else {
failures[domain] = fmt.Errorf("[%s] acme: could not determine solvers", domain)
}
}
parallelSolve(authSolvers, failures)
sequentialSolve(authSolversSequential, failures)
// Be careful not to return an empty failures map,
// for even an empty obtainError is a non-nil error value
if len(failures) > 0 {
return failures
}
return nil
}
func sequentialSolve(authSolvers []*selectedAuthSolver, failures obtainError) {
for i, authSolver := range authSolvers {
// Submit the challenge
domain := challenge.GetTargetedDomain(authSolver.authz)
if solvr, ok := authSolver.solver.(preSolver); ok {
err := solvr.PreSolve(authSolver.authz)
if err != nil {
failures[domain] = err
cleanUp(authSolver.solver, authSolver.authz)
continue
}
}
// Solve challenge
err := authSolver.solver.Solve(authSolver.authz)
if err != nil {
failures[domain] = err
cleanUp(authSolver.solver, authSolver.authz)
continue
}
// Clean challenge
cleanUp(authSolver.solver, authSolver.authz)
if len(authSolvers)-1 > i {
solvr := authSolver.solver.(sequential)
_, interval := solvr.Sequential()
log.Infof("sequence: wait for %s", interval)
time.Sleep(interval)
}
}
}
func parallelSolve(authSolvers []*selectedAuthSolver, failures obtainError) {
// For all valid preSolvers, first submit the challenges so they have max time to propagate
for _, authSolver := range authSolvers {
authz := authSolver.authz
if solvr, ok := authSolver.solver.(preSolver); ok {
err := solvr.PreSolve(authz)
if err != nil {
failures[challenge.GetTargetedDomain(authz)] = err
}
}
}
defer func() {
// Clean all created TXT records
for _, authSolver := range authSolvers {
cleanUp(authSolver.solver, authSolver.authz)
}
}()
// Finally solve all challenges for real
for _, authSolver := range authSolvers {
authz := authSolver.authz
domain := challenge.GetTargetedDomain(authz)
if failures[domain] != nil {
// already failed in previous loop
continue
}
err := authSolver.solver.Solve(authz)
if err != nil {
failures[domain] = err
}
}
}
func cleanUp(solvr solver, authz acme.Authorization) {
if solvr, ok := solvr.(cleanup); ok {
domain := challenge.GetTargetedDomain(authz)
err := solvr.CleanUp(authz)
if err != nil {
log.Warnf("[%s] acme: error cleaning up: %v ", domain, err)
}
}
}

View file

@ -0,0 +1,154 @@
package resolver
import (
"errors"
"fmt"
"sort"
"strconv"
"time"
"github.com/xenolf/lego/acme"
"github.com/xenolf/lego/acme/api"
"github.com/xenolf/lego/challenge"
"github.com/xenolf/lego/challenge/dns01"
"github.com/xenolf/lego/challenge/http01"
"github.com/xenolf/lego/challenge/tlsalpn01"
"github.com/xenolf/lego/log"
)
type byType []acme.Challenge
func (a byType) Len() int { return len(a) }
func (a byType) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a byType) Less(i, j int) bool { return a[i].Type > a[j].Type }
type SolverManager struct {
core *api.Core
solvers map[challenge.Type]solver
}
func NewSolversManager(core *api.Core) *SolverManager {
return &SolverManager{
solvers: map[challenge.Type]solver{},
core: core,
}
}
// SetHTTP01Provider specifies a custom provider p that can solve the given HTTP-01 challenge.
func (c *SolverManager) SetHTTP01Provider(p challenge.Provider) error {
c.solvers[challenge.HTTP01] = http01.NewChallenge(c.core, validate, p)
return nil
}
// SetTLSALPN01Provider specifies a custom provider p that can solve the given TLS-ALPN-01 challenge.
func (c *SolverManager) SetTLSALPN01Provider(p challenge.Provider) error {
c.solvers[challenge.TLSALPN01] = tlsalpn01.NewChallenge(c.core, validate, p)
return nil
}
// SetDNS01Provider specifies a custom provider p that can solve the given DNS-01 challenge.
func (c *SolverManager) SetDNS01Provider(p challenge.Provider, opts ...dns01.ChallengeOption) error {
c.solvers[challenge.DNS01] = dns01.NewChallenge(c.core, validate, p, opts...)
return nil
}
// Remove Remove a challenge type from the available solvers.
func (c *SolverManager) Remove(chlgType challenge.Type) {
delete(c.solvers, chlgType)
}
// Checks all challenges from the server in order and returns the first matching solver.
func (c *SolverManager) chooseSolver(authz acme.Authorization) solver {
// Allow to have a deterministic challenge order
sort.Sort(byType(authz.Challenges))
domain := challenge.GetTargetedDomain(authz)
for _, chlg := range authz.Challenges {
if solvr, ok := c.solvers[challenge.Type(chlg.Type)]; ok {
log.Infof("[%s] acme: use %s solver", domain, chlg.Type)
return solvr
}
log.Infof("[%s] acme: Could not find solver for: %s", domain, chlg.Type)
}
return nil
}
func validate(core *api.Core, domain string, chlg acme.Challenge) error {
chlng, err := core.Challenges.New(chlg.URL)
if err != nil {
return fmt.Errorf("failed to initiate challenge: %v", err)
}
valid, err := checkChallengeStatus(chlng)
if err != nil {
return err
}
if valid {
log.Infof("[%s] The server validated our request", domain)
return nil
}
// After the path is sent, the ACME server will access our server.
// Repeatedly check the server for an updated status on our request.
for {
authz, err := core.Authorizations.Get(chlng.AuthorizationURL)
if err != nil {
return err
}
valid, err := checkAuthorizationStatus(authz)
if err != nil {
return err
}
if valid {
log.Infof("[%s] The server validated our request", domain)
return nil
}
ra, err := strconv.Atoi(chlng.RetryAfter)
if err != nil {
// The ACME server MUST return a Retry-After.
// If it doesn't, we'll just poll hard.
// Boulder does not implement the ability to retry challenges or the Retry-After header.
// https://github.com/letsencrypt/boulder/blob/master/docs/acme-divergences.md#section-82
ra = 5
}
time.Sleep(time.Duration(ra) * time.Second)
}
}
func checkChallengeStatus(chlng acme.ExtendedChallenge) (bool, error) {
switch chlng.Status {
case acme.StatusValid:
return true, nil
case acme.StatusPending, acme.StatusProcessing:
return false, nil
case acme.StatusInvalid:
return false, chlng.Error
default:
return false, errors.New("the server returned an unexpected state")
}
}
func checkAuthorizationStatus(authz acme.Authorization) (bool, error) {
switch authz.Status {
case acme.StatusValid:
return true, nil
case acme.StatusPending, acme.StatusProcessing:
return false, nil
case acme.StatusDeactivated, acme.StatusExpired, acme.StatusRevoked:
return false, fmt.Errorf("the authorization state %s", authz.Status)
case acme.StatusInvalid:
for _, chlg := range authz.Challenges {
if chlg.Status == acme.StatusInvalid && chlg.Error != nil {
return false, chlg.Error
}
}
return false, fmt.Errorf("the authorization state %s", authz.Status)
default:
return false, errors.New("the server returned an unexpected state")
}
}

Some files were not shown because too many files have changed in this diff Show more