OpenAI: Simplify input output in testing (#5858)

* simplify input output

* direct comp

* in line image

* rm error pointer type

* update response testing

* lint
This commit is contained in:
royjhan 2024-08-12 13:33:34 -04:00 committed by GitHub
parent 1dc3ef3aa9
commit 01d544d373
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -7,12 +7,12 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
@ -20,14 +20,9 @@ import (
const ( const (
prefix = `data:image/jpeg;base64,` prefix = `data:image/jpeg;base64,`
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
imageURL = prefix + image
) )
func prepareRequest(req *http.Request, body any) { var False = false
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
}
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc { func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
@ -43,134 +38,136 @@ func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
func TestChatMiddleware(t *testing.T) { func TestChatMiddleware(t *testing.T) {
type testCase struct { type testCase struct {
Name string name string
Setup func(t *testing.T, req *http.Request) body string
Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) req api.ChatRequest
err ErrorResponse
} }
var capturedRequest *api.ChatRequest var capturedRequest *api.ChatRequest
testCases := []testCase{ testCases := []testCase{
{ {
Name: "chat handler", name: "chat handler",
Setup: func(t *testing.T, req *http.Request) { body: `{
body := ChatCompletionRequest{ "model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model", Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}}, Messages: []api.Message{
} {
prepareRequest(req, body) Role: "user",
Content: "Hello",
}, },
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) { },
if resp.Code != http.StatusOK { Options: map[string]any{
t.Fatalf("expected 200, got %d", resp.Code) "temperature": 1.0,
} "top_p": 1.0,
},
if req.Messages[0].Role != "user" { Stream: &False,
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
}
if req.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
}
}, },
}, },
{ {
Name: "chat handler with image content", name: "chat handler with image content",
Setup: func(t *testing.T, req *http.Request) { body: `{
body := ChatCompletionRequest{ "model": "test-model",
Model: "test-model", "messages": [
Messages: []Message{
{ {
Role: "user", Content: []map[string]any{ "role": "user",
{"type": "text", "text": "Hello"}, "content": [
{"type": "image_url", "image_url": map[string]string{"url": imageURL}}, {
"type": "text",
"text": "Hello"
},
{
"type": "image_url",
"image_url": {
"url": "` + prefix + image + `"
}
}
]
}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "Hello",
},
{
Role: "user",
Images: []api.ImageData{
func() []byte {
img, _ := base64.StdEncoding.DecodeString(image)
return img
}(),
}, },
}, },
}, },
} Options: map[string]any{
prepareRequest(req, body) "temperature": 1.0,
"top_p": 1.0,
}, },
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) { Stream: &False,
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.Code)
}
if req.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
}
if req.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
}
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
if req.Messages[1].Role != "user" {
t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
}
if !bytes.Equal(req.Messages[1].Images[0], img) {
t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
}
}, },
}, },
{ {
Name: "chat handler with tools", name: "chat handler with tools",
Setup: func(t *testing.T, req *http.Request) { body: `{
body := ChatCompletionRequest{ "model": "test-model",
"messages": [
{"role": "user", "content": "What's the weather like in Paris Today?"},
{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
]
}`,
req: api.ChatRequest{
Model: "test-model", Model: "test-model",
Messages: []Message{ Messages: []api.Message{
{Role: "user", Content: "What's the weather like in Paris Today?"}, {
{Role: "assistant", ToolCalls: []ToolCall{{ Role: "user",
ID: "id", Content: "What's the weather like in Paris Today?",
Type: "function", },
Function: struct { {
Name string `json:"name"` Role: "assistant",
Arguments string `json:"arguments"` ToolCalls: []api.ToolCall{
}{ {
Function: api.ToolCallFunction{
Name: "get_current_weather", Name: "get_current_weather",
Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}", Arguments: map[string]interface{}{
"location": "Paris, France",
"format": "celsius",
}, },
}}},
}, },
}
prepareRequest(req, body)
}, },
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) { },
if resp.Code != 200 { },
t.Fatalf("expected 200, got %d", resp.Code) },
} Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
if req.Messages[0].Content != "What's the weather like in Paris Today?" {
t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
}
if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
}
if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
}
},
},
{ {
Name: "chat handler error forwarding", name: "chat handler error forwarding",
Setup: func(t *testing.T, req *http.Request) { body: `{
body := ChatCompletionRequest{ "model": "test-model",
Model: "test-model", "messages": [
Messages: []Message{{Role: "user", Content: 2}}, {"role": "user", "content": 2}
} ]
prepareRequest(req, body) }`,
err: ErrorResponse{
Error: Error{
Message: "invalid message content type: float64",
Type: "invalid_request_error",
}, },
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid message content type") {
t.Fatalf("error was not forwarded")
}
}, },
}, },
} }
@ -185,16 +182,26 @@ func TestChatMiddleware(t *testing.T) {
router.Handle(http.MethodPost, "/api/chat", endpoint) router.Handle(http.MethodPost, "/api/chat", endpoint)
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil) req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
tc.Setup(t, req)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest, resp) var errResp ErrorResponse
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match")
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatal("errors did not match")
}
capturedRequest = nil capturedRequest = nil
}) })
} }
@ -202,71 +209,52 @@ func TestChatMiddleware(t *testing.T) {
func TestCompletionsMiddleware(t *testing.T) { func TestCompletionsMiddleware(t *testing.T) {
type testCase struct { type testCase struct {
Name string name string
Setup func(t *testing.T, req *http.Request) body string
Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) req api.GenerateRequest
err ErrorResponse
} }
var capturedRequest *api.GenerateRequest var capturedRequest *api.GenerateRequest
testCases := []testCase{ testCases := []testCase{
{ {
Name: "completions handler", name: "completions handler",
Setup: func(t *testing.T, req *http.Request) { body: `{
temp := float32(0.8) "model": "test-model",
body := CompletionRequest{ "prompt": "Hello",
"temperature": 0.8,
"stop": ["\n", "stop"],
"suffix": "suffix"
}`,
req: api.GenerateRequest{
Model: "test-model", Model: "test-model",
Prompt: "Hello", Prompt: "Hello",
Temperature: &temp, Options: map[string]any{
Stop: []string{"\n", "stop"}, "frequency_penalty": 0.0,
Suffix: "suffix", "presence_penalty": 0.0,
} "temperature": 1.6,
prepareRequest(req, body) "top_p": 1.0,
"stop": []any{"\n", "stop"},
}, },
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) { Suffix: "suffix",
if req.Prompt != "Hello" { Stream: &False,
t.Fatalf("expected 'Hello', got %s", req.Prompt)
}
if req.Options["temperature"] != 1.6 {
t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
}
stopTokens, ok := req.Options["stop"].([]any)
if !ok {
t.Fatalf("expected stop tokens to be a list")
}
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
}
if req.Suffix != "suffix" {
t.Fatalf("expected 'suffix', got %s", req.Suffix)
}
}, },
}, },
{ {
Name: "completions handler error forwarding", name: "completions handler error forwarding",
Setup: func(t *testing.T, req *http.Request) { body: `{
body := CompletionRequest{ "model": "test-model",
Model: "test-model", "prompt": "Hello",
Prompt: "Hello", "temperature": null,
Temperature: nil, "stop": [1, 2],
Stop: []int{1, 2}, "suffix": "suffix"
Suffix: "suffix", }`,
} err: ErrorResponse{
prepareRequest(req, body) Error: Error{
Message: "invalid type for 'stop' field: float64",
Type: "invalid_request_error",
}, },
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
t.Fatalf("error was not forwarded")
}
}, },
}, },
} }
@ -281,15 +269,27 @@ func TestCompletionsMiddleware(t *testing.T) {
router.Handle(http.MethodPost, "/api/generate", endpoint) router.Handle(http.MethodPost, "/api/generate", endpoint)
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil) req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
tc.Setup(t, req)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest, resp) var errResp ErrorResponse
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match")
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatal("errors did not match")
}
capturedRequest = nil capturedRequest = nil
}) })
@ -298,78 +298,47 @@ func TestCompletionsMiddleware(t *testing.T) {
func TestEmbeddingsMiddleware(t *testing.T) { func TestEmbeddingsMiddleware(t *testing.T) {
type testCase struct { type testCase struct {
Name string name string
Setup func(t *testing.T, req *http.Request) body string
Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) req api.EmbedRequest
err ErrorResponse
} }
var capturedRequest *api.EmbedRequest var capturedRequest *api.EmbedRequest
testCases := []testCase{ testCases := []testCase{
{ {
Name: "embed handler single input", name: "embed handler single input",
Setup: func(t *testing.T, req *http.Request) { body: `{
body := EmbedRequest{ "input": "Hello",
"model": "test-model"
}`,
req: api.EmbedRequest{
Input: "Hello", Input: "Hello",
Model: "test-model", Model: "test-model",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
if req.Input != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Input)
}
if req.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", req.Model)
}
}, },
}, },
{ {
Name: "embed handler batch input", name: "embed handler batch input",
Setup: func(t *testing.T, req *http.Request) { body: `{
body := EmbedRequest{ "input": ["Hello", "World"],
Input: []string{"Hello", "World"}, "model": "test-model"
}`,
req: api.EmbedRequest{
Input: []any{"Hello", "World"},
Model: "test-model", Model: "test-model",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
input, ok := req.Input.([]any)
if !ok {
t.Fatalf("expected input to be a list")
}
if input[0].(string) != "Hello" {
t.Fatalf("expected 'Hello', got %s", input[0])
}
if input[1].(string) != "World" {
t.Fatalf("expected 'World', got %s", input[1])
}
if req.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", req.Model)
}
}, },
}, },
{ {
Name: "embed handler error forwarding", name: "embed handler error forwarding",
Setup: func(t *testing.T, req *http.Request) { body: `{
body := EmbedRequest{ "model": "test-model"
Model: "test-model", }`,
} err: ErrorResponse{
prepareRequest(req, body) Error: Error{
Message: "invalid input",
Type: "invalid_request_error",
}, },
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid input") {
t.Fatalf("error was not forwarded")
}
}, },
}, },
} }
@ -384,116 +353,167 @@ func TestEmbeddingsMiddleware(t *testing.T) {
router.Handle(http.MethodPost, "/api/embed", endpoint) router.Handle(http.MethodPost, "/api/embed", endpoint)
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil) req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
tc.Setup(t, req)
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest, resp) var errResp ErrorResponse
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match")
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatal("errors did not match")
}
capturedRequest = nil capturedRequest = nil
}) })
} }
} }
func TestMiddlewareResponses(t *testing.T) { func TestListMiddleware(t *testing.T) {
type testCase struct { type testCase struct {
Name string name string
Method string endpoint func(c *gin.Context)
Path string resp string
TestPath string
Handler func() gin.HandlerFunc
Endpoint func(c *gin.Context)
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, resp *httptest.ResponseRecorder)
} }
testCases := []testCase{ testCases := []testCase{
{ {
Name: "list handler", name: "list handler",
Method: http.MethodGet, endpoint: func(c *gin.Context) {
Path: "/api/tags",
TestPath: "/api/tags",
Handler: ListMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ListResponse{ c.JSON(http.StatusOK, api.ListResponse{
Models: []api.ListModelResponse{ Models: []api.ListModelResponse{
{ {
Name: "Test Model", Name: "test-model",
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
}, },
}, },
}) })
}, },
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { resp: `{
var listResp ListCompletion "object": "list",
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil { "data": [
t.Fatal(err) {
"id": "test-model",
"object": "model",
"created": 1686935002,
"owned_by": "library"
} }
]
if listResp.Object != "list" { }`,
t.Fatalf("expected list, got %s", listResp.Object)
}
if len(listResp.Data) != 1 {
t.Fatalf("expected 1, got %d", len(listResp.Data))
}
if listResp.Data[0].Id != "Test Model" {
t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
}
},
}, },
{ {
Name: "retrieve model", name: "list handler empty output",
Method: http.MethodGet, endpoint: func(c *gin.Context) {
Path: "/api/show/:model", c.JSON(http.StatusOK, api.ListResponse{})
TestPath: "/api/show/test-model",
Handler: RetrieveMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ShowResponse{
ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
})
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
var retrieveResp Model
if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
t.Fatal(err)
}
if retrieveResp.Object != "model" {
t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
}
if retrieveResp.Id != "test-model" {
t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
}
}, },
resp: `{
"object": "list",
"data": null
}`,
}, },
} }
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
router := gin.New()
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) { router := gin.New()
router = gin.New() router.Use(ListMiddleware())
router.Use(tc.Handler()) router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
router.Handle(tc.Method, tc.Path, tc.Endpoint) req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
if tc.Setup != nil {
tc.Setup(t, req)
}
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code) var expected, actual map[string]any
err := json.Unmarshal([]byte(tc.resp), &expected)
if err != nil {
t.Fatalf("failed to unmarshal expected response: %v", err)
}
tc.Expected(t, resp) err = json.Unmarshal(resp.Body.Bytes(), &actual)
}) if err != nil {
t.Fatalf("failed to unmarshal actual response: %v", err)
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
}
}
}
func TestRetrieveMiddleware(t *testing.T) {
type testCase struct {
name string
endpoint func(c *gin.Context)
resp string
}
testCases := []testCase{
{
name: "retrieve handler",
endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ShowResponse{
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
})
},
resp: `{
"id":"test-model",
"object":"model",
"created":1686935002,
"owned_by":"library"}
`,
},
{
name: "retrieve handler error forwarding",
endpoint: func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
},
resp: `{
"error": {
"code": null,
"message": "model not found",
"param": null,
"type": "api_error"
}
}`,
},
}
gin.SetMode(gin.TestMode)
for _, tc := range testCases {
router := gin.New()
router.Use(RetrieveMiddleware())
router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
var expected, actual map[string]any
err := json.Unmarshal([]byte(tc.resp), &expected)
if err != nil {
t.Fatalf("failed to unmarshal expected response: %v", err)
}
err = json.Unmarshal(resp.Body.Bytes(), &actual)
if err != nil {
t.Fatalf("failed to unmarshal actual response: %v", err)
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
}
} }
} }