package openai import ( "bytes" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/gin-gonic/gin" "github.com/ollama/ollama/api" "github.com/stretchr/testify/assert" ) func TestMiddleware(t *testing.T) { type testCase struct { Name string Method string Path string TestPath string Handler func() gin.HandlerFunc Endpoint func(c *gin.Context) Setup func(t *testing.T, req *http.Request) Expected func(t *testing.T, resp *httptest.ResponseRecorder) } testCases := []testCase{ { Name: "chat handler", Method: http.MethodPost, Path: "/api/chat", TestPath: "/api/chat", Handler: ChatMiddleware, Endpoint: func(c *gin.Context) { var chatReq api.ChatRequest if err := c.ShouldBindJSON(&chatReq); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"}) return } userMessage := chatReq.Messages[0].Content var assistantMessage string switch userMessage { case "Hello": assistantMessage = "Hello!" default: assistantMessage = "I'm not sure how to respond to that." } c.JSON(http.StatusOK, api.ChatResponse{ Message: api.Message{ Role: "assistant", Content: assistantMessage, }, }) }, Setup: func(t *testing.T, req *http.Request) { body := ChatCompletionRequest{ Model: "test-model", Messages: []Message{{Role: "user", Content: "Hello"}}, } bodyBytes, _ := json.Marshal(body) req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) req.Header.Set("Content-Type", "application/json") }, Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { assert.Equal(t, http.StatusOK, resp.Code) var chatResp ChatCompletion if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil { t.Fatal(err) } if chatResp.Object != "chat.completion" { t.Fatalf("expected chat.completion, got %s", chatResp.Object) } if chatResp.Choices[0].Message.Content != "Hello!" { t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content) } }, }, { Name: "completions handler", Method: http.MethodPost, Path: "/api/generate", TestPath: "/api/generate", Handler: CompletionsMiddleware, Endpoint: func(c *gin.Context) { c.JSON(http.StatusOK, api.GenerateResponse{ Response: "Hello!", }) }, Setup: func(t *testing.T, req *http.Request) { body := CompletionRequest{ Model: "test-model", Prompt: "Hello", } bodyBytes, _ := json.Marshal(body) req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) req.Header.Set("Content-Type", "application/json") }, Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { assert.Equal(t, http.StatusOK, resp.Code) var completionResp Completion if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil { t.Fatal(err) } if completionResp.Object != "text_completion" { t.Fatalf("expected text_completion, got %s", completionResp.Object) } if completionResp.Choices[0].Text != "Hello!" { t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text) } }, }, { Name: "completions handler with params", Method: http.MethodPost, Path: "/api/generate", TestPath: "/api/generate", Handler: CompletionsMiddleware, Endpoint: func(c *gin.Context) { var generateReq api.GenerateRequest if err := c.ShouldBindJSON(&generateReq); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"}) return } temperature := generateReq.Options["temperature"].(float64) var assistantMessage string switch temperature { case 1.6: assistantMessage = "Received temperature of 1.6" default: assistantMessage = fmt.Sprintf("Received temperature of %f", temperature) } c.JSON(http.StatusOK, api.GenerateResponse{ Response: assistantMessage, }) }, Setup: func(t *testing.T, req *http.Request) { temp := float32(0.8) body := CompletionRequest{ Model: "test-model", Prompt: "Hello", Temperature: &temp, } bodyBytes, _ := json.Marshal(body) req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) req.Header.Set("Content-Type", "application/json") }, Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { assert.Equal(t, http.StatusOK, resp.Code) var completionResp Completion if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil { t.Fatal(err) } if completionResp.Object != "text_completion" { t.Fatalf("expected text_completion, got %s", completionResp.Object) } if completionResp.Choices[0].Text != "Received temperature of 1.6" { t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text) } }, }, { Name: "completions handler with error", Method: http.MethodPost, Path: "/api/generate", TestPath: "/api/generate", Handler: CompletionsMiddleware, Endpoint: func(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"}) }, Setup: func(t *testing.T, req *http.Request) { body := CompletionRequest{ Model: "test-model", Prompt: "Hello", } bodyBytes, _ := json.Marshal(body) req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) req.Header.Set("Content-Type", "application/json") }, Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { if resp.Code != http.StatusBadRequest { t.Fatalf("expected 400, got %d", resp.Code) } if !strings.Contains(resp.Body.String(), `"invalid request"`) { t.Fatalf("error was not forwarded") } }, }, { Name: "list handler", Method: http.MethodGet, Path: "/api/tags", TestPath: "/api/tags", Handler: ListMiddleware, Endpoint: func(c *gin.Context) { c.JSON(http.StatusOK, api.ListResponse{ Models: []api.ListModelResponse{ { Name: "Test Model", }, }, }) }, Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { assert.Equal(t, http.StatusOK, resp.Code) var listResp ListCompletion if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil { t.Fatal(err) } if listResp.Object != "list" { t.Fatalf("expected list, got %s", listResp.Object) } if len(listResp.Data) != 1 { t.Fatalf("expected 1, got %d", len(listResp.Data)) } if listResp.Data[0].Id != "Test Model" { t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id) } }, }, { Name: "retrieve model", Method: http.MethodGet, Path: "/api/show/:model", TestPath: "/api/show/test-model", Handler: RetrieveMiddleware, Endpoint: func(c *gin.Context) { c.JSON(http.StatusOK, api.ShowResponse{ ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC), }) }, Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { var retrieveResp Model if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil { t.Fatal(err) } if retrieveResp.Object != "model" { t.Fatalf("Expected object to be model, got %s", retrieveResp.Object) } if retrieveResp.Id != "test-model" { t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id) } }, }, } gin.SetMode(gin.TestMode) router := gin.New() for _, tc := range testCases { t.Run(tc.Name, func(t *testing.T) { router = gin.New() router.Use(tc.Handler()) router.Handle(tc.Method, tc.Path, tc.Endpoint) req, _ := http.NewRequest(tc.Method, tc.TestPath, nil) if tc.Setup != nil { tc.Setup(t, req) } resp := httptest.NewRecorder() router.ServeHTTP(resp, req) tc.Expected(t, resp) }) } }