From 4d485e1b6b5a2f0390b47182ba87b6b929c426b0 Mon Sep 17 00:00:00 2001 From: Vincent Demeester Date: Fri, 13 Nov 2015 11:50:32 +0100 Subject: [PATCH] Refactor providers and add tests - Add a `baseProvider` struct with common - Refactor docker, kv(s) and marathon providers (spliting into small pieces) - Add unit tests Signed-off-by: Vincent Demeester --- integration/file_test.go | 4 +- provider/boltdb.go | 17 +- provider/consul.go | 17 +- provider/docker.go | 257 ++++++------- provider/docker_test.go | 788 ++++++++++++++++++++++++++++++++++++++ provider/etcd.go | 17 +- provider/file.go | 13 +- provider/kv.go | 175 +++------ provider/kv_test.go | 312 +++++++++++++++ provider/marathon.go | 255 ++++++------ provider/marathon_test.go | 656 +++++++++++++++++++++++++++++++ provider/provider.go | 59 ++- provider/provider_test.go | 170 ++++++++ provider/zk.go | 17 +- 14 files changed, 2319 insertions(+), 438 deletions(-) create mode 100644 provider/docker_test.go create mode 100644 provider/kv_test.go create mode 100644 provider/marathon_test.go create mode 100644 provider/provider_test.go diff --git a/integration/file_test.go b/integration/file_test.go index 3f8dd6236..7fbed2390 100644 --- a/integration/file_test.go +++ b/integration/file_test.go @@ -15,7 +15,7 @@ func (s *FileSuite) TestSimpleConfiguration(c *check.C) { c.Assert(err, checker.IsNil) defer cmd.Process.Kill() - time.Sleep(500 * time.Millisecond) + time.Sleep(1000 * time.Millisecond) resp, err := http.Get("http://127.0.0.1/") // Expected a 404 as we did not configure anything @@ -30,7 +30,7 @@ func (s *FileSuite) TestSimpleConfigurationNoPanic(c *check.C) { c.Assert(err, checker.IsNil) defer cmd.Process.Kill() - time.Sleep(500 * time.Millisecond) + time.Sleep(1000 * time.Millisecond) resp, err := http.Get("http://127.0.0.1/") // Expected a 404 as we did not configure anything diff --git a/provider/boltdb.go b/provider/boltdb.go index ecad47497..9304f8a8f 100644 --- a/provider/boltdb.go +++ b/provider/boltdb.go @@ -1,19 +1,20 @@ package provider -import "github.com/emilevauge/traefik/types" +import ( + "github.com/docker/libkv/store" + "github.com/docker/libkv/store/boltdb" + "github.com/emilevauge/traefik/types" +) // BoltDb holds configurations of the BoltDb provider. type BoltDb struct { - Watch bool - Endpoint string - Prefix string - Filename string - KvProvider *Kv + Kv } // Provide allows the provider to provide configurations to traefik // using the given configuration channel. func (provider *BoltDb) Provide(configurationChan chan<- types.ConfigMessage) error { - provider.KvProvider = NewBoltDbProvider(provider) - return provider.KvProvider.provide(configurationChan) + provider.StoreType = store.BOLTDB + boltdb.Register() + return provider.provide(configurationChan) } diff --git a/provider/consul.go b/provider/consul.go index 7f46d3ba5..fa0be91cf 100644 --- a/provider/consul.go +++ b/provider/consul.go @@ -1,19 +1,20 @@ package provider -import "github.com/emilevauge/traefik/types" +import ( + "github.com/docker/libkv/store" + "github.com/docker/libkv/store/consul" + "github.com/emilevauge/traefik/types" +) // Consul holds configurations of the Consul provider. type Consul struct { - Watch bool - Endpoint string - Prefix string - Filename string - KvProvider *Kv + Kv } // Provide allows the provider to provide configurations to traefik // using the given configuration channel. func (provider *Consul) Provide(configurationChan chan<- types.ConfigMessage) error { - provider.KvProvider = NewConsulProvider(provider) - return provider.KvProvider.provide(configurationChan) + provider.StoreType = store.CONSUL + consul.Register() + return provider.provide(configurationChan) } diff --git a/provider/docker.go b/provider/docker.go index c03278f0f..07e6e2e47 100644 --- a/provider/docker.go +++ b/provider/docker.go @@ -1,7 +1,6 @@ package provider import ( - "bytes" "errors" "fmt" "strconv" @@ -9,20 +8,17 @@ import ( "text/template" "time" - "github.com/BurntSushi/toml" "github.com/BurntSushi/ty/fun" log "github.com/Sirupsen/logrus" "github.com/cenkalti/backoff" - "github.com/emilevauge/traefik/autogen" "github.com/emilevauge/traefik/types" "github.com/fsouza/go-dockerclient" ) // Docker holds configurations of the Docker provider. type Docker struct { - Watch bool + baseProvider Endpoint string - Filename string Domain string } @@ -55,9 +51,12 @@ func (provider *Docker) Provide(configurationChan chan<- types.ConfigMessage) er } if event.Status == "start" || event.Status == "die" { log.Debugf("Docker event receveived %+v", event) - configuration := provider.loadDockerConfig(dockerClient) + configuration := provider.loadDockerConfig(listContainers(dockerClient)) if configuration != nil { - configurationChan <- types.ConfigMessage{"docker", configuration} + configurationChan <- types.ConfigMessage{ + ProviderName: "docker", + Configuration: configuration, + } } } } @@ -72,94 +71,31 @@ func (provider *Docker) Provide(configurationChan chan<- types.ConfigMessage) er }() } - configuration := provider.loadDockerConfig(dockerClient) - configurationChan <- types.ConfigMessage{"docker", configuration} + configuration := provider.loadDockerConfig(listContainers(dockerClient)) + configurationChan <- types.ConfigMessage{ + ProviderName: "docker", + Configuration: configuration, + } return nil } -func (provider *Docker) loadDockerConfig(dockerClient *docker.Client) *types.Configuration { +func (provider *Docker) loadDockerConfig(containersInspected []docker.Container) *types.Configuration { var DockerFuncMap = template.FuncMap{ - "getBackend": func(container docker.Container) string { - if label, err := provider.getLabel(container, "traefik.backend"); err == nil { - return label - } - return provider.getEscapedName(container.Name) - }, - "getPort": func(container docker.Container) string { - if label, err := provider.getLabel(container, "traefik.port"); err == nil { - return label - } - for key := range container.NetworkSettings.Ports { - return key.Port() - } - return "" - }, - "getWeight": func(container docker.Container) string { - if label, err := provider.getLabel(container, "traefik.weight"); err == nil { - return label - } - return "0" - }, - "getDomain": func(container docker.Container) string { - if label, err := provider.getLabel(container, "traefik.domain"); err == nil { - return label - } - return provider.Domain - }, - "getProtocol": func(container docker.Container) string { - if label, err := provider.getLabel(container, "traefik.protocol"); err == nil { - return label - } - return "http" - }, - "getPassHostHeader": func(container docker.Container) string { - if passHostHeader, err := provider.getLabel(container, "traefik.frontend.passHostHeader"); err == nil { - return passHostHeader - } - return "false" - }, - "getFrontendValue": provider.GetFrontendValue, - "getFrontendRule": provider.GetFrontendRule, - "replace": func(s1 string, s2 string, s3 string) string { - return strings.Replace(s3, s1, s2, -1) - }, - } - configuration := new(types.Configuration) - containerList, _ := dockerClient.ListContainers(docker.ListContainersOptions{}) - containersInspected := []docker.Container{} - frontends := map[string][]docker.Container{} - - // get inspect containers - for _, container := range containerList { - containerInspected, _ := dockerClient.InspectContainer(container.ID) - containersInspected = append(containersInspected, *containerInspected) + "getBackend": provider.getBackend, + "getPort": provider.getPort, + "getWeight": provider.getWeight, + "getDomain": provider.getDomain, + "getProtocol": provider.getProtocol, + "getPassHostHeader": provider.getPassHostHeader, + "getFrontendValue": provider.getFrontendValue, + "getFrontendRule": provider.getFrontendRule, + "replace": replace, } // filter containers - filteredContainers := fun.Filter(func(container docker.Container) bool { - if len(container.NetworkSettings.Ports) == 0 { - log.Debugf("Filtering container without port %s", container.Name) - return false - } - _, err := strconv.Atoi(container.Config.Labels["traefik.port"]) - if len(container.NetworkSettings.Ports) > 1 && err != nil { - log.Debugf("Filtering container with more than 1 port and no traefik.port label %s", container.Name) - return false - } - if container.Config.Labels["traefik.enable"] == "false" { - log.Debugf("Filtering disabled container %s", container.Name) - return false - } - - labels, err := provider.getLabels(container, []string{"traefik.frontend.rule", "traefik.frontend.value"}) - if len(labels) != 0 && err != nil { - log.Debugf("Filtering bad labeled container %s", container.Name) - return false - } - - return true - }, containersInspected).([]docker.Container) + filteredContainers := fun.Filter(containerFilter, containersInspected).([]docker.Container) + frontends := map[string][]docker.Container{} for _, container := range filteredContainers { frontends[provider.getFrontendName(container)] = append(frontends[provider.getFrontendName(container)], container) } @@ -173,53 +109,112 @@ func (provider *Docker) loadDockerConfig(dockerClient *docker.Client) *types.Con frontends, provider.Domain, } - tmpl := template.New(provider.Filename).Funcs(DockerFuncMap) - if len(provider.Filename) > 0 { - _, err := tmpl.ParseFiles(provider.Filename) - if err != nil { - log.Error("Error reading file", err) - return nil - } - } else { - buf, err := autogen.Asset("templates/docker.tmpl") - if err != nil { - log.Error("Error reading file", err) - } - _, err = tmpl.Parse(string(buf)) - if err != nil { - log.Error("Error reading file", err) - return nil - } - } - var buffer bytes.Buffer - err := tmpl.Execute(&buffer, templateObjects) + configuration, err := provider.getConfiguration("templates/docker.tmpl", DockerFuncMap, templateObjects) if err != nil { - log.Error("Error with docker template", err) - return nil - } - - if _, err := toml.Decode(buffer.String(), configuration); err != nil { - log.Error("Error creating docker configuration ", err) - return nil + log.Error(err) } return configuration } +func containerFilter(container docker.Container) bool { + if len(container.NetworkSettings.Ports) == 0 { + log.Debugf("Filtering container without port %s", container.Name) + return false + } + _, err := strconv.Atoi(container.Config.Labels["traefik.port"]) + if len(container.NetworkSettings.Ports) > 1 && err != nil { + log.Debugf("Filtering container with more than 1 port and no traefik.port label %s", container.Name) + return false + } + + if container.Config.Labels["traefik.enable"] == "false" { + log.Debugf("Filtering disabled container %s", container.Name) + return false + } + + labels, err := getLabels(container, []string{"traefik.frontend.rule", "traefik.frontend.value"}) + if len(labels) != 0 && err != nil { + log.Debugf("Filtering bad labeled container %s", container.Name) + return false + } + + return true +} + func (provider *Docker) getFrontendName(container docker.Container) string { // Replace '.' with '-' in quoted keys because of this issue https://github.com/BurntSushi/toml/issues/78 - frontendName := fmt.Sprintf("%s-%s", provider.GetFrontendRule(container), provider.GetFrontendValue(container)) + frontendName := fmt.Sprintf("%s-%s", provider.getFrontendRule(container), provider.getFrontendValue(container)) frontendName = strings.Replace(frontendName, "[", "", -1) frontendName = strings.Replace(frontendName, "]", "", -1) return strings.Replace(frontendName, ".", "-", -1) } -func (provider *Docker) getEscapedName(name string) string { - return strings.Replace(name, "/", "", -1) +// GetFrontendValue returns the frontend value for the specified container, using +// it's label. It returns a default one if the label is not present. +func (provider *Docker) getFrontendValue(container docker.Container) string { + if label, err := getLabel(container, "traefik.frontend.value"); err == nil { + return label + } + return getEscapedName(container.Name) + "." + provider.Domain } -func (provider *Docker) getLabel(container docker.Container, label string) (string, error) { +// GetFrontendRule returns the frontend rule for the specified container, using +// it's label. It returns a default one (Host) if the label is not present. +func (provider *Docker) getFrontendRule(container docker.Container) string { + if label, err := getLabel(container, "traefik.frontend.rule"); err == nil { + return label + } + return "Host" +} + +func (provider *Docker) getBackend(container docker.Container) string { + if label, err := getLabel(container, "traefik.backend"); err == nil { + return label + } + return getEscapedName(container.Name) +} + +func (provider *Docker) getPort(container docker.Container) string { + if label, err := getLabel(container, "traefik.port"); err == nil { + return label + } + for key := range container.NetworkSettings.Ports { + return key.Port() + } + return "" +} + +func (provider *Docker) getWeight(container docker.Container) string { + if label, err := getLabel(container, "traefik.weight"); err == nil { + return label + } + return "0" +} + +func (provider *Docker) getDomain(container docker.Container) string { + if label, err := getLabel(container, "traefik.domain"); err == nil { + return label + } + return provider.Domain +} + +func (provider *Docker) getProtocol(container docker.Container) string { + if label, err := getLabel(container, "traefik.protocol"); err == nil { + return label + } + return "http" +} + +func (provider *Docker) getPassHostHeader(container docker.Container) string { + if passHostHeader, err := getLabel(container, "traefik.frontend.passHostHeader"); err == nil { + return passHostHeader + } + return "false" +} + +func getLabel(container docker.Container, label string) (string, error) { for key, value := range container.Config.Labels { if key == label { return value, nil @@ -228,11 +223,11 @@ func (provider *Docker) getLabel(container docker.Container, label string) (stri return "", errors.New("Label not found:" + label) } -func (provider *Docker) getLabels(container docker.Container, labels []string) (map[string]string, error) { +func getLabels(container docker.Container, labels []string) (map[string]string, error) { var globalErr error foundLabels := map[string]string{} for _, label := range labels { - foundLabel, err := provider.getLabel(container, label) + foundLabel, err := getLabel(container, label) // Error out only if one of them is defined. if err != nil { globalErr = errors.New("Label not found: " + label) @@ -244,20 +239,14 @@ func (provider *Docker) getLabels(container docker.Container, labels []string) ( return foundLabels, globalErr } -// GetFrontendValue returns the frontend value for the specified container, using -// it's label. It returns a default one if the label is not present. -func (provider *Docker) GetFrontendValue(container docker.Container) string { - if label, err := provider.getLabel(container, "traefik.frontend.value"); err == nil { - return label - } - return provider.getEscapedName(container.Name) + "." + provider.Domain -} +func listContainers(dockerClient *docker.Client) []docker.Container { + containerList, _ := dockerClient.ListContainers(docker.ListContainersOptions{}) + containersInspected := []docker.Container{} -// GetFrontendRule returns the frontend rule for the specified container, using -// it's label. It returns a default one (Host) if the label is not present. -func (provider *Docker) GetFrontendRule(container docker.Container) string { - if label, err := provider.getLabel(container, "traefik.frontend.rule"); err == nil { - return label + // get inspect containers + for _, container := range containerList { + containerInspected, _ := dockerClient.InspectContainer(container.ID) + containersInspected = append(containersInspected, *containerInspected) } - return "Host" + return containersInspected } diff --git a/provider/docker_test.go b/provider/docker_test.go new file mode 100644 index 000000000..54b30dc5d --- /dev/null +++ b/provider/docker_test.go @@ -0,0 +1,788 @@ +package provider + +import ( + "reflect" + "strings" + "testing" + + "github.com/emilevauge/traefik/types" + "github.com/fsouza/go-dockerclient" +) + +func TestDockerGetFrontendName(t *testing.T) { + provider := &Docker{ + Domain: "docker.localhost", + } + + containers := []struct { + container docker.Container + expected string + }{ + { + container: docker.Container{ + Name: "foo", + Config: &docker.Config{}, + }, + expected: "Host-foo-docker-localhost", + }, + { + container: docker.Container{ + Name: "bar", + Config: &docker.Config{ + Labels: map[string]string{ + "traefik.frontend.rule": "Header", + }, + }, + }, + expected: "Header-bar-docker-localhost", + }, + { + container: docker.Container{ + Name: "test", + Config: &docker.Config{ + Labels: map[string]string{ + "traefik.frontend.value": "foo.bar", + }, + }, + }, + expected: "Host-foo-bar", + }, + { + container: docker.Container{ + Name: "test", + Config: &docker.Config{ + Labels: map[string]string{ + "traefik.frontend.value": "foo.bar", + "traefik.frontend.rule": "Header", + }, + }, + }, + expected: "Header-foo-bar", + }, + { + container: docker.Container{ + Name: "test", + Config: &docker.Config{ + Labels: map[string]string{ + "traefik.frontend.value": "[foo.bar]", + "traefik.frontend.rule": "Header", + }, + }, + }, + expected: "Header-foo-bar", + }, + } + + for _, e := range containers { + actual := provider.getFrontendName(e.container) + if actual != e.expected { + t.Fatalf("expected %q, got %q", e.expected, actual) + } + } +} + +func TestDockerGetFrontendValue(t *testing.T) { + provider := &Docker{ + Domain: "docker.localhost", + } + + containers := []struct { + container docker.Container + expected string + }{ + { + container: docker.Container{ + Name: "foo", + Config: &docker.Config{}, + }, + expected: "foo.docker.localhost", + }, + { + container: docker.Container{ + Name: "bar", + Config: &docker.Config{}, + }, + expected: "bar.docker.localhost", + }, + { + container: docker.Container{ + Name: "test", + Config: &docker.Config{ + Labels: map[string]string{ + "traefik.frontend.value": "foo.bar", + }, + }, + }, + expected: "foo.bar", + }, + } + + for _, e := range containers { + actual := provider.getFrontendValue(e.container) + if actual != e.expected { + t.Fatalf("expected %q, got %q", e.expected, actual) + } + } +} + +func TestDockerGetFrontendRule(t *testing.T) { + provider := &Docker{} + + containers := []struct { + container docker.Container + expected string + }{ + { + container: docker.Container{ + Name: "foo", + Config: &docker.Config{}, + }, + expected: "Host", + }, + { + container: docker.Container{ + Name: "test", + Config: &docker.Config{ + Labels: map[string]string{ + "traefik.frontend.rule": "foo", + }, + }, + }, + expected: "foo", + }, + } + + for _, e := range containers { + actual := provider.getFrontendRule(e.container) + if actual != e.expected { + t.Fatalf("expected %q, got %q", e.expected, actual) + } + } +} + +func TestDockerGetBackend(t *testing.T) { + provider := &Docker{} + + containers := []struct { + container docker.Container + expected string + }{ + { + container: docker.Container{ + Name: "foo", + Config: &docker.Config{}, + }, + expected: "foo", + }, + { + container: docker.Container{ + Name: "bar", + Config: &docker.Config{}, + }, + expected: "bar", + }, + { + container: docker.Container{ + Name: "test", + Config: &docker.Config{ + Labels: map[string]string{ + "traefik.backend": "foobar", + }, + }, + }, + expected: "foobar", + }, + } + + for _, e := range containers { + actual := provider.getBackend(e.container) + if actual != e.expected { + t.Fatalf("expected %q, got %q", e.expected, actual) + } + } +} + +func TestDockerGetPort(t *testing.T) { + provider := &Docker{} + + containers := []struct { + container docker.Container + expected string + }{ + { + container: docker.Container{ + Name: "foo", + Config: &docker.Config{}, + NetworkSettings: &docker.NetworkSettings{}, + }, + expected: "", + }, + { + container: docker.Container{ + Name: "bar", + Config: &docker.Config{}, + NetworkSettings: &docker.NetworkSettings{ + Ports: map[docker.Port][]docker.PortBinding{ + "80/tcp": {}, + }, + }, + }, + expected: "80", + }, + // FIXME handle this better.. + // { + // container: docker.Container{ + // Name: "bar", + // Config: &docker.Config{}, + // NetworkSettings: &docker.NetworkSettings{ + // Ports: map[docker.Port][]docker.PortBinding{ + // "80/tcp": []docker.PortBinding{}, + // "443/tcp": []docker.PortBinding{}, + // }, + // }, + // }, + // expected: "80", + // }, + { + container: docker.Container{ + Name: "test", + Config: &docker.Config{ + Labels: map[string]string{ + "traefik.port": "8080", + }, + }, + NetworkSettings: &docker.NetworkSettings{ + Ports: map[docker.Port][]docker.PortBinding{ + "80/tcp": {}, + }, + }, + }, + expected: "8080", + }, + } + + for _, e := range containers { + actual := provider.getPort(e.container) + if actual != e.expected { + t.Fatalf("expected %q, got %q", e.expected, actual) + } + } +} + +func TestDockerGetWeight(t *testing.T) { + provider := &Docker{} + + containers := []struct { + container docker.Container + expected string + }{ + { + container: docker.Container{ + Name: "foo", + Config: &docker.Config{}, + }, + expected: "0", + }, + { + container: docker.Container{ + Name: "test", + Config: &docker.Config{ + Labels: map[string]string{ + "traefik.weight": "10", + }, + }, + }, + expected: "10", + }, + } + + for _, e := range containers { + actual := provider.getWeight(e.container) + if actual != e.expected { + t.Fatalf("expected %q, got %q", e.expected, actual) + } + } +} + +func TestDockerGetDomain(t *testing.T) { + provider := &Docker{ + Domain: "docker.localhost", + } + + containers := []struct { + container docker.Container + expected string + }{ + { + container: docker.Container{ + Name: "foo", + Config: &docker.Config{}, + }, + expected: "docker.localhost", + }, + { + container: docker.Container{ + Name: "test", + Config: &docker.Config{ + Labels: map[string]string{ + "traefik.domain": "foo.bar", + }, + }, + }, + expected: "foo.bar", + }, + } + + for _, e := range containers { + actual := provider.getDomain(e.container) + if actual != e.expected { + t.Fatalf("expected %q, got %q", e.expected, actual) + } + } +} + +func TestDockerGetProtocol(t *testing.T) { + provider := &Docker{} + + containers := []struct { + container docker.Container + expected string + }{ + { + container: docker.Container{ + Name: "foo", + Config: &docker.Config{}, + }, + expected: "http", + }, + { + container: docker.Container{ + Name: "test", + Config: &docker.Config{ + Labels: map[string]string{ + "traefik.protocol": "https", + }, + }, + }, + expected: "https", + }, + } + + for _, e := range containers { + actual := provider.getProtocol(e.container) + if actual != e.expected { + t.Fatalf("expected %q, got %q", e.expected, actual) + } + } +} + +func TestDockerGetPassHostHeader(t *testing.T) { + provider := &Docker{} + + containers := []struct { + container docker.Container + expected string + }{ + { + container: docker.Container{ + Name: "foo", + Config: &docker.Config{}, + }, + expected: "false", + }, + { + container: docker.Container{ + Name: "test", + Config: &docker.Config{ + Labels: map[string]string{ + "traefik.frontend.passHostHeader": "true", + }, + }, + }, + expected: "true", + }, + } + + for _, e := range containers { + actual := provider.getPassHostHeader(e.container) + if actual != e.expected { + t.Fatalf("expected %q, got %q", e.expected, actual) + } + } +} + +func TestDockerGetLabel(t *testing.T) { + containers := []struct { + container docker.Container + expected string + }{ + { + container: docker.Container{ + Config: &docker.Config{}, + }, + expected: "Label not found:", + }, + { + container: docker.Container{ + Config: &docker.Config{ + Labels: map[string]string{ + "foo": "bar", + }, + }, + }, + expected: "", + }, + } + + for _, e := range containers { + label, err := getLabel(e.container, "foo") + if e.expected != "" { + if err == nil || !strings.Contains(err.Error(), e.expected) { + t.Fatalf("expected an error with %q, got %v", e.expected, err) + } + } else { + if label != "bar" { + t.Fatalf("expected label 'bar', got %s", label) + } + } + } +} + +func TestDockerGetLabels(t *testing.T) { + containers := []struct { + container docker.Container + expectedLabels map[string]string + expectedError string + }{ + { + container: docker.Container{ + Config: &docker.Config{}, + }, + expectedLabels: map[string]string{}, + expectedError: "Label not found:", + }, + { + container: docker.Container{ + Config: &docker.Config{ + Labels: map[string]string{ + "foo": "fooz", + }, + }, + }, + expectedLabels: map[string]string{ + "foo": "fooz", + }, + expectedError: "Label not found: bar", + }, + { + container: docker.Container{ + Config: &docker.Config{ + Labels: map[string]string{ + "foo": "fooz", + "bar": "barz", + }, + }, + }, + expectedLabels: map[string]string{ + "foo": "fooz", + "bar": "barz", + }, + expectedError: "", + }, + } + + for _, e := range containers { + labels, err := getLabels(e.container, []string{"foo", "bar"}) + if !reflect.DeepEqual(labels, e.expectedLabels) { + t.Fatalf("expect %v, got %v", e.expectedLabels, labels) + } + if e.expectedError != "" { + if err == nil || !strings.Contains(err.Error(), e.expectedError) { + t.Fatalf("expected an error with %q, got %v", e.expectedError, err) + } + } + } +} + +func TestDockerTraefikFilter(t *testing.T) { + containers := []struct { + container docker.Container + expected bool + }{ + { + container: docker.Container{ + Config: &docker.Config{}, + NetworkSettings: &docker.NetworkSettings{}, + }, + expected: false, + }, + { + container: docker.Container{ + Config: &docker.Config{ + Labels: map[string]string{ + "traefik.enable": "false", + }, + }, + NetworkSettings: &docker.NetworkSettings{ + Ports: map[docker.Port][]docker.PortBinding{ + "80/tcp": {}, + }, + }, + }, + expected: false, + }, + { + container: docker.Container{ + Config: &docker.Config{ + Labels: map[string]string{ + "traefik.frontend.rule": "Host", + }, + }, + NetworkSettings: &docker.NetworkSettings{ + Ports: map[docker.Port][]docker.PortBinding{ + "80/tcp": {}, + }, + }, + }, + expected: false, + }, + { + container: docker.Container{ + Config: &docker.Config{ + Labels: map[string]string{ + "traefik.frontend.value": "foo.bar", + }, + }, + NetworkSettings: &docker.NetworkSettings{ + Ports: map[docker.Port][]docker.PortBinding{ + "80/tcp": {}, + }, + }, + }, + expected: false, + }, + { + container: docker.Container{ + Config: &docker.Config{}, + NetworkSettings: &docker.NetworkSettings{ + Ports: map[docker.Port][]docker.PortBinding{ + "80/tcp": {}, + "443/tcp": {}, + }, + }, + }, + expected: false, + }, + { + container: docker.Container{ + Config: &docker.Config{}, + NetworkSettings: &docker.NetworkSettings{ + Ports: map[docker.Port][]docker.PortBinding{ + "80/tcp": {}, + }, + }, + }, + expected: true, + }, + { + container: docker.Container{ + Config: &docker.Config{ + Labels: map[string]string{ + "traefik.port": "80", + }, + }, + NetworkSettings: &docker.NetworkSettings{ + Ports: map[docker.Port][]docker.PortBinding{ + "80/tcp": {}, + "443/tcp": {}, + }, + }, + }, + expected: true, + }, + { + container: docker.Container{ + Config: &docker.Config{ + Labels: map[string]string{ + "traefik.enable": "true", + }, + }, + NetworkSettings: &docker.NetworkSettings{ + Ports: map[docker.Port][]docker.PortBinding{ + "80/tcp": {}, + }, + }, + }, + expected: true, + }, + { + container: docker.Container{ + Config: &docker.Config{ + Labels: map[string]string{ + "traefik.enable": "anything", + }, + }, + NetworkSettings: &docker.NetworkSettings{ + Ports: map[docker.Port][]docker.PortBinding{ + "80/tcp": {}, + }, + }, + }, + expected: true, + }, + { + container: docker.Container{ + Config: &docker.Config{ + Labels: map[string]string{ + "traefik.frontend.rule": "Host", + "traefik.frontend.value": "foo.bar", + }, + }, + NetworkSettings: &docker.NetworkSettings{ + Ports: map[docker.Port][]docker.PortBinding{ + "80/tcp": {}, + }, + }, + }, + expected: true, + }, + } + + for _, e := range containers { + actual := containerFilter(e.container) + if actual != e.expected { + t.Fatalf("expected %v, got %v", e.expected, actual) + } + } +} + +func TestDockerLoadDockerConfig(t *testing.T) { + cases := []struct { + containers []docker.Container + expectedFrontends map[string]*types.Frontend + expectedBackends map[string]*types.Backend + }{ + { + containers: []docker.Container{}, + expectedFrontends: map[string]*types.Frontend{}, + expectedBackends: map[string]*types.Backend{}, + }, + { + containers: []docker.Container{ + { + Name: "test", + Config: &docker.Config{}, + NetworkSettings: &docker.NetworkSettings{ + Ports: map[docker.Port][]docker.PortBinding{ + "80/tcp": {}, + }, + IPAddress: "127.0.0.1", + }, + }, + }, + expectedFrontends: map[string]*types.Frontend{ + `"frontend-Host-test-docker-localhost"`: { + Backend: "backend-test", + Routes: map[string]types.Route{ + `"route-frontend-Host-test-docker-localhost"`: { + Rule: "Host", + Value: "test.docker.localhost", + }, + }, + }, + }, + expectedBackends: map[string]*types.Backend{ + "backend-test": { + Servers: map[string]types.Server{ + "server-test": { + URL: "http://127.0.0.1:80", + }, + }, + CircuitBreaker: nil, + LoadBalancer: nil, + }, + }, + }, + { + containers: []docker.Container{ + { + Name: "test1", + Config: &docker.Config{ + Labels: map[string]string{ + "traefik.backend": "foobar", + }, + }, + NetworkSettings: &docker.NetworkSettings{ + Ports: map[docker.Port][]docker.PortBinding{ + "80/tcp": {}, + }, + IPAddress: "127.0.0.1", + }, + }, + { + Name: "test2", + Config: &docker.Config{ + Labels: map[string]string{ + "traefik.backend": "foobar", + }, + }, + NetworkSettings: &docker.NetworkSettings{ + Ports: map[docker.Port][]docker.PortBinding{ + "80/tcp": {}, + }, + IPAddress: "127.0.0.1", + }, + }, + }, + expectedFrontends: map[string]*types.Frontend{ + `"frontend-Host-test1-docker-localhost"`: { + Backend: "backend-foobar", + Routes: map[string]types.Route{ + `"route-frontend-Host-test1-docker-localhost"`: { + Rule: "Host", + Value: "test1.docker.localhost", + }, + }, + }, + `"frontend-Host-test2-docker-localhost"`: { + Backend: "backend-foobar", + Routes: map[string]types.Route{ + `"route-frontend-Host-test2-docker-localhost"`: { + Rule: "Host", + Value: "test2.docker.localhost", + }, + }, + }, + }, + expectedBackends: map[string]*types.Backend{ + "backend-foobar": { + Servers: map[string]types.Server{ + "server-test1": { + URL: "http://127.0.0.1:80", + }, + "server-test2": { + URL: "http://127.0.0.1:80", + }, + }, + CircuitBreaker: nil, + LoadBalancer: nil, + }, + }, + }, + } + + provider := &Docker{ + Domain: "docker.localhost", + } + + for _, c := range cases { + actualConfig := provider.loadDockerConfig(c.containers) + // Compare backends + if !reflect.DeepEqual(actualConfig.Backends, c.expectedBackends) { + t.Fatalf("expected %#v, got %#v", c.expectedBackends, actualConfig.Backends) + } + if !reflect.DeepEqual(actualConfig.Frontends, c.expectedFrontends) { + t.Fatalf("expected %#v, got %#v", c.expectedFrontends, actualConfig.Frontends) + } + } +} diff --git a/provider/etcd.go b/provider/etcd.go index 82eb32abf..d51ecc380 100644 --- a/provider/etcd.go +++ b/provider/etcd.go @@ -1,19 +1,20 @@ package provider -import "github.com/emilevauge/traefik/types" +import ( + "github.com/docker/libkv/store" + "github.com/docker/libkv/store/etcd" + "github.com/emilevauge/traefik/types" +) // Etcd holds configurations of the Etcd provider. type Etcd struct { - Watch bool - Endpoint string - Prefix string - Filename string - KvProvider *Kv + Kv } // Provide allows the provider to provide configurations to traefik // using the given configuration channel. func (provider *Etcd) Provide(configurationChan chan<- types.ConfigMessage) error { - provider.KvProvider = NewEtcdProvider(provider) - return provider.KvProvider.provide(configurationChan) + provider.StoreType = store.ETCD + etcd.Register() + return provider.provide(configurationChan) } diff --git a/provider/file.go b/provider/file.go index 0e1c2c091..c5ef28c00 100644 --- a/provider/file.go +++ b/provider/file.go @@ -13,8 +13,7 @@ import ( // File holds configurations of the File provider. type File struct { - Watch bool - Filename string + baseProvider } // Provide allows the provider to provide configurations to traefik @@ -44,7 +43,10 @@ func (provider *File) Provide(configurationChan chan<- types.ConfigMessage) erro log.Debug("File event:", event) configuration := provider.loadFileConfig(file.Name()) if configuration != nil { - configurationChan <- types.ConfigMessage{"file", configuration} + configurationChan <- types.ConfigMessage{ + ProviderName: "file", + Configuration: configuration, + } } } case error := <-watcher.Errors: @@ -60,7 +62,10 @@ func (provider *File) Provide(configurationChan chan<- types.ConfigMessage) erro } configuration := provider.loadFileConfig(file.Name()) - configurationChan <- types.ConfigMessage{"file", configuration} + configurationChan <- types.ConfigMessage{ + ProviderName: "file", + Configuration: configuration, + } return nil } diff --git a/provider/kv.go b/provider/kv.go index 1690735a3..0ada556a0 100644 --- a/provider/kv.go +++ b/provider/kv.go @@ -2,92 +2,27 @@ package provider import ( - "bytes" - "errors" "strings" "text/template" "time" - "github.com/BurntSushi/toml" "github.com/BurntSushi/ty/fun" log "github.com/Sirupsen/logrus" "github.com/docker/libkv" "github.com/docker/libkv/store" - "github.com/docker/libkv/store/boltdb" - "github.com/docker/libkv/store/consul" - "github.com/docker/libkv/store/etcd" - "github.com/docker/libkv/store/zookeeper" - "github.com/emilevauge/traefik/autogen" "github.com/emilevauge/traefik/types" ) // Kv holds common configurations of key-value providers. type Kv struct { - Watch bool + baseProvider Endpoint string Prefix string - Filename string StoreType store.Backend kvclient store.Store } -// NewConsulProvider returns a Consul provider. -func NewConsulProvider(provider *Consul) *Kv { - kvProvider := new(Kv) - kvProvider.Watch = provider.Watch - kvProvider.Endpoint = provider.Endpoint - kvProvider.Prefix = provider.Prefix - kvProvider.Filename = provider.Filename - kvProvider.StoreType = store.CONSUL - return kvProvider -} - -// NewEtcdProvider returns a Etcd provider. -func NewEtcdProvider(provider *Etcd) *Kv { - kvProvider := new(Kv) - kvProvider.Watch = provider.Watch - kvProvider.Endpoint = provider.Endpoint - kvProvider.Prefix = provider.Prefix - kvProvider.Filename = provider.Filename - kvProvider.StoreType = store.ETCD - return kvProvider -} - -// NewZkProvider returns a Zookepper provider. -func NewZkProvider(provider *Zookepper) *Kv { - kvProvider := new(Kv) - kvProvider.Watch = provider.Watch - kvProvider.Endpoint = provider.Endpoint - kvProvider.Prefix = provider.Prefix - kvProvider.Filename = provider.Filename - kvProvider.StoreType = store.ZK - return kvProvider -} - -// NewBoltDbProvider returns a BoldDb provider. -func NewBoltDbProvider(provider *BoltDb) *Kv { - kvProvider := new(Kv) - kvProvider.Watch = provider.Watch - kvProvider.Endpoint = provider.Endpoint - kvProvider.Prefix = provider.Prefix - kvProvider.Filename = provider.Filename - kvProvider.StoreType = store.BOLTDB - return kvProvider -} - func (provider *Kv) provide(configurationChan chan<- types.ConfigMessage) error { - switch provider.StoreType { - case store.CONSUL: - consul.Register() - case store.ETCD: - etcd.Register() - case store.ZK: - zookeeper.Register() - case store.BOLTDB: - boltdb.Register() - default: - return errors.New("Invalid kv store: " + string(provider.StoreType)) - } kv, err := libkv.NewStore( provider.StoreType, []string{provider.Endpoint}, @@ -114,88 +49,70 @@ func (provider *Kv) provide(configurationChan chan<- types.ConfigMessage) error <-chanKeys configuration := provider.loadConfig() if configuration != nil { - configurationChan <- types.ConfigMessage{string(provider.StoreType), configuration} + configurationChan <- types.ConfigMessage{ + ProviderName: string(provider.StoreType), + Configuration: configuration, + } } defer close(stopCh) } }() } configuration := provider.loadConfig() - configurationChan <- types.ConfigMessage{string(provider.StoreType), configuration} + configurationChan <- types.ConfigMessage{ + ProviderName: string(provider.StoreType), + Configuration: configuration, + } return nil } func (provider *Kv) loadConfig() *types.Configuration { - configuration := new(types.Configuration) templateObjects := struct { Prefix string }{ provider.Prefix, } var KvFuncMap = template.FuncMap{ - "List": func(keys ...string) []string { - joinedKeys := strings.Join(keys, "") - keysPairs, err := provider.kvclient.List(joinedKeys) - if err != nil { - log.Error("Error getting keys: ", joinedKeys, err) - return nil - } - directoryKeys := make(map[string]string) - for _, key := range keysPairs { - directory := strings.Split(strings.TrimPrefix(key.Key, strings.TrimPrefix(joinedKeys, "/")), "/")[0] - directoryKeys[directory] = joinedKeys + directory - } - return fun.Values(directoryKeys).([]string) - }, - "Get": func(keys ...string) string { - joinedKeys := strings.Join(keys, "") - keyPair, err := provider.kvclient.Get(joinedKeys) - if err != nil { - log.Debug("Error getting key: ", joinedKeys, err) - return "" - } else if keyPair == nil { - return "" - } - return string(keyPair.Value) - }, - "Last": func(key string) string { - splittedKey := strings.Split(key, "/") - return splittedKey[len(splittedKey)-1] - }, + "List": provider.list, + "Get": provider.get, + "Last": provider.last, } - tmpl := template.New(provider.Filename).Funcs(KvFuncMap) - if len(provider.Filename) > 0 { - _, err := tmpl.ParseFiles(provider.Filename) - if err != nil { - log.Error("Error reading file", err) - return nil - } - } else { - buf, err := autogen.Asset("templates/kv.tmpl") - if err != nil { - log.Error("Error reading file", err) - } - _, err = tmpl.Parse(string(buf)) - if err != nil { - log.Error("Error reading file", err) - return nil - } - } - - var buffer bytes.Buffer - - err := tmpl.Execute(&buffer, templateObjects) + configuration, err := provider.getConfiguration("templates/kv.tmpl", KvFuncMap, templateObjects) if err != nil { - log.Error("Error with kv template:", err) - return nil + log.Error(err) } - - if _, err := toml.Decode(buffer.String(), configuration); err != nil { - log.Error("Error creating kv configuration:", err) - log.Error(buffer.String()) - return nil - } - return configuration } + +func (provider *Kv) list(keys ...string) []string { + joinedKeys := strings.Join(keys, "") + keysPairs, err := provider.kvclient.List(joinedKeys) + if err != nil { + log.Error("Error getting keys: ", joinedKeys, err) + return nil + } + directoryKeys := make(map[string]string) + for _, key := range keysPairs { + directory := strings.Split(strings.TrimPrefix(key.Key, strings.TrimPrefix(joinedKeys, "/")), "/")[0] + directoryKeys[directory] = joinedKeys + directory + } + return fun.Values(directoryKeys).([]string) +} + +func (provider *Kv) get(keys ...string) string { + joinedKeys := strings.Join(keys, "") + keyPair, err := provider.kvclient.Get(joinedKeys) + if err != nil { + log.Debug("Error getting key: ", joinedKeys, err) + return "" + } else if keyPair == nil { + return "" + } + return string(keyPair.Value) +} + +func (provider *Kv) last(key string) string { + splittedKey := strings.Split(key, "/") + return splittedKey[len(splittedKey)-1] +} diff --git a/provider/kv_test.go b/provider/kv_test.go new file mode 100644 index 000000000..3cb79f257 --- /dev/null +++ b/provider/kv_test.go @@ -0,0 +1,312 @@ +package provider + +import ( + "errors" + "strings" + "testing" + + "github.com/docker/libkv/store" + "reflect" + "sort" +) + +func TestKvList(t *testing.T) { + cases := []struct { + provider *Kv + keys []string + expected []string + }{ + { + provider: &Kv{ + kvclient: &Mock{}, + }, + keys: []string{}, + expected: []string{}, + }, + { + provider: &Kv{ + kvclient: &Mock{}, + }, + keys: []string{"traefik"}, + expected: []string{}, + }, + { + provider: &Kv{ + kvclient: &Mock{ + KVPairs: []*store.KVPair{ + { + Key: "foo", + Value: []byte("bar"), + }, + }, + }, + }, + keys: []string{"bar"}, + expected: []string{}, + }, + { + provider: &Kv{ + kvclient: &Mock{ + KVPairs: []*store.KVPair{ + { + Key: "foo", + Value: []byte("bar"), + }, + }, + }, + }, + keys: []string{"foo"}, + expected: []string{"foo"}, + }, + { + provider: &Kv{ + kvclient: &Mock{ + KVPairs: []*store.KVPair{ + { + Key: "foo/baz/1", + Value: []byte("bar"), + }, + { + Key: "foo/baz/2", + Value: []byte("bar"), + }, + { + Key: "foo/baz/biz/1", + Value: []byte("bar"), + }, + }, + }, + }, + keys: []string{"foo", "/baz/"}, + expected: []string{"foo/baz/biz", "foo/baz/1", "foo/baz/2"}, + }, + } + + for _, c := range cases { + actual := c.provider.list(c.keys...) + sort.Strings(actual) + sort.Strings(c.expected) + if !reflect.DeepEqual(actual, c.expected) { + t.Fatalf("expected %v, got %v for %v and %v", c.expected, actual, c.keys, c.provider) + } + } + + // Error case + provider := &Kv{ + kvclient: &Mock{ + Error: true, + }, + } + actual := provider.list("anything") + if actual != nil { + t.Fatalf("Should have return nil, got %v", actual) + } +} + +func TestKvGet(t *testing.T) { + cases := []struct { + provider *Kv + keys []string + expected string + }{ + { + provider: &Kv{ + kvclient: &Mock{}, + }, + keys: []string{}, + expected: "", + }, + { + provider: &Kv{ + kvclient: &Mock{}, + }, + keys: []string{"traefik"}, + expected: "", + }, + { + provider: &Kv{ + kvclient: &Mock{ + KVPairs: []*store.KVPair{ + { + Key: "foo", + Value: []byte("bar"), + }, + }, + }, + }, + keys: []string{"bar"}, + expected: "", + }, + { + provider: &Kv{ + kvclient: &Mock{ + KVPairs: []*store.KVPair{ + { + Key: "foo", + Value: []byte("bar"), + }, + }, + }, + }, + keys: []string{"foo"}, + expected: "bar", + }, + { + provider: &Kv{ + kvclient: &Mock{ + KVPairs: []*store.KVPair{ + { + Key: "foo/baz/1", + Value: []byte("bar1"), + }, + { + Key: "foo/baz/2", + Value: []byte("bar2"), + }, + { + Key: "foo/baz/biz/1", + Value: []byte("bar3"), + }, + }, + }, + }, + keys: []string{"foo", "/baz/", "2"}, + expected: "bar2", + }, + } + + for _, c := range cases { + actual := c.provider.get(c.keys...) + if actual != c.expected { + t.Fatalf("expected %v, got %v for %v and %v", c.expected, actual, c.keys, c.provider) + } + } + + // Error case + provider := &Kv{ + kvclient: &Mock{ + Error: true, + }, + } + actual := provider.get("anything") + if actual != "" { + t.Fatalf("Should have return nil, got %v", actual) + } +} + +func TestKvLast(t *testing.T) { + cases := []struct { + key string + expected string + }{ + { + key: "", + expected: "", + }, + { + key: "foo", + expected: "foo", + }, + { + key: "foo/bar", + expected: "bar", + }, + { + key: "foo/bar/baz", + expected: "baz", + }, + // FIXME is this wanted ? + { + key: "foo/bar/", + expected: "", + }, + } + + provider := &Kv{} + for _, c := range cases { + actual := provider.last(c.key) + if actual != c.expected { + t.Fatalf("expected %s, got %s", c.expected, actual) + } + } +} + +// Extremely limited mock store so we can test initialization +type Mock struct { + Error bool + KVPairs []*store.KVPair +} + +func (s *Mock) Put(key string, value []byte, opts *store.WriteOptions) error { + return errors.New("Put not supported") +} + +func (s *Mock) Get(key string) (*store.KVPair, error) { + if s.Error { + return nil, errors.New("Error") + } + for _, kvPair := range s.KVPairs { + if kvPair.Key == key { + return kvPair, nil + } + } + return nil, nil +} + +func (s *Mock) Delete(key string) error { + return errors.New("Delete not supported") +} + +// Exists mock +func (s *Mock) Exists(key string) (bool, error) { + return false, errors.New("Exists not supported") +} + +// Watch mock +func (s *Mock) Watch(key string, stopCh <-chan struct{}) (<-chan *store.KVPair, error) { + return nil, errors.New("Watch not supported") +} + +// WatchTree mock +func (s *Mock) WatchTree(prefix string, stopCh <-chan struct{}) (<-chan []*store.KVPair, error) { + return nil, errors.New("WatchTree not supported") +} + +// NewLock mock +func (s *Mock) NewLock(key string, options *store.LockOptions) (store.Locker, error) { + return nil, errors.New("NewLock not supported") +} + +// List mock +func (s *Mock) List(prefix string) ([]*store.KVPair, error) { + if s.Error { + return nil, errors.New("Error") + } + kv := []*store.KVPair{} + for _, kvPair := range s.KVPairs { + if strings.HasPrefix(kvPair.Key, prefix) { + kv = append(kv, kvPair) + } + } + return kv, nil +} + +// DeleteTree mock +func (s *Mock) DeleteTree(prefix string) error { + return errors.New("DeleteTree not supported") +} + +// AtomicPut mock +func (s *Mock) AtomicPut(key string, value []byte, previous *store.KVPair, opts *store.WriteOptions) (bool, *store.KVPair, error) { + return false, nil, errors.New("AtomicPut not supported") +} + +// AtomicDelete mock +func (s *Mock) AtomicDelete(key string, previous *store.KVPair) (bool, error) { + return false, errors.New("AtomicDelete not supported") +} + +// Close mock +func (s *Mock) Close() { + return +} diff --git a/provider/marathon.go b/provider/marathon.go index 890be561c..e01fc5050 100644 --- a/provider/marathon.go +++ b/provider/marathon.go @@ -1,28 +1,29 @@ package provider import ( - "bytes" "errors" + "net/url" "strconv" - "strings" "text/template" - "github.com/BurntSushi/toml" "github.com/BurntSushi/ty/fun" log "github.com/Sirupsen/logrus" - "github.com/emilevauge/traefik/autogen" "github.com/emilevauge/traefik/types" "github.com/gambol99/go-marathon" ) // Marathon holds configuration of the Marathon provider. type Marathon struct { - Watch bool + baseProvider Endpoint string Domain string - Filename string NetworkInterface string - marathonClient marathon.Marathon + marathonClient lightMarathonClient +} + +type lightMarathonClient interface { + Applications(url.Values) (*marathon.Applications, error) + AllTasks() (*marathon.Tasks, error) } // Provide allows the provider to provide configurations to traefik @@ -48,7 +49,10 @@ func (provider *Marathon) Provide(configurationChan chan<- types.ConfigMessage) log.Debug("Marathon event receveived", event) configuration := provider.loadMarathonConfig() if configuration != nil { - configurationChan <- types.ConfigMessage{"marathon", configuration} + configurationChan <- types.ConfigMessage{ + ProviderName: "marathon", + Configuration: configuration, + } } } }() @@ -56,59 +60,24 @@ func (provider *Marathon) Provide(configurationChan chan<- types.ConfigMessage) } configuration := provider.loadMarathonConfig() - configurationChan <- types.ConfigMessage{"marathon", configuration} + configurationChan <- types.ConfigMessage{ + ProviderName: "marathon", + Configuration: configuration, + } return nil } func (provider *Marathon) loadMarathonConfig() *types.Configuration { var MarathonFuncMap = template.FuncMap{ - "getPort": func(task marathon.Task) string { - for _, port := range task.Ports { - return strconv.Itoa(port) - } - return "" - }, - "getWeight": func(task marathon.Task, applications []marathon.Application) string { - application, errApp := getApplication(task, applications) - if errApp != nil { - log.Errorf("Unable to get marathon application from task %s", task.AppID) - return "0" - } - if label, err := provider.getLabel(application, "traefik.weight"); err == nil { - return label - } - return "0" - }, - "getDomain": func(application marathon.Application) string { - if label, err := provider.getLabel(application, "traefik.domain"); err == nil { - return label - } - return provider.Domain - }, - "replace": func(s1 string, s2 string, s3 string) string { - return strings.Replace(s3, s1, s2, -1) - }, - "getProtocol": func(task marathon.Task, applications []marathon.Application) string { - application, errApp := getApplication(task, applications) - if errApp != nil { - log.Errorf("Unable to get marathon application from task %s", task.AppID) - return "http" - } - if label, err := provider.getLabel(application, "traefik.protocol"); err == nil { - return label - } - return "http" - }, - "getPassHostHeader": func(application marathon.Application) string { - if passHostHeader, err := provider.getLabel(application, "traefik.frontend.passHostHeader"); err == nil { - return passHostHeader - } - return "false" - }, - "getFrontendValue": provider.GetFrontendValue, - "getFrontendRule": provider.GetFrontendRule, + "getPort": provider.getPort, + "getWeight": provider.getWeight, + "getDomain": provider.getDomain, + "getProtocol": provider.getProtocol, + "getPassHostHeader": provider.getPassHostHeader, + "getFrontendValue": provider.getFrontendValue, + "getFrontendRule": provider.getFrontendRule, + "replace": replace, } - configuration := new(types.Configuration) applications, err := provider.marathonClient.Applications(nil) if err != nil { @@ -124,54 +93,12 @@ func (provider *Marathon) loadMarathonConfig() *types.Configuration { //filter tasks filteredTasks := fun.Filter(func(task marathon.Task) bool { - if len(task.Ports) == 0 { - log.Debug("Filtering marathon task without port %s", task.AppID) - return false - } - application, errApp := getApplication(task, applications.Apps) - if errApp != nil { - log.Errorf("Unable to get marathon application from task %s", task.AppID) - return false - } - _, err := strconv.Atoi(application.Labels["traefik.port"]) - if len(application.Ports) > 1 && err != nil { - log.Debugf("Filtering marathon task %s with more than 1 port and no traefik.port label", task.AppID) - return false - } - if application.Labels["traefik.enable"] == "false" { - log.Debugf("Filtering disabled marathon task %s", task.AppID) - return false - } - //filter healthchecks - if application.HasHealthChecks() { - if task.HasHealthCheckResults() { - for _, healthcheck := range task.HealthCheckResult { - // found one bad healthcheck, return false - if !healthcheck.Alive { - log.Debugf("Filtering marathon task %s with bad healthcheck", task.AppID) - return false - } - } - } else { - log.Debugf("Filtering marathon task %s with bad healthcheck", task.AppID) - return false - } - } - return true + return taskFilter(task, applications) }, tasks.Tasks).([]marathon.Task) //filter apps filteredApps := fun.Filter(func(app marathon.Application) bool { - //get ports from app tasks - if !fun.Exists(func(task marathon.Task) bool { - if task.AppID == app.ID { - return true - } - return false - }, filteredTasks) { - return false - } - return true + return applicationFilter(app, filteredTasks) }, applications.Apps).([]marathon.Application) templateObjects := struct { @@ -184,41 +111,56 @@ func (provider *Marathon) loadMarathonConfig() *types.Configuration { provider.Domain, } - tmpl := template.New(provider.Filename).Funcs(MarathonFuncMap) - if len(provider.Filename) > 0 { - _, err := tmpl.ParseFiles(provider.Filename) - if err != nil { - log.Error("Error reading file", err) - return nil - } - } else { - buf, err := autogen.Asset("templates/marathon.tmpl") - if err != nil { - log.Error("Error reading file", err) - } - _, err = tmpl.Parse(string(buf)) - if err != nil { - log.Error("Error reading file", err) - return nil - } - } - - var buffer bytes.Buffer - - err = tmpl.Execute(&buffer, templateObjects) + configuration, err := provider.getConfiguration("templates/marathon.tmpl", MarathonFuncMap, templateObjects) if err != nil { - log.Error("Error with marathon template:", err) - return nil + log.Error(err) } - - if _, err := toml.Decode(buffer.String(), configuration); err != nil { - log.Error("Error creating marathon configuration:", err) - return nil - } - return configuration } +func taskFilter(task marathon.Task, applications *marathon.Applications) bool { + if len(task.Ports) == 0 { + log.Debug("Filtering marathon task without port %s", task.AppID) + return false + } + application, errApp := getApplication(task, applications.Apps) + if errApp != nil { + log.Errorf("Unable to get marathon application from task %s", task.AppID) + return false + } + _, err := strconv.Atoi(application.Labels["traefik.port"]) + if len(application.Ports) > 1 && err != nil { + log.Debugf("Filtering marathon task %s with more than 1 port and no traefik.port label", task.AppID) + return false + } + if application.Labels["traefik.enable"] == "false" { + log.Debugf("Filtering disabled marathon task %s", task.AppID) + return false + } + //filter healthchecks + if application.HasHealthChecks() { + if task.HasHealthCheckResults() { + for _, healthcheck := range task.HealthCheckResult { + // found one bad healthcheck, return false + if !healthcheck.Alive { + log.Debugf("Filtering marathon task %s with bad healthcheck", task.AppID) + return false + } + } + } else { + log.Debugf("Filtering marathon task %s with bad healthcheck", task.AppID) + return false + } + } + return true +} + +func applicationFilter(app marathon.Application, filteredTasks []marathon.Task) bool { + return fun.Exists(func(task marathon.Task) bool { + return task.AppID == app.ID + }, filteredTasks) +} + func getApplication(task marathon.Task, apps []marathon.Application) (marathon.Application, error) { for _, application := range apps { if application.ID == task.AppID { @@ -237,22 +179,63 @@ func (provider *Marathon) getLabel(application marathon.Application, label strin return "", errors.New("Label not found:" + label) } -func (provider *Marathon) getEscapedName(name string) string { - return strings.Replace(name, "/", "", -1) +func (provider *Marathon) getPort(task marathon.Task) string { + for _, port := range task.Ports { + return strconv.Itoa(port) + } + return "" } -// GetFrontendValue returns the frontend value for the specified application, using +func (provider *Marathon) getWeight(task marathon.Task, applications []marathon.Application) string { + application, errApp := getApplication(task, applications) + if errApp != nil { + log.Errorf("Unable to get marathon application from task %s", task.AppID) + return "0" + } + if label, err := provider.getLabel(application, "traefik.weight"); err == nil { + return label + } + return "0" +} + +func (provider *Marathon) getDomain(application marathon.Application) string { + if label, err := provider.getLabel(application, "traefik.domain"); err == nil { + return label + } + return provider.Domain +} + +func (provider *Marathon) getProtocol(task marathon.Task, applications []marathon.Application) string { + application, errApp := getApplication(task, applications) + if errApp != nil { + log.Errorf("Unable to get marathon application from task %s", task.AppID) + return "http" + } + if label, err := provider.getLabel(application, "traefik.protocol"); err == nil { + return label + } + return "http" +} + +func (provider *Marathon) getPassHostHeader(application marathon.Application) string { + if passHostHeader, err := provider.getLabel(application, "traefik.frontend.passHostHeader"); err == nil { + return passHostHeader + } + return "false" +} + +// getFrontendValue returns the frontend value for the specified application, using // it's label. It returns a default one if the label is not present. -func (provider *Marathon) GetFrontendValue(application marathon.Application) string { +func (provider *Marathon) getFrontendValue(application marathon.Application) string { if label, err := provider.getLabel(application, "traefik.frontend.value"); err == nil { return label } - return provider.getEscapedName(application.ID) + "." + provider.Domain + return getEscapedName(application.ID) + "." + provider.Domain } -// GetFrontendRule returns the frontend rule for the specified application, using +// getFrontendRule returns the frontend rule for the specified application, using // it's label. It returns a default one (Host) if the label is not present. -func (provider *Marathon) GetFrontendRule(application marathon.Application) string { +func (provider *Marathon) getFrontendRule(application marathon.Application) string { if label, err := provider.getLabel(application, "traefik.frontend.rule"); err == nil { return label } diff --git a/provider/marathon_test.go b/provider/marathon_test.go new file mode 100644 index 000000000..a4508aea8 --- /dev/null +++ b/provider/marathon_test.go @@ -0,0 +1,656 @@ +package provider + +import ( + "errors" + "net/url" + "reflect" + "testing" + + "github.com/emilevauge/traefik/types" + "github.com/gambol99/go-marathon" +) + +type fakeClient struct { + applicationsError bool + applications *marathon.Applications + tasksError bool + tasks *marathon.Tasks +} + +func (c *fakeClient) Applications(url.Values) (*marathon.Applications, error) { + if c.applicationsError { + return nil, errors.New("error") + } + return c.applications, nil +} + +func (c *fakeClient) AllTasks() (*marathon.Tasks, error) { + if c.tasksError { + return nil, errors.New("error") + } + return c.tasks, nil +} + +func TestMarathonLoadConfig(t *testing.T) { + cases := []struct { + applicationsError bool + applications *marathon.Applications + tasksError bool + tasks *marathon.Tasks + expectedNil bool + expectedFrontends map[string]*types.Frontend + expectedBackends map[string]*types.Backend + }{ + { + applications: &marathon.Applications{}, + tasks: &marathon.Tasks{}, + expectedFrontends: map[string]*types.Frontend{}, + expectedBackends: map[string]*types.Backend{}, + }, + { + applicationsError: true, + applications: &marathon.Applications{}, + tasks: &marathon.Tasks{}, + expectedNil: true, + expectedFrontends: map[string]*types.Frontend{}, + expectedBackends: map[string]*types.Backend{}, + }, + { + applications: &marathon.Applications{}, + tasksError: true, + tasks: &marathon.Tasks{}, + expectedNil: true, + expectedFrontends: map[string]*types.Frontend{}, + expectedBackends: map[string]*types.Backend{}, + }, + { + applications: &marathon.Applications{ + Apps: []marathon.Application{ + { + ID: "/test", + Ports: []int{80}, + }, + }, + }, + tasks: &marathon.Tasks{ + Tasks: []marathon.Task{ + { + ID: "test", + AppID: "/test", + Host: "127.0.0.1", + Ports: []int{80}, + }, + }, + }, + expectedFrontends: map[string]*types.Frontend{ + `frontend-test`: { + Backend: "backend-test", + Routes: map[string]types.Route{ + `route-host-test`: { + Rule: "Host", + Value: "test.docker.localhost", + }, + }, + }, + }, + expectedBackends: map[string]*types.Backend{ + "backend-test": { + Servers: map[string]types.Server{ + "server-test": { + URL: "http://127.0.0.1:80", + Weight: 0, + }, + }, + CircuitBreaker: nil, + LoadBalancer: nil, + }, + }, + }, + } + + for _, c := range cases { + provider := &Marathon{ + Domain: "docker.localhost", + marathonClient: &fakeClient{ + applicationsError: c.applicationsError, + applications: c.applications, + tasksError: c.tasksError, + tasks: c.tasks, + }, + } + actualConfig := provider.loadMarathonConfig() + if c.expectedNil { + if actualConfig != nil { + t.Fatalf("Should have been nil, got %v", actualConfig) + } + } else { + // Compare backends + if !reflect.DeepEqual(actualConfig.Backends, c.expectedBackends) { + t.Fatalf("expected %#v, got %#v", c.expectedBackends, actualConfig.Backends) + } + if !reflect.DeepEqual(actualConfig.Frontends, c.expectedFrontends) { + t.Fatalf("expected %#v, got %#v", c.expectedFrontends, actualConfig.Frontends) + } + } + } +} + +func TestMarathonTaskFilter(t *testing.T) { + cases := []struct { + task marathon.Task + applications *marathon.Applications + expected bool + }{ + { + task: marathon.Task{}, + applications: &marathon.Applications{}, + expected: false, + }, + { + task: marathon.Task{ + AppID: "test", + Ports: []int{80}, + }, + applications: &marathon.Applications{}, + expected: false, + }, + { + task: marathon.Task{ + AppID: "test", + Ports: []int{80}, + }, + applications: &marathon.Applications{ + Apps: []marathon.Application{ + { + ID: "foo", + }, + }, + }, + expected: false, + }, + { + task: marathon.Task{ + AppID: "foo", + Ports: []int{80}, + }, + applications: &marathon.Applications{ + Apps: []marathon.Application{ + { + ID: "foo", + Ports: []int{80, 443}, + }, + }, + }, + expected: false, + }, + { + task: marathon.Task{ + AppID: "foo", + Ports: []int{80}, + }, + applications: &marathon.Applications{ + Apps: []marathon.Application{ + { + ID: "foo", + Ports: []int{80}, + Labels: map[string]string{ + "traefik.enable": "false", + }, + }, + }, + }, + expected: false, + }, + { + task: marathon.Task{ + AppID: "foo", + Ports: []int{80}, + }, + applications: &marathon.Applications{ + Apps: []marathon.Application{ + { + ID: "foo", + Ports: []int{80}, + HealthChecks: []*marathon.HealthCheck{ + marathon.NewDefaultHealthCheck(), + }, + }, + }, + }, + expected: false, + }, + { + task: marathon.Task{ + AppID: "foo", + Ports: []int{80}, + HealthCheckResult: []*marathon.HealthCheckResult{ + { + Alive: false, + }, + }, + }, + applications: &marathon.Applications{ + Apps: []marathon.Application{ + { + ID: "foo", + Ports: []int{80}, + HealthChecks: []*marathon.HealthCheck{ + marathon.NewDefaultHealthCheck(), + }, + }, + }, + }, + expected: false, + }, + { + task: marathon.Task{ + AppID: "foo", + Ports: []int{80}, + HealthCheckResult: []*marathon.HealthCheckResult{ + { + Alive: true, + }, + { + Alive: false, + }, + }, + }, + applications: &marathon.Applications{ + Apps: []marathon.Application{ + { + ID: "foo", + Ports: []int{80}, + HealthChecks: []*marathon.HealthCheck{ + marathon.NewDefaultHealthCheck(), + }, + }, + }, + }, + expected: false, + }, + { + task: marathon.Task{ + AppID: "foo", + Ports: []int{80}, + }, + applications: &marathon.Applications{ + Apps: []marathon.Application{ + { + ID: "foo", + Ports: []int{80}, + }, + }, + }, + expected: true, + }, + { + task: marathon.Task{ + AppID: "foo", + Ports: []int{80}, + HealthCheckResult: []*marathon.HealthCheckResult{ + { + Alive: true, + }, + }, + }, + applications: &marathon.Applications{ + Apps: []marathon.Application{ + { + ID: "foo", + Ports: []int{80}, + HealthChecks: []*marathon.HealthCheck{ + marathon.NewDefaultHealthCheck(), + }, + }, + }, + }, + expected: true, + }, + } + + for _, c := range cases { + actual := taskFilter(c.task, c.applications) + if actual != c.expected { + t.Fatalf("expected %v, got %v", c.expected, actual) + } + } +} + +func TestMarathonApplicationFilter(t *testing.T) { + cases := []struct { + application marathon.Application + filteredTasks []marathon.Task + expected bool + }{ + { + application: marathon.Application{}, + filteredTasks: []marathon.Task{}, + expected: false, + }, + { + application: marathon.Application{ + ID: "test", + }, + filteredTasks: []marathon.Task{}, + expected: false, + }, + { + application: marathon.Application{ + ID: "foo", + }, + filteredTasks: []marathon.Task{ + { + AppID: "bar", + }, + }, + expected: false, + }, + { + application: marathon.Application{ + ID: "foo", + }, + filteredTasks: []marathon.Task{ + { + AppID: "foo", + }, + }, + expected: true, + }, + } + + for _, c := range cases { + actual := applicationFilter(c.application, c.filteredTasks) + if actual != c.expected { + t.Fatalf("expected %v, got %v", c.expected, actual) + } + } +} + +func TestMarathonGetPort(t *testing.T) { + provider := &Marathon{} + + cases := []struct { + task marathon.Task + expected string + }{ + { + task: marathon.Task{}, + expected: "", + }, + { + task: marathon.Task{ + Ports: []int{80}, + }, + expected: "80", + }, + { + task: marathon.Task{ + Ports: []int{80, 443}, + }, + expected: "80", + }, + } + + for _, c := range cases { + actual := provider.getPort(c.task) + if actual != c.expected { + t.Fatalf("expected %q, got %q", c.expected, actual) + } + } +} + +func TestMarathonGetWeigh(t *testing.T) { + provider := &Marathon{} + + applications := []struct { + applications []marathon.Application + task marathon.Task + expected string + }{ + { + applications: []marathon.Application{}, + task: marathon.Task{}, + expected: "0", + }, + { + applications: []marathon.Application{ + { + ID: "test1", + Labels: map[string]string{ + "traefik.weight": "10", + }, + }, + }, + task: marathon.Task{ + AppID: "test2", + }, + expected: "0", + }, + { + applications: []marathon.Application{ + { + ID: "test", + Labels: map[string]string{ + "traefik.test": "10", + }, + }, + }, + task: marathon.Task{ + AppID: "test", + }, + expected: "0", + }, + { + applications: []marathon.Application{ + { + ID: "test", + Labels: map[string]string{ + "traefik.weight": "10", + }, + }, + }, + task: marathon.Task{ + AppID: "test", + }, + expected: "10", + }, + } + + for _, a := range applications { + actual := provider.getWeight(a.task, a.applications) + if actual != a.expected { + t.Fatalf("expected %q, got %q", a.expected, actual) + } + } +} + +func TestMarathonGetDomain(t *testing.T) { + provider := &Marathon{ + Domain: "docker.localhost", + } + + applications := []struct { + application marathon.Application + expected string + }{ + { + application: marathon.Application{}, + expected: "docker.localhost", + }, + { + application: marathon.Application{ + Labels: map[string]string{ + "traefik.domain": "foo.bar", + }, + }, + expected: "foo.bar", + }, + } + + for _, a := range applications { + actual := provider.getDomain(a.application) + if actual != a.expected { + t.Fatalf("expected %q, got %q", a.expected, actual) + } + } +} + +func TestMarathonGetProtocol(t *testing.T) { + provider := &Marathon{} + + applications := []struct { + applications []marathon.Application + task marathon.Task + expected string + }{ + { + applications: []marathon.Application{}, + task: marathon.Task{}, + expected: "http", + }, + { + applications: []marathon.Application{ + { + ID: "test1", + Labels: map[string]string{ + "traefik.protocol": "https", + }, + }, + }, + task: marathon.Task{ + AppID: "test2", + }, + expected: "http", + }, + { + applications: []marathon.Application{ + { + ID: "test", + Labels: map[string]string{ + "traefik.foo": "bar", + }, + }, + }, + task: marathon.Task{ + AppID: "test", + }, + expected: "http", + }, + { + applications: []marathon.Application{ + { + ID: "test", + Labels: map[string]string{ + "traefik.protocol": "https", + }, + }, + }, + task: marathon.Task{ + AppID: "test", + }, + expected: "https", + }, + } + + for _, a := range applications { + actual := provider.getProtocol(a.task, a.applications) + if actual != a.expected { + t.Fatalf("expected %q, got %q", a.expected, actual) + } + } +} + +func TestMarathonGetPassHostHeader(t *testing.T) { + provider := &Marathon{} + + applications := []struct { + application marathon.Application + expected string + }{ + { + application: marathon.Application{}, + expected: "false", + }, + { + application: marathon.Application{ + Labels: map[string]string{ + "traefik.frontend.passHostHeader": "true", + }, + }, + expected: "true", + }, + } + + for _, a := range applications { + actual := provider.getPassHostHeader(a.application) + if actual != a.expected { + t.Fatalf("expected %q, got %q", a.expected, actual) + } + } +} + +func TestMarathonGetFrontendValue(t *testing.T) { + provider := &Marathon{ + Domain: "docker.localhost", + } + + applications := []struct { + application marathon.Application + expected string + }{ + { + application: marathon.Application{}, + expected: ".docker.localhost", + }, + { + application: marathon.Application{ + ID: "test", + }, + expected: "test.docker.localhost", + }, + { + application: marathon.Application{ + Labels: map[string]string{ + "traefik.frontend.value": "foo.bar", + }, + }, + expected: "foo.bar", + }, + } + + for _, a := range applications { + actual := provider.getFrontendValue(a.application) + if actual != a.expected { + t.Fatalf("expected %q, got %q", a.expected, actual) + } + } +} + +func TestMarathonGetFrontendRule(t *testing.T) { + provider := &Marathon{} + + applications := []struct { + application marathon.Application + expected string + }{ + { + application: marathon.Application{}, + expected: "Host", + }, + { + application: marathon.Application{ + Labels: map[string]string{ + "traefik.frontend.rule": "Header", + }, + }, + expected: "Header", + }, + } + + for _, a := range applications { + actual := provider.getFrontendRule(a.application) + if actual != a.expected { + t.Fatalf("expected %q, got %q", a.expected, actual) + } + } +} diff --git a/provider/provider.go b/provider/provider.go index f1001a953..d2ceb5ebc 100644 --- a/provider/provider.go +++ b/provider/provider.go @@ -1,6 +1,15 @@ package provider -import "github.com/emilevauge/traefik/types" +import ( + "bytes" + "io/ioutil" + "strings" + "text/template" + + "github.com/BurntSushi/toml" + "github.com/emilevauge/traefik/autogen" + "github.com/emilevauge/traefik/types" +) // Provider defines methods of a provider. type Provider interface { @@ -8,3 +17,51 @@ type Provider interface { // using the given configuration channel. Provide(configurationChan chan<- types.ConfigMessage) error } + +type baseProvider struct { + Watch bool + Filename string +} + +func (p *baseProvider) getConfiguration(defaultTemplateFile string, funcMap template.FuncMap, templateObjects interface{}) (*types.Configuration, error) { + var ( + buf []byte + err error + ) + configuration := new(types.Configuration) + tmpl := template.New(p.Filename).Funcs(funcMap) + if len(p.Filename) > 0 { + buf, err = ioutil.ReadFile(p.Filename) + if err != nil { + return nil, err + } + } else { + buf, err = autogen.Asset(defaultTemplateFile) + if err != nil { + return nil, err + } + } + _, err = tmpl.Parse(string(buf)) + if err != nil { + return nil, err + } + + var buffer bytes.Buffer + err = tmpl.Execute(&buffer, templateObjects) + if err != nil { + return nil, err + } + + if _, err := toml.Decode(buffer.String(), configuration); err != nil { + return nil, err + } + return configuration, nil +} + +func replace(s1 string, s2 string, s3 string) string { + return strings.Replace(s3, s1, s2, -1) +} + +func getEscapedName(name string) string { + return strings.Replace(name, "/", "", -1) +} diff --git a/provider/provider_test.go b/provider/provider_test.go new file mode 100644 index 000000000..d9578fee2 --- /dev/null +++ b/provider/provider_test.go @@ -0,0 +1,170 @@ +package provider + +import ( + "io/ioutil" + "os" + "strings" + "testing" + "text/template" +) + +type myProvider struct { + baseProvider +} + +func (p *myProvider) Foo() string { + return "bar" +} + +func TestConfigurationErrors(t *testing.T) { + templateErrorFile, err := ioutil.TempFile("", "provider-configuration-error") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(templateErrorFile.Name()) + data := []byte("Not a valid template {{ Bar }}") + err = ioutil.WriteFile(templateErrorFile.Name(), data, 0700) + if err != nil { + t.Fatal(err) + } + + templateInvalidTOMLFile, err := ioutil.TempFile("", "provider-configuration-error") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(templateInvalidTOMLFile.Name()) + data = []byte(`Hello {{ .Name }} +{{ Foo }}`) + err = ioutil.WriteFile(templateInvalidTOMLFile.Name(), data, 0700) + if err != nil { + t.Fatal(err) + } + + invalids := []struct { + provider *myProvider + defaultTemplate string + expectedError string + funcMap template.FuncMap + templateObjects interface{} + }{ + { + provider: &myProvider{ + baseProvider{ + Filename: "/non/existent/template.tmpl", + }, + }, + expectedError: "open /non/existent/template.tmpl: no such file or directory", + }, + { + provider: &myProvider{}, + defaultTemplate: "non/existent/template.tmpl", + expectedError: "Asset non/existent/template.tmpl not found", + }, + { + provider: &myProvider{ + baseProvider{ + Filename: templateErrorFile.Name(), + }, + }, + expectedError: `function "Bar" not defined`, + }, + { + provider: &myProvider{ + baseProvider{ + Filename: templateInvalidTOMLFile.Name(), + }, + }, + expectedError: "Near line 1, key 'Hello': Near line 1: Expected key separator '=', but got '<' instead", + funcMap: template.FuncMap{ + "Foo": func() string { + return "bar" + }, + }, + templateObjects: struct{ Name string }{Name: "bar"}, + }, + } + + for _, invalid := range invalids { + configuration, err := invalid.provider.getConfiguration(invalid.defaultTemplate, invalid.funcMap, nil) + if err == nil || !strings.Contains(err.Error(), invalid.expectedError) { + t.Fatalf("should have generate an error with %q, got %v", invalid.expectedError, err) + } + if configuration != nil { + t.Fatalf("shouldn't have return a configuration object : %v", configuration) + } + } +} + +func TestGetConfiguration(t *testing.T) { + templateFile, err := ioutil.TempFile("", "provider-configuration") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(templateFile.Name()) + data := []byte(`[backends] + [backends.backend1] + [backends.backend1.circuitbreaker] + expression = "NetworkErrorRatio() > 0.5" + [backends.backend1.servers.server1] + url = "http://172.17.0.2:80" + weight = 10 + [backends.backend1.servers.server2] + url = "http://172.17.0.3:80" + weight = 1 + +[frontends] + [frontends.frontend1] + backend = "backend1" + passHostHeader = true + [frontends.frontend11.routes.test_2] + rule = "Path" + value = "/test"`) + err = ioutil.WriteFile(templateFile.Name(), data, 0700) + if err != nil { + t.Fatal(err) + } + + provider := &myProvider{ + baseProvider{ + Filename: templateFile.Name(), + }, + } + configuration, err := provider.getConfiguration(templateFile.Name(), nil, nil) + if err != nil { + t.Fatalf("Shouldn't have error out, got %v", err) + } + if configuration == nil { + t.Fatalf("Configuration should not be nil, but was") + } +} + +func TestReplace(t *testing.T) { + cases := []struct { + str string + expected string + }{ + { + str: "", + expected: "", + }, + { + str: "foo", + expected: "bar", + }, + { + str: "foo foo", + expected: "bar bar", + }, + { + str: "somethingfoo", + expected: "somethingbar", + }, + } + + for _, c := range cases { + actual := replace("foo", "bar", c.str) + if actual != c.expected { + t.Fatalf("expected %q, got %q, for %q", c.expected, actual, c.str) + } + } +} diff --git a/provider/zk.go b/provider/zk.go index 1b5079526..2f379aa7a 100644 --- a/provider/zk.go +++ b/provider/zk.go @@ -1,19 +1,20 @@ package provider -import "github.com/emilevauge/traefik/types" +import ( + "github.com/docker/libkv/store" + "github.com/docker/libkv/store/zookeeper" + "github.com/emilevauge/traefik/types" +) // Zookepper holds configurations of the Zookepper provider. type Zookepper struct { - Watch bool - Endpoint string - Prefix string - Filename string - KvProvider *Kv + Kv } // Provide allows the provider to provide configurations to traefik // using the given configuration channel. func (provider *Zookepper) Provide(configurationChan chan<- types.ConfigMessage) error { - provider.KvProvider = NewZkProvider(provider) - return provider.KvProvider.provide(configurationChan) + provider.StoreType = store.ZK + zookeeper.Register() + return provider.provide(configurationChan) }