From 10acbb8d921b3b8cd13c3dcd0ad86a5e40730737 Mon Sep 17 00:00:00 2001 From: Ludovic Fernandez Date: Tue, 3 Sep 2019 15:22:05 +0200 Subject: [PATCH] Don't panic with undefined middleware --- pkg/responsemodifiers/response_modifier.go | 5 +++++ .../response_modifier_test.go | 21 +++++++++++++------ pkg/server/middleware/middlewares.go | 6 +++++- pkg/server/middleware/middlewares_test.go | 1 - 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/pkg/responsemodifiers/response_modifier.go b/pkg/responsemodifiers/response_modifier.go index 507627d69..97c83e895 100644 --- a/pkg/responsemodifiers/response_modifier.go +++ b/pkg/responsemodifiers/response_modifier.go @@ -23,6 +23,11 @@ func (f *Builder) Build(ctx context.Context, names []string) func(*http.Response for _, middleName := range names { if conf, ok := f.configs[middleName]; ok { + if conf == nil || conf.Middleware == nil { + getLogger(ctx, middleName, "undefined").Error("Invalid Middleware configuration (ResponseModifier)") + continue + } + if conf.Headers != nil { getLogger(ctx, middleName, "Headers").Debug("Creating Middleware (ResponseModifier)") diff --git a/pkg/responsemodifiers/response_modifier_test.go b/pkg/responsemodifiers/response_modifier_test.go index 79c1f1132..9b3d7806b 100644 --- a/pkg/responsemodifiers/response_modifier_test.go +++ b/pkg/responsemodifiers/response_modifier_test.go @@ -47,7 +47,7 @@ func TestBuilderBuild(t *testing.T) { assertResponse: func(t *testing.T, resp *http.Response) { t.Helper() - assert.Equal(t, resp.Header.Get("X-Foo"), "foo") + assert.Equal(t, "foo", resp.Header.Get("X-Foo")) }, }, { @@ -85,7 +85,7 @@ func TestBuilderBuild(t *testing.T) { assertResponse: func(t *testing.T, resp *http.Response) { t.Helper() - assert.Equal(t, resp.Header.Get("Referrer-Policy"), "no-referrer") + assert.Equal(t, "no-referrer", resp.Header.Get("Referrer-Policy")) }, }, { @@ -107,8 +107,8 @@ func TestBuilderBuild(t *testing.T) { assertResponse: func(t *testing.T, resp *http.Response) { t.Helper() - assert.Equal(t, resp.Header.Get("X-Foo"), "foo") - assert.Equal(t, resp.Header.Get("X-Bar"), "bar") + assert.Equal(t, "foo", resp.Header.Get("X-Foo")) + assert.Equal(t, "bar", resp.Header.Get("X-Bar")) }, }, { @@ -130,7 +130,7 @@ func TestBuilderBuild(t *testing.T) { assertResponse: func(t *testing.T, resp *http.Response) { t.Helper() - assert.Equal(t, resp.Header.Get("X-Foo"), "foo") + assert.Equal(t, "foo", resp.Header.Get("X-Foo")) }, }, { @@ -157,9 +157,18 @@ func TestBuilderBuild(t *testing.T) { assertResponse: func(t *testing.T, resp *http.Response) { t.Helper() - assert.Equal(t, resp.Header.Get("X-Foo"), "foo") + assert.Equal(t, "foo", resp.Header.Get("X-Foo")) }, }, + { + desc: "nil middleware", + middlewares: []string{"foo"}, + buildResponse: stubResponse, + conf: map[string]*dynamic.Middleware{ + "foo": nil, + }, + assertResponse: func(t *testing.T, resp *http.Response) {}, + }, } for _, test := range testCases { diff --git a/pkg/server/middleware/middlewares.go b/pkg/server/middleware/middlewares.go index 3e5d2c154..e93dee542 100644 --- a/pkg/server/middleware/middlewares.go +++ b/pkg/server/middleware/middlewares.go @@ -102,6 +102,10 @@ func checkRecursion(ctx context.Context, middlewareName string) (context.Context // it is the responsibility of the caller to make sure that b.configs[middlewareName].Middleware exists func (b *Builder) buildConstructor(ctx context.Context, middlewareName string) (alice.Constructor, error) { config := b.configs[middlewareName] + if config == nil || config.Middleware == nil { + return nil, fmt.Errorf("invalid middleware %q configuration", middlewareName) + } + var middleware alice.Constructor badConf := errors.New("cannot create middleware: multi-types middleware not supported, consider declaring two different pieces of middleware instead") @@ -314,7 +318,7 @@ func (b *Builder) buildConstructor(ctx context.Context, middlewareName string) ( } if middleware == nil { - return nil, errors.New("middleware does not exist") + return nil, fmt.Errorf("middleware %q does not exist", middlewareName) } return tracing.Wrap(ctx, middleware), nil diff --git a/pkg/server/middleware/middlewares_test.go b/pkg/server/middleware/middlewares_test.go index b747390c5..52a89484b 100644 --- a/pkg/server/middleware/middlewares_test.go +++ b/pkg/server/middleware/middlewares_test.go @@ -369,7 +369,6 @@ func TestBuilder_buildConstructor(t *testing.T) { t.Parallel() constructor, err := middlewaresBuilder.buildConstructor(context.Background(), test.middlewareID) - require.NoError(t, err) middleware, err2 := constructor(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))