traefik/middlewares/redirect/redirect_test.go
2018-02-08 09:30:06 +01:00

182 lines
5 KiB
Go

package redirect
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/configuration"
"github.com/containous/traefik/testhelpers"
"github.com/containous/traefik/tls"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewEntryPointHandler(t *testing.T) {
testCases := []struct {
desc string
entryPoint *configuration.EntryPoint
permanent bool
url string
expectedURL string
expectedStatus int
errorExpected bool
}{
{
desc: "HTTP to HTTPS",
entryPoint: &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}},
url: "http://foo:80",
expectedURL: "https://foo:443",
expectedStatus: http.StatusFound,
},
{
desc: "HTTPS to HTTP",
entryPoint: &configuration.EntryPoint{Address: ":80"},
url: "https://foo:443",
expectedURL: "http://foo:80",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP to HTTP",
entryPoint: &configuration.EntryPoint{Address: ":88"},
url: "http://foo:80",
expectedURL: "http://foo:88",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP to HTTPS permanent",
entryPoint: &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}},
permanent: true,
url: "http://foo:80",
expectedURL: "https://foo:443",
expectedStatus: http.StatusMovedPermanently,
},
{
desc: "HTTPS to HTTP permanent",
entryPoint: &configuration.EntryPoint{Address: ":80"},
permanent: true,
url: "https://foo:443",
expectedURL: "http://foo:80",
expectedStatus: http.StatusMovedPermanently,
},
{
desc: "invalid address",
entryPoint: &configuration.EntryPoint{Address: ":foo", TLS: &tls.TLS{}},
url: "http://foo:80",
errorExpected: true,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
handler, err := NewEntryPointHandler(test.entryPoint, test.permanent)
if test.errorExpected {
require.Error(t, err)
} else {
require.NoError(t, err)
recorder := httptest.NewRecorder()
r := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
handler.ServeHTTP(recorder, r, nil)
location, err := recorder.Result().Location()
require.NoError(t, err)
assert.Equal(t, test.expectedURL, location.String())
assert.Equal(t, test.expectedStatus, recorder.Code)
}
})
}
}
func TestNewRegexHandler(t *testing.T) {
testCases := []struct {
desc string
regex string
replacement string
permanent bool
url string
expectedURL string
expectedStatus int
errorExpected bool
}{
{
desc: "simple redirection",
regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
replacement: "https://${1}bar$2:443$4",
url: "http://foo.com:80",
expectedURL: "https://foobar.com:443",
expectedStatus: http.StatusFound,
},
{
desc: "use request header",
regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
replacement: `https://${1}{{ .Request.Header.Get "X-Foo" }}$2:443$4`,
url: "http://foo.com:80",
expectedURL: "https://foobar.com:443",
expectedStatus: http.StatusFound,
},
{
desc: "URL doesn't match regex",
regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
replacement: "https://${1}bar$2:443$4",
url: "http://bar.com:80",
expectedStatus: http.StatusOK,
},
{
desc: "invalid rewritten URL",
regex: `^(.*)$`,
replacement: "http://192.168.0.%31/",
url: "http://foo.com:80",
expectedStatus: http.StatusBadGateway,
},
{
desc: "invalid regex",
regex: `^(.*`,
replacement: "$1",
url: "http://foo.com:80",
errorExpected: true,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
handler, err := NewRegexHandler(test.regex, test.replacement, test.permanent)
if test.errorExpected {
require.Nil(t, handler)
require.Error(t, err)
} else {
require.NotNil(t, handler)
require.NoError(t, err)
recorder := httptest.NewRecorder()
r := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
r.Header.Set("X-Foo", "bar")
next := func(rw http.ResponseWriter, req *http.Request) {}
handler.ServeHTTP(recorder, r, next)
if test.expectedStatus == http.StatusMovedPermanently || test.expectedStatus == http.StatusFound {
assert.Equal(t, test.expectedStatus, recorder.Code)
location, err := recorder.Result().Location()
require.NoError(t, err)
assert.Equal(t, test.expectedURL, location.String())
} else {
assert.Equal(t, test.expectedStatus, recorder.Code)
location, err := recorder.Result().Location()
require.Error(t, err, "ghf %v", location)
}
}
})
}
}