diff --git a/openai/openai_test.go b/openai/openai_test.go index 4d21382c..39e8dc58 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -3,7 +3,6 @@ package openai import ( "bytes" "encoding/json" - "fmt" "io" "net/http" "net/http/httptest" @@ -16,49 +15,33 @@ import ( "github.com/stretchr/testify/assert" ) -func TestMiddleware(t *testing.T) { +func TestMiddlewareRequests(t *testing.T) { type testCase struct { Name string Method string Path string - TestPath string Handler func() gin.HandlerFunc - Endpoint func(c *gin.Context) Setup func(t *testing.T, req *http.Request) - Expected func(t *testing.T, resp *httptest.ResponseRecorder) + Expected func(t *testing.T, req *http.Request) + } + + var capturedRequest *http.Request + + captureRequestMiddleware := func() gin.HandlerFunc { + return func(c *gin.Context) { + bodyBytes, _ := io.ReadAll(c.Request.Body) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + capturedRequest = c.Request + c.Next() + } } testCases := []testCase{ { - Name: "chat handler", - Method: http.MethodPost, - Path: "/api/chat", - TestPath: "/api/chat", - Handler: ChatMiddleware, - Endpoint: func(c *gin.Context) { - var chatReq api.ChatRequest - if err := c.ShouldBindJSON(&chatReq); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"}) - return - } - - userMessage := chatReq.Messages[0].Content - var assistantMessage string - - switch userMessage { - case "Hello": - assistantMessage = "Hello!" - default: - assistantMessage = "I'm not sure how to respond to that." - } - - c.JSON(http.StatusOK, api.ChatResponse{ - Message: api.Message{ - Role: "assistant", - Content: assistantMessage, - }, - }) - }, + Name: "chat handler", + Method: http.MethodPost, + Path: "/api/chat", + Handler: ChatMiddleware, Setup: func(t *testing.T, req *http.Request) { body := ChatCompletionRequest{ Model: "test-model", @@ -70,88 +53,26 @@ func TestMiddleware(t *testing.T) { req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) req.Header.Set("Content-Type", "application/json") }, - Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { - assert.Equal(t, http.StatusOK, resp.Code) - - var chatResp ChatCompletion - if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil { + Expected: func(t *testing.T, req *http.Request) { + var chatReq api.ChatRequest + if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil { t.Fatal(err) } - if chatResp.Object != "chat.completion" { - t.Fatalf("expected chat.completion, got %s", chatResp.Object) + if chatReq.Messages[0].Role != "user" { + t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role) } - if chatResp.Choices[0].Message.Content != "Hello!" { - t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content) + if chatReq.Messages[0].Content != "Hello" { + t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content) } }, }, { - Name: "completions handler", - Method: http.MethodPost, - Path: "/api/generate", - TestPath: "/api/generate", - Handler: CompletionsMiddleware, - Endpoint: func(c *gin.Context) { - c.JSON(http.StatusOK, api.GenerateResponse{ - Response: "Hello!", - }) - }, - Setup: func(t *testing.T, req *http.Request) { - body := CompletionRequest{ - Model: "test-model", - Prompt: "Hello", - } - - bodyBytes, _ := json.Marshal(body) - - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - req.Header.Set("Content-Type", "application/json") - }, - Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { - assert.Equal(t, http.StatusOK, resp.Code) - var completionResp Completion - if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil { - t.Fatal(err) - } - - if completionResp.Object != "text_completion" { - t.Fatalf("expected text_completion, got %s", completionResp.Object) - } - - if completionResp.Choices[0].Text != "Hello!" { - t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text) - } - }, - }, - { - Name: "completions handler with params", - Method: http.MethodPost, - Path: "/api/generate", - TestPath: "/api/generate", - Handler: CompletionsMiddleware, - Endpoint: func(c *gin.Context) { - var generateReq api.GenerateRequest - if err := c.ShouldBindJSON(&generateReq); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"}) - return - } - - temperature := generateReq.Options["temperature"].(float64) - var assistantMessage string - - switch temperature { - case 1.6: - assistantMessage = "Received temperature of 1.6" - default: - assistantMessage = fmt.Sprintf("Received temperature of %f", temperature) - } - - c.JSON(http.StatusOK, api.GenerateResponse{ - Response: assistantMessage, - }) - }, + Name: "completions handler", + Method: http.MethodPost, + Path: "/api/generate", + Handler: CompletionsMiddleware, Setup: func(t *testing.T, req *http.Request) { temp := float32(0.8) body := CompletionRequest{ @@ -165,24 +86,65 @@ func TestMiddleware(t *testing.T) { req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) req.Header.Set("Content-Type", "application/json") }, - Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { - assert.Equal(t, http.StatusOK, resp.Code) - var completionResp Completion - if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil { + Expected: func(t *testing.T, req *http.Request) { + var genReq api.GenerateRequest + if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil { t.Fatal(err) } - if completionResp.Object != "text_completion" { - t.Fatalf("expected text_completion, got %s", completionResp.Object) + if genReq.Prompt != "Hello" { + t.Fatalf("expected 'Hello', got %s", genReq.Prompt) } - if completionResp.Choices[0].Text != "Received temperature of 1.6" { - t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text) + if genReq.Options["temperature"] != 1.6 { + t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"]) } }, }, + } + + gin.SetMode(gin.TestMode) + router := gin.New() + + endpoint := func(c *gin.Context) { + c.Status(http.StatusOK) + } + + for _, tc := range testCases { + t.Run(tc.Name, func(t *testing.T) { + router = gin.New() + router.Use(captureRequestMiddleware()) + router.Use(tc.Handler()) + router.Handle(tc.Method, tc.Path, endpoint) + req, _ := http.NewRequest(tc.Method, tc.Path, nil) + + if tc.Setup != nil { + tc.Setup(t, req) + } + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + tc.Expected(t, capturedRequest) + }) + } +} + +func TestMiddlewareResponses(t *testing.T) { + type testCase struct { + Name string + Method string + Path string + TestPath string + Handler func() gin.HandlerFunc + Endpoint func(c *gin.Context) + Setup func(t *testing.T, req *http.Request) + Expected func(t *testing.T, resp *httptest.ResponseRecorder) + } + + testCases := []testCase{ { - Name: "completions handler with error", + Name: "completions handler error forwarding", Method: http.MethodPost, Path: "/api/generate", TestPath: "/api/generate",