package metrics import ( "net/http" "net/http/httptest" "reflect" "strings" "testing" "github.com/go-kit/kit/metrics" "github.com/stretchr/testify/assert" "google.golang.org/grpc/codes" ) // CollectingCounter is a metrics.Counter implementation that enables access to the CounterValue and LastLabelValues. type CollectingCounter struct { CounterValue float64 LastLabelValues []string } // With is there to satisfy the metrics.Counter interface. func (c *CollectingCounter) With(labelValues ...string) metrics.Counter { c.LastLabelValues = labelValues return c } // Add is there to satisfy the metrics.Counter interface. func (c *CollectingCounter) Add(delta float64) { c.CounterValue += delta } func TestMetricsRetryListener(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) retryMetrics := newCollectingRetryMetrics() retryListener := NewRetryListener(retryMetrics, "serviceName") retryListener.Retried(req, 1) retryListener.Retried(req, 2) wantCounterValue := float64(2) if retryMetrics.retriesCounter.CounterValue != wantCounterValue { t.Errorf("got counter value of %f, want %f", retryMetrics.retriesCounter.CounterValue, wantCounterValue) } wantLabelValues := []string{"service", "serviceName"} if !reflect.DeepEqual(retryMetrics.retriesCounter.LastLabelValues, wantLabelValues) { t.Errorf("wrong label values %v used, want %v", retryMetrics.retriesCounter.LastLabelValues, wantLabelValues) } } // collectingRetryMetrics is an implementation of the retryMetrics interface that can be used inside tests to collect the times Add() was called. type collectingRetryMetrics struct { retriesCounter *CollectingCounter } func newCollectingRetryMetrics() *collectingRetryMetrics { return &collectingRetryMetrics{retriesCounter: &CollectingCounter{}} } func (m *collectingRetryMetrics) ServiceRetriesCounter() metrics.Counter { return m.retriesCounter } type rwWithCloseNotify struct { *httptest.ResponseRecorder } func (r *rwWithCloseNotify) CloseNotify() <-chan bool { panic("implement me") } func TestCloseNotifier(t *testing.T) { testCases := []struct { rw http.ResponseWriter desc string implementsCloseNotifier bool }{ { rw: httptest.NewRecorder(), desc: "does not implement CloseNotifier", implementsCloseNotifier: false, }, { rw: &rwWithCloseNotify{httptest.NewRecorder()}, desc: "implements CloseNotifier", implementsCloseNotifier: true, }, } for _, test := range testCases { test := test t.Run(test.desc, func(t *testing.T) { t.Parallel() _, ok := test.rw.(http.CloseNotifier) assert.Equal(t, test.implementsCloseNotifier, ok) rw := newResponseRecorder(test.rw) _, impl := rw.(http.CloseNotifier) assert.Equal(t, test.implementsCloseNotifier, impl) }) } } func Test_getMethod(t *testing.T) { testCases := []struct { method string expected string }{ { method: http.MethodGet, expected: http.MethodGet, }, { method: strings.ToLower(http.MethodGet), expected: "EXTENSION_METHOD", }, { method: "THIS_IS_NOT_A_VALID_METHOD", expected: "EXTENSION_METHOD", }, } for _, test := range testCases { test := test t.Run(test.method, func(t *testing.T) { t.Parallel() request := httptest.NewRequest(test.method, "http://example.com", nil) assert.Equal(t, test.expected, getMethod(request)) }) } } func Test_getRequestProtocol(t *testing.T) { testCases := []struct { desc string headers http.Header expected string }{ { desc: "default", expected: protoHTTP, }, { desc: "websocket", headers: http.Header{ "Connection": []string{"upgrade"}, "Upgrade": []string{"websocket"}, }, expected: protoWebsocket, }, { desc: "SSE", headers: http.Header{ "Accept": []string{"text/event-stream"}, }, expected: protoSSE, }, { desc: "grpc web", headers: http.Header{ "Content-Type": []string{"application/grpc-web-text"}, }, expected: protoGRPCWeb, }, { desc: "grpc", headers: http.Header{ "Content-Type": []string{"application/grpc-text"}, }, expected: protoGRPC, }, } for _, test := range testCases { test := test t.Run(test.desc, func(t *testing.T) { t.Parallel() req := httptest.NewRequest(http.MethodGet, "https://localhost", http.NoBody) req.Header = test.headers protocol := getRequestProtocol(req) assert.Equal(t, test.expected, protocol) }) } } func Test_grpcStatusCode(t *testing.T) { testCases := []struct { desc string status string expected codes.Code }{ { desc: "invalid number", status: "foo", expected: codes.Unknown, }, { desc: "number", status: "1", expected: codes.Canceled, }, { desc: "invalid string", status: `"foo"`, expected: codes.Unknown, }, { desc: "string", status: `"OK"`, expected: codes.OK, }, } for _, test := range testCases { test := test t.Run(test.desc, func(t *testing.T) { t.Parallel() rw := httptest.NewRecorder() rw.Header().Set("Grpc-Status", test.status) code := grpcStatusCode(rw) assert.EqualValues(t, test.expected, code) }) } }