package forwardedheaders

import (
	"crypto/tls"
	"net/http"
	"net/http/httptest"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestServeHTTP(t *testing.T) {
	testCases := []struct {
		desc            string
		insecure        bool
		trustedIps      []string
		incomingHeaders map[string][]string
		remoteAddr      string
		expectedHeaders map[string]string
		tls             bool
		websocket       bool
		host            string
	}{
		{
			desc:            "all Empty",
			insecure:        true,
			trustedIps:      nil,
			remoteAddr:      "",
			incomingHeaders: map[string][]string{},
			expectedHeaders: map[string]string{
				xForwardedFor:               "",
				xForwardedURI:               "",
				xForwardedMethod:            "",
				xForwardedTLSClientCert:     "",
				xForwardedTLSClientCertInfo: "",
			},
		},
		{
			desc:       "insecure true with incoming X-Forwarded headers",
			insecure:   true,
			trustedIps: nil,
			remoteAddr: "",
			incomingHeaders: map[string][]string{
				xForwardedFor:               {"10.0.1.0, 10.0.1.12"},
				xForwardedURI:               {"/bar"},
				xForwardedMethod:            {"GET"},
				xForwardedTLSClientCert:     {"Cert"},
				xForwardedTLSClientCertInfo: {"CertInfo"},
			},
			expectedHeaders: map[string]string{
				xForwardedFor:               "10.0.1.0, 10.0.1.12",
				xForwardedURI:               "/bar",
				xForwardedMethod:            "GET",
				xForwardedTLSClientCert:     "Cert",
				xForwardedTLSClientCertInfo: "CertInfo",
			},
		},
		{
			desc:       "insecure false with incoming X-Forwarded headers",
			insecure:   false,
			trustedIps: nil,
			remoteAddr: "",
			incomingHeaders: map[string][]string{
				xForwardedFor:               {"10.0.1.0, 10.0.1.12"},
				xForwardedURI:               {"/bar"},
				xForwardedMethod:            {"GET"},
				xForwardedTLSClientCert:     {"Cert"},
				xForwardedTLSClientCertInfo: {"CertInfo"},
			},
			expectedHeaders: map[string]string{
				xForwardedFor:               "",
				xForwardedURI:               "",
				xForwardedMethod:            "",
				xForwardedTLSClientCert:     "",
				xForwardedTLSClientCertInfo: "",
			},
		},
		{
			desc:       "insecure false with incoming X-Forwarded headers and valid Trusted Ips",
			insecure:   false,
			trustedIps: []string{"10.0.1.100"},
			remoteAddr: "10.0.1.100:80",
			incomingHeaders: map[string][]string{
				xForwardedFor:               {"10.0.1.0, 10.0.1.12"},
				xForwardedURI:               {"/bar"},
				xForwardedMethod:            {"GET"},
				xForwardedTLSClientCert:     {"Cert"},
				xForwardedTLSClientCertInfo: {"CertInfo"},
			},
			expectedHeaders: map[string]string{
				xForwardedFor:               "10.0.1.0, 10.0.1.12",
				xForwardedURI:               "/bar",
				xForwardedMethod:            "GET",
				xForwardedTLSClientCert:     "Cert",
				xForwardedTLSClientCertInfo: "CertInfo",
			},
		},
		{
			desc:       "insecure false with incoming X-Forwarded headers and invalid Trusted Ips",
			insecure:   false,
			trustedIps: []string{"10.0.1.100"},
			remoteAddr: "10.0.1.101:80",
			incomingHeaders: map[string][]string{
				xForwardedFor:               {"10.0.1.0, 10.0.1.12"},
				xForwardedURI:               {"/bar"},
				xForwardedMethod:            {"GET"},
				xForwardedTLSClientCert:     {"Cert"},
				xForwardedTLSClientCertInfo: {"CertInfo"},
			},
			expectedHeaders: map[string]string{
				xForwardedFor:               "",
				xForwardedURI:               "",
				xForwardedMethod:            "",
				xForwardedTLSClientCert:     "",
				xForwardedTLSClientCertInfo: "",
			},
		},
		{
			desc:       "insecure false with incoming X-Forwarded headers and valid Trusted Ips CIDR",
			insecure:   false,
			trustedIps: []string{"1.2.3.4/24"},
			remoteAddr: "1.2.3.156:80",
			incomingHeaders: map[string][]string{
				xForwardedFor:               {"10.0.1.0, 10.0.1.12"},
				xForwardedURI:               {"/bar"},
				xForwardedMethod:            {"GET"},
				xForwardedTLSClientCert:     {"Cert"},
				xForwardedTLSClientCertInfo: {"CertInfo"},
			},
			expectedHeaders: map[string]string{
				xForwardedFor:               "10.0.1.0, 10.0.1.12",
				xForwardedURI:               "/bar",
				xForwardedMethod:            "GET",
				xForwardedTLSClientCert:     "Cert",
				xForwardedTLSClientCertInfo: "CertInfo",
			},
		},
		{
			desc:       "insecure false with incoming X-Forwarded headers and invalid Trusted Ips CIDR",
			insecure:   false,
			trustedIps: []string{"1.2.3.4/24"},
			remoteAddr: "10.0.1.101:80",
			incomingHeaders: map[string][]string{
				xForwardedFor:               {"10.0.1.0, 10.0.1.12"},
				xForwardedURI:               {"/bar"},
				xForwardedMethod:            {"GET"},
				xForwardedTLSClientCert:     {"Cert"},
				xForwardedTLSClientCertInfo: {"CertInfo"},
			},
			expectedHeaders: map[string]string{
				xForwardedFor:               "",
				xForwardedURI:               "",
				xForwardedMethod:            "",
				xForwardedTLSClientCert:     "",
				xForwardedTLSClientCertInfo: "",
			},
		},
		{
			desc:     "xForwardedFor with multiple header(s) values",
			insecure: true,
			incomingHeaders: map[string][]string{
				xForwardedFor: {
					"10.0.0.4, 10.0.0.3",
					"10.0.0.2, 10.0.0.1",
					"10.0.0.0",
				},
			},
			expectedHeaders: map[string]string{
				xForwardedFor: "10.0.0.4, 10.0.0.3, 10.0.0.2, 10.0.0.1, 10.0.0.0",
			},
		},
		{
			desc:       "xRealIP populated from remote address",
			remoteAddr: "10.0.1.101:80",
			expectedHeaders: map[string]string{
				xRealIP: "10.0.1.101",
			},
		},
		{
			desc:       "xRealIP was already populated from previous headers",
			insecure:   true,
			remoteAddr: "10.0.1.101:80",
			incomingHeaders: map[string][]string{
				xRealIP: {"10.0.1.12"},
			},
			expectedHeaders: map[string]string{
				xRealIP: "10.0.1.12",
			},
		},
		{
			desc: "xForwardedProto with no tls",
			tls:  false,
			expectedHeaders: map[string]string{
				xForwardedProto: "http",
			},
		},
		{
			desc: "xForwardedProto with tls",
			tls:  true,
			expectedHeaders: map[string]string{
				xForwardedProto: "https",
			},
		},
		{
			desc:      "xForwardedProto with websocket",
			tls:       false,
			websocket: true,
			expectedHeaders: map[string]string{
				xForwardedProto: "ws",
			},
		},
		{
			desc:      "xForwardedProto with websocket and tls",
			tls:       true,
			websocket: true,
			expectedHeaders: map[string]string{
				xForwardedProto: "wss",
			},
		},
		{
			desc:      "xForwardedProto with websocket and tls and already x-forwarded-proto with wss",
			tls:       true,
			websocket: true,
			incomingHeaders: map[string][]string{
				xForwardedProto: {"wss"},
			},
			expectedHeaders: map[string]string{
				xForwardedProto: "wss",
			},
		},
		{
			desc: "xForwardedPort with explicit port",
			host: "foo.com:8080",
			expectedHeaders: map[string]string{
				xForwardedPort: "8080",
			},
		},
		{
			desc: "xForwardedPort with implicit tls port from proto header",
			// setting insecure just so our initial xForwardedProto does not get cleaned
			insecure: true,
			incomingHeaders: map[string][]string{
				xForwardedProto: {"https"},
			},
			expectedHeaders: map[string]string{
				xForwardedProto: "https",
				xForwardedPort:  "443",
			},
		},
		{
			desc: "xForwardedPort with implicit tls port from TLS in req",
			tls:  true,
			expectedHeaders: map[string]string{
				xForwardedPort: "443",
			},
		},
		{
			desc: "xForwardedHost from req host",
			host: "foo.com:8080",
			expectedHeaders: map[string]string{
				xForwardedHost: "foo.com:8080",
			},
		},
		{
			desc: "xForwardedServer from req XForwarded",
			host: "foo.com:8080",
			expectedHeaders: map[string]string{
				xForwardedServer: "foo.com:8080",
			},
		},
	}

	for _, test := range testCases {
		test := test
		t.Run(test.desc, func(t *testing.T) {
			t.Parallel()

			req, err := http.NewRequest(http.MethodGet, "", nil)
			require.NoError(t, err)

			req.RemoteAddr = test.remoteAddr

			if test.tls {
				req.TLS = &tls.ConnectionState{}
			}

			if test.websocket {
				req.Header.Set(connection, "upgrade")
				req.Header.Set(upgrade, "websocket")
			}

			if test.host != "" {
				req.Host = test.host
			}

			for k, values := range test.incomingHeaders {
				for _, value := range values {
					req.Header.Add(k, value)
				}
			}

			m, err := NewXForwarded(test.insecure, test.trustedIps,
				http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
			require.NoError(t, err)

			if test.host != "" {
				m.hostname = test.host
			}

			m.ServeHTTP(nil, req)

			for k, v := range test.expectedHeaders {
				assert.Equal(t, v, req.Header.Get(k))
			}
		})
	}
}

func Test_isWebsocketRequest(t *testing.T) {
	testCases := []struct {
		desc             string
		connectionHeader string
		upgradeHeader    string
		assert           assert.BoolAssertionFunc
	}{
		{
			desc:             "connection Header multiple values middle",
			connectionHeader: "foo,upgrade,bar",
			upgradeHeader:    "websocket",
			assert:           assert.True,
		},
		{
			desc:             "connection Header multiple values end",
			connectionHeader: "foo,bar,upgrade",
			upgradeHeader:    "websocket",
			assert:           assert.True,
		},
		{
			desc:             "connection Header multiple values begin",
			connectionHeader: "upgrade,foo,bar",
			upgradeHeader:    "websocket",
			assert:           assert.True,
		},
		{
			desc:             "connection Header no upgrade",
			connectionHeader: "foo,bar",
			upgradeHeader:    "websocket",
			assert:           assert.False,
		},
		{
			desc:             "connection Header empty",
			connectionHeader: "",
			upgradeHeader:    "websocket",
			assert:           assert.False,
		},
		{
			desc:             "no header values",
			connectionHeader: "foo,bar",
			upgradeHeader:    "foo,bar",
			assert:           assert.False,
		},
		{
			desc:             "upgrade header multiple values",
			connectionHeader: "upgrade",
			upgradeHeader:    "foo,bar,websocket",
			assert:           assert.True,
		},
	}

	for _, test := range testCases {
		test := test
		t.Run(test.desc, func(t *testing.T) {
			t.Parallel()

			req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)

			req.Header.Set(connection, test.connectionHeader)
			req.Header.Set(upgrade, test.upgradeHeader)

			ok := isWebsocketRequest(req)

			test.assert(t, ok)
		})
	}
}