diff --git a/openai/openai_test.go b/openai/openai_test.go index e08a96c9..c7e9f384 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -7,27 +7,22 @@ import ( "io" "net/http" "net/http/httptest" + "reflect" "strings" "testing" "time" "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" "github.com/ollama/ollama/api" ) const ( - prefix = `data:image/jpeg;base64,` - image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` - imageURL = prefix + image + prefix = `data:image/jpeg;base64,` + image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` ) -func prepareRequest(req *http.Request, body any) { - bodyBytes, _ := json.Marshal(body) - req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - req.Header.Set("Content-Type", "application/json") -} +var False = false func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc { return func(c *gin.Context) { @@ -43,134 +38,136 @@ func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc { func TestChatMiddleware(t *testing.T) { type testCase struct { - Name string - Setup func(t *testing.T, req *http.Request) - Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) + name string + body string + req api.ChatRequest + err ErrorResponse } var capturedRequest *api.ChatRequest testCases := []testCase{ { - Name: "chat handler", - Setup: func(t *testing.T, req *http.Request) { - body := ChatCompletionRequest{ - Model: "test-model", - Messages: []Message{{Role: "user", Content: "Hello"}}, - } - prepareRequest(req, body) - }, - Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) { - if resp.Code != http.StatusOK { - t.Fatalf("expected 200, got %d", resp.Code) - } - - if req.Messages[0].Role != "user" { - t.Fatalf("expected 'user', got %s", req.Messages[0].Role) - } - - if req.Messages[0].Content != "Hello" { - t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content) - } + name: "chat handler", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "Hello", + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, }, }, { - Name: "chat handler with image content", - Setup: func(t *testing.T, req *http.Request) { - body := ChatCompletionRequest{ - Model: "test-model", - Messages: []Message{ - { - Role: "user", Content: []map[string]any{ - {"type": "text", "text": "Hello"}, - {"type": "image_url", "image_url": map[string]string{"url": imageURL}}, + name: "chat handler with image content", + body: `{ + "model": "test-model", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Hello" + }, + { + "type": "image_url", + "image_url": { + "url": "` + prefix + image + `" + } + } + ] + } + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "Hello", + }, + { + Role: "user", + Images: []api.ImageData{ + func() []byte { + img, _ := base64.StdEncoding.DecodeString(image) + return img + }(), + }, + }, + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, + }, + }, + { + name: "chat handler with tools", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "What's the weather like in Paris Today?"}, + {"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + { + Role: "user", + Content: "What's the weather like in Paris Today?", + }, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: map[string]interface{}{ + "location": "Paris, France", + "format": "celsius", + }, + }, }, }, }, - } - prepareRequest(req, body) - }, - Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) { - if resp.Code != http.StatusOK { - t.Fatalf("expected 200, got %d", resp.Code) - } - - if req.Messages[0].Role != "user" { - t.Fatalf("expected 'user', got %s", req.Messages[0].Role) - } - - if req.Messages[0].Content != "Hello" { - t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content) - } - - img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):]) - - if req.Messages[1].Role != "user" { - t.Fatalf("expected 'user', got %s", req.Messages[1].Role) - } - - if !bytes.Equal(req.Messages[1].Images[0], img) { - t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0]) - } + }, + Options: map[string]any{ + "temperature": 1.0, + "top_p": 1.0, + }, + Stream: &False, }, }, + { - Name: "chat handler with tools", - Setup: func(t *testing.T, req *http.Request) { - body := ChatCompletionRequest{ - Model: "test-model", - Messages: []Message{ - {Role: "user", Content: "What's the weather like in Paris Today?"}, - {Role: "assistant", ToolCalls: []ToolCall{{ - ID: "id", - Type: "function", - Function: struct { - Name string `json:"name"` - Arguments string `json:"arguments"` - }{ - Name: "get_current_weather", - Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}", - }, - }}}, - }, - } - prepareRequest(req, body) - }, - Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) { - if resp.Code != 200 { - t.Fatalf("expected 200, got %d", resp.Code) - } - - if req.Messages[0].Content != "What's the weather like in Paris Today?" { - t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content) - } - - if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" { - t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"]) - } - - if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" { - t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"]) - } - }, - }, - { - Name: "chat handler error forwarding", - Setup: func(t *testing.T, req *http.Request) { - body := ChatCompletionRequest{ - Model: "test-model", - Messages: []Message{{Role: "user", Content: 2}}, - } - prepareRequest(req, body) - }, - Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) { - if resp.Code != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", resp.Code) - } - - if !strings.Contains(resp.Body.String(), "invalid message content type") { - t.Fatalf("error was not forwarded") - } + name: "chat handler error forwarding", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": 2} + ] + }`, + err: ErrorResponse{ + Error: Error{ + Message: "invalid message content type: float64", + Type: "invalid_request_error", + }, }, }, } @@ -185,16 +182,26 @@ func TestChatMiddleware(t *testing.T) { router.Handle(http.MethodPost, "/api/chat", endpoint) for _, tc := range testCases { - t.Run(tc.Name, func(t *testing.T) { - req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil) - - tc.Setup(t, req) + t.Run(tc.name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body)) + req.Header.Set("Content-Type", "application/json") resp := httptest.NewRecorder() router.ServeHTTP(resp, req) - tc.Expected(t, capturedRequest, resp) + var errResp ErrorResponse + if resp.Code != http.StatusOK { + if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + t.Fatal(err) + } + } + if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { + t.Fatal("requests did not match") + } + if !reflect.DeepEqual(tc.err, errResp) { + t.Fatal("errors did not match") + } capturedRequest = nil }) } @@ -202,71 +209,52 @@ func TestChatMiddleware(t *testing.T) { func TestCompletionsMiddleware(t *testing.T) { type testCase struct { - Name string - Setup func(t *testing.T, req *http.Request) - Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) + name string + body string + req api.GenerateRequest + err ErrorResponse } var capturedRequest *api.GenerateRequest testCases := []testCase{ { - Name: "completions handler", - Setup: func(t *testing.T, req *http.Request) { - temp := float32(0.8) - body := CompletionRequest{ - Model: "test-model", - Prompt: "Hello", - Temperature: &temp, - Stop: []string{"\n", "stop"}, - Suffix: "suffix", - } - prepareRequest(req, body) - }, - Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) { - if req.Prompt != "Hello" { - t.Fatalf("expected 'Hello', got %s", req.Prompt) - } - - if req.Options["temperature"] != 1.6 { - t.Fatalf("expected 1.6, got %f", req.Options["temperature"]) - } - - stopTokens, ok := req.Options["stop"].([]any) - - if !ok { - t.Fatalf("expected stop tokens to be a list") - } - - if stopTokens[0] != "\n" || stopTokens[1] != "stop" { - t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens) - } - - if req.Suffix != "suffix" { - t.Fatalf("expected 'suffix', got %s", req.Suffix) - } + name: "completions handler", + body: `{ + "model": "test-model", + "prompt": "Hello", + "temperature": 0.8, + "stop": ["\n", "stop"], + "suffix": "suffix" + }`, + req: api.GenerateRequest{ + Model: "test-model", + Prompt: "Hello", + Options: map[string]any{ + "frequency_penalty": 0.0, + "presence_penalty": 0.0, + "temperature": 1.6, + "top_p": 1.0, + "stop": []any{"\n", "stop"}, + }, + Suffix: "suffix", + Stream: &False, }, }, { - Name: "completions handler error forwarding", - Setup: func(t *testing.T, req *http.Request) { - body := CompletionRequest{ - Model: "test-model", - Prompt: "Hello", - Temperature: nil, - Stop: []int{1, 2}, - Suffix: "suffix", - } - prepareRequest(req, body) - }, - Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) { - if resp.Code != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", resp.Code) - } - - if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") { - t.Fatalf("error was not forwarded") - } + name: "completions handler error forwarding", + body: `{ + "model": "test-model", + "prompt": "Hello", + "temperature": null, + "stop": [1, 2], + "suffix": "suffix" + }`, + err: ErrorResponse{ + Error: Error{ + Message: "invalid type for 'stop' field: float64", + Type: "invalid_request_error", + }, }, }, } @@ -281,15 +269,27 @@ func TestCompletionsMiddleware(t *testing.T) { router.Handle(http.MethodPost, "/api/generate", endpoint) for _, tc := range testCases { - t.Run(tc.Name, func(t *testing.T) { - req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil) - - tc.Setup(t, req) + t.Run(tc.name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body)) + req.Header.Set("Content-Type", "application/json") resp := httptest.NewRecorder() router.ServeHTTP(resp, req) - tc.Expected(t, capturedRequest, resp) + var errResp ErrorResponse + if resp.Code != http.StatusOK { + if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + t.Fatal(err) + } + } + + if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { + t.Fatal("requests did not match") + } + + if !reflect.DeepEqual(tc.err, errResp) { + t.Fatal("errors did not match") + } capturedRequest = nil }) @@ -298,78 +298,47 @@ func TestCompletionsMiddleware(t *testing.T) { func TestEmbeddingsMiddleware(t *testing.T) { type testCase struct { - Name string - Setup func(t *testing.T, req *http.Request) - Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) + name string + body string + req api.EmbedRequest + err ErrorResponse } var capturedRequest *api.EmbedRequest testCases := []testCase{ { - Name: "embed handler single input", - Setup: func(t *testing.T, req *http.Request) { - body := EmbedRequest{ - Input: "Hello", - Model: "test-model", - } - prepareRequest(req, body) - }, - Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) { - if req.Input != "Hello" { - t.Fatalf("expected 'Hello', got %s", req.Input) - } - - if req.Model != "test-model" { - t.Fatalf("expected 'test-model', got %s", req.Model) - } + name: "embed handler single input", + body: `{ + "input": "Hello", + "model": "test-model" + }`, + req: api.EmbedRequest{ + Input: "Hello", + Model: "test-model", }, }, { - Name: "embed handler batch input", - Setup: func(t *testing.T, req *http.Request) { - body := EmbedRequest{ - Input: []string{"Hello", "World"}, - Model: "test-model", - } - prepareRequest(req, body) - }, - Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) { - input, ok := req.Input.([]any) - - if !ok { - t.Fatalf("expected input to be a list") - } - - if input[0].(string) != "Hello" { - t.Fatalf("expected 'Hello', got %s", input[0]) - } - - if input[1].(string) != "World" { - t.Fatalf("expected 'World', got %s", input[1]) - } - - if req.Model != "test-model" { - t.Fatalf("expected 'test-model', got %s", req.Model) - } + name: "embed handler batch input", + body: `{ + "input": ["Hello", "World"], + "model": "test-model" + }`, + req: api.EmbedRequest{ + Input: []any{"Hello", "World"}, + Model: "test-model", }, }, { - Name: "embed handler error forwarding", - Setup: func(t *testing.T, req *http.Request) { - body := EmbedRequest{ - Model: "test-model", - } - prepareRequest(req, body) - }, - Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) { - if resp.Code != http.StatusBadRequest { - t.Fatalf("expected 400, got %d", resp.Code) - } - - if !strings.Contains(resp.Body.String(), "invalid input") { - t.Fatalf("error was not forwarded") - } + name: "embed handler error forwarding", + body: `{ + "model": "test-model" + }`, + err: ErrorResponse{ + Error: Error{ + Message: "invalid input", + Type: "invalid_request_error", + }, }, }, } @@ -384,116 +353,167 @@ func TestEmbeddingsMiddleware(t *testing.T) { router.Handle(http.MethodPost, "/api/embed", endpoint) for _, tc := range testCases { - t.Run(tc.Name, func(t *testing.T) { - req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil) - - tc.Setup(t, req) + t.Run(tc.name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body)) + req.Header.Set("Content-Type", "application/json") resp := httptest.NewRecorder() router.ServeHTTP(resp, req) - tc.Expected(t, capturedRequest, resp) + var errResp ErrorResponse + if resp.Code != http.StatusOK { + if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + t.Fatal(err) + } + } + + if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) { + t.Fatal("requests did not match") + } + + if !reflect.DeepEqual(tc.err, errResp) { + t.Fatal("errors did not match") + } capturedRequest = nil }) } } -func TestMiddlewareResponses(t *testing.T) { +func TestListMiddleware(t *testing.T) { type testCase struct { - Name string - Method string - Path string - TestPath string - Handler func() gin.HandlerFunc - Endpoint func(c *gin.Context) - Setup func(t *testing.T, req *http.Request) - Expected func(t *testing.T, resp *httptest.ResponseRecorder) + name string + endpoint func(c *gin.Context) + resp string } testCases := []testCase{ { - Name: "list handler", - Method: http.MethodGet, - Path: "/api/tags", - TestPath: "/api/tags", - Handler: ListMiddleware, - Endpoint: func(c *gin.Context) { + name: "list handler", + endpoint: func(c *gin.Context) { c.JSON(http.StatusOK, api.ListResponse{ Models: []api.ListModelResponse{ { - Name: "Test Model", + Name: "test-model", + ModifiedAt: time.Unix(int64(1686935002), 0).UTC(), }, }, }) }, - Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { - var listResp ListCompletion - if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil { - t.Fatal(err) - } - - if listResp.Object != "list" { - t.Fatalf("expected list, got %s", listResp.Object) - } - - if len(listResp.Data) != 1 { - t.Fatalf("expected 1, got %d", len(listResp.Data)) - } - - if listResp.Data[0].Id != "Test Model" { - t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id) - } - }, + resp: `{ + "object": "list", + "data": [ + { + "id": "test-model", + "object": "model", + "created": 1686935002, + "owned_by": "library" + } + ] + }`, }, { - Name: "retrieve model", - Method: http.MethodGet, - Path: "/api/show/:model", - TestPath: "/api/show/test-model", - Handler: RetrieveMiddleware, - Endpoint: func(c *gin.Context) { - c.JSON(http.StatusOK, api.ShowResponse{ - ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC), - }) - }, - Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { - var retrieveResp Model - if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil { - t.Fatal(err) - } - - if retrieveResp.Object != "model" { - t.Fatalf("Expected object to be model, got %s", retrieveResp.Object) - } - - if retrieveResp.Id != "test-model" { - t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id) - } + name: "list handler empty output", + endpoint: func(c *gin.Context) { + c.JSON(http.StatusOK, api.ListResponse{}) }, + resp: `{ + "object": "list", + "data": null + }`, }, } gin.SetMode(gin.TestMode) - router := gin.New() for _, tc := range testCases { - t.Run(tc.Name, func(t *testing.T) { - router = gin.New() - router.Use(tc.Handler()) - router.Handle(tc.Method, tc.Path, tc.Endpoint) - req, _ := http.NewRequest(tc.Method, tc.TestPath, nil) + router := gin.New() + router.Use(ListMiddleware()) + router.Handle(http.MethodGet, "/api/tags", tc.endpoint) + req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil) - if tc.Setup != nil { - tc.Setup(t, req) - } + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) - resp := httptest.NewRecorder() - router.ServeHTTP(resp, req) + var expected, actual map[string]any + err := json.Unmarshal([]byte(tc.resp), &expected) + if err != nil { + t.Fatalf("failed to unmarshal expected response: %v", err) + } - assert.Equal(t, http.StatusOK, resp.Code) + err = json.Unmarshal(resp.Body.Bytes(), &actual) + if err != nil { + t.Fatalf("failed to unmarshal actual response: %v", err) + } - tc.Expected(t, resp) - }) + if !reflect.DeepEqual(expected, actual) { + t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual) + } + } +} + +func TestRetrieveMiddleware(t *testing.T) { + type testCase struct { + name string + endpoint func(c *gin.Context) + resp string + } + + testCases := []testCase{ + { + name: "retrieve handler", + endpoint: func(c *gin.Context) { + c.JSON(http.StatusOK, api.ShowResponse{ + ModifiedAt: time.Unix(int64(1686935002), 0).UTC(), + }) + }, + resp: `{ + "id":"test-model", + "object":"model", + "created":1686935002, + "owned_by":"library"} + `, + }, + { + name: "retrieve handler error forwarding", + endpoint: func(c *gin.Context) { + c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"}) + }, + resp: `{ + "error": { + "code": null, + "message": "model not found", + "param": null, + "type": "api_error" + } + }`, + }, + } + + gin.SetMode(gin.TestMode) + + for _, tc := range testCases { + router := gin.New() + router.Use(RetrieveMiddleware()) + router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint) + req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil) + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + var expected, actual map[string]any + err := json.Unmarshal([]byte(tc.resp), &expected) + if err != nil { + t.Fatalf("failed to unmarshal expected response: %v", err) + } + + err = json.Unmarshal(resp.Body.Bytes(), &actual) + if err != nil { + t.Fatalf("failed to unmarshal actual response: %v", err) + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual) + } } }