package connectionheader

import (
	"net/http"
	"net/http/httptest"
	"testing"

	"github.com/stretchr/testify/assert"
)

func TestRemover(t *testing.T) {
	testCases := []struct {
		desc       string
		reqHeaders map[string]string
		expected   http.Header
	}{
		{
			desc: "simple remove",
			reqHeaders: map[string]string{
				"Foo":            "bar",
				connectionHeader: "foo",
			},
			expected: http.Header{},
		},
		{
			desc: "remove and Upgrade",
			reqHeaders: map[string]string{
				upgradeHeader:    "test",
				"Foo":            "bar",
				connectionHeader: "Upgrade,foo",
			},
			expected: http.Header{
				upgradeHeader:    []string{"test"},
				connectionHeader: []string{"Upgrade"},
			},
		},
		{
			desc: "no remove",
			reqHeaders: map[string]string{
				"Foo":            "bar",
				connectionHeader: "fii",
			},
			expected: http.Header{
				"Foo": []string{"bar"},
			},
		},
	}

	for _, test := range testCases {
		test := test
		t.Run(test.desc, func(t *testing.T) {
			t.Parallel()

			next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})

			h := Remover(next)

			req := httptest.NewRequest(http.MethodGet, "https://localhost", nil)

			for k, v := range test.reqHeaders {
				req.Header.Set(k, v)
			}

			rw := httptest.NewRecorder()

			h.ServeHTTP(rw, req)

			assert.Equal(t, test.expected, req.Header)
		})
	}
}