OpenAI: Function Based Testing (#5752)
* distinguish error forwarding * more coverage * rm comment
This commit is contained in:
parent
51b2fd299c
commit
c57317cbf0
2 changed files with 279 additions and 183 deletions
|
@ -877,6 +877,7 @@ func ChatMiddleware() gin.HandlerFunc {
|
||||||
chatReq, err := fromChatRequest(req)
|
chatReq, err := fromChatRequest(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||||
|
|
|
@ -20,113 +20,59 @@ const prefix = `data:image/jpeg;base64,`
|
||||||
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||||
const imageURL = prefix + image
|
const imageURL = prefix + image
|
||||||
|
|
||||||
func TestMiddlewareRequests(t *testing.T) {
|
func prepareRequest(req *http.Request, body any) {
|
||||||
type testCase struct {
|
bodyBytes, _ := json.Marshal(body)
|
||||||
Name string
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
Method string
|
req.Header.Set("Content-Type", "application/json")
|
||||||
Path string
|
}
|
||||||
Handler func() gin.HandlerFunc
|
|
||||||
Setup func(t *testing.T, req *http.Request)
|
|
||||||
Expected func(t *testing.T, req *http.Request)
|
|
||||||
}
|
|
||||||
|
|
||||||
var capturedRequest *http.Request
|
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
||||||
|
|
||||||
captureRequestMiddleware := func() gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
capturedRequest = c.Request
|
err := json.Unmarshal(bodyBytes, capturedRequest)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
|
||||||
|
}
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
Name string
|
||||||
|
Setup func(t *testing.T, req *http.Request)
|
||||||
|
Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var capturedRequest *api.ChatRequest
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
Name: "chat handler",
|
Name: "chat handler",
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/chat",
|
|
||||||
Handler: ChatMiddleware,
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
body := ChatCompletionRequest{
|
body := ChatCompletionRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
Messages: []Message{{Role: "user", Content: "Hello"}},
|
Messages: []Message{{Role: "user", Content: "Hello"}},
|
||||||
}
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *http.Request) {
|
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||||
var chatReq api.ChatRequest
|
if resp.Code != http.StatusOK {
|
||||||
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
|
t.Fatalf("expected 200, got %d", resp.Code)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if chatReq.Messages[0].Role != "user" {
|
if req.Messages[0].Role != "user" {
|
||||||
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
|
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
|
||||||
}
|
}
|
||||||
|
|
||||||
if chatReq.Messages[0].Content != "Hello" {
|
if req.Messages[0].Content != "Hello" {
|
||||||
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
|
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "completions handler",
|
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/generate",
|
|
||||||
Handler: CompletionsMiddleware,
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
temp := float32(0.8)
|
|
||||||
body := CompletionRequest{
|
|
||||||
Model: "test-model",
|
|
||||||
Prompt: "Hello",
|
|
||||||
Temperature: &temp,
|
|
||||||
Stop: []string{"\n", "stop"},
|
|
||||||
Suffix: "suffix",
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, req *http.Request) {
|
|
||||||
var genReq api.GenerateRequest
|
|
||||||
if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if genReq.Prompt != "Hello" {
|
|
||||||
t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
|
|
||||||
}
|
|
||||||
|
|
||||||
if genReq.Options["temperature"] != 1.6 {
|
|
||||||
t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
|
|
||||||
}
|
|
||||||
|
|
||||||
stopTokens, ok := genReq.Options["stop"].([]any)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected stop tokens to be a list")
|
|
||||||
}
|
|
||||||
|
|
||||||
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
|
|
||||||
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
if genReq.Suffix != "suffix" {
|
|
||||||
t.Fatalf("expected 'suffix', got %s", genReq.Suffix)
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "chat handler with image content",
|
Name: "chat handler with image content",
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/chat",
|
|
||||||
Handler: ChatMiddleware,
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
body := ChatCompletionRequest{
|
body := ChatCompletionRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
|
@ -139,91 +85,254 @@ func TestMiddlewareRequests(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *http.Request) {
|
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||||
var chatReq api.ChatRequest
|
if resp.Code != http.StatusOK {
|
||||||
if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
|
t.Fatalf("expected 200, got %d", resp.Code)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if chatReq.Messages[0].Role != "user" {
|
if req.Messages[0].Role != "user" {
|
||||||
t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
|
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
|
||||||
}
|
}
|
||||||
|
|
||||||
if chatReq.Messages[0].Content != "Hello" {
|
if req.Messages[0].Content != "Hello" {
|
||||||
t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
|
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
|
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
|
||||||
|
|
||||||
if chatReq.Messages[1].Role != "user" {
|
if req.Messages[1].Role != "user" {
|
||||||
t.Fatalf("expected 'user', got %s", chatReq.Messages[1].Role)
|
t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !bytes.Equal(chatReq.Messages[1].Images[0], img) {
|
if !bytes.Equal(req.Messages[1].Images[0], img) {
|
||||||
t.Fatalf("expected image encoding, got %s", chatReq.Messages[1].Images[0])
|
t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "chat handler with tools",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := ChatCompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []Message{
|
||||||
|
{Role: "user", Content: "What's the weather like in Paris Today?"},
|
||||||
|
{Role: "assistant", ToolCalls: []ToolCall{{
|
||||||
|
ID: "id",
|
||||||
|
Type: "function",
|
||||||
|
Function: struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments string `json:"arguments"`
|
||||||
|
}{
|
||||||
|
Name: "get_current_weather",
|
||||||
|
Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
|
||||||
|
},
|
||||||
|
}}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||||
|
if resp.Code != 200 {
|
||||||
|
t.Fatalf("expected 200, got %d", resp.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Messages[0].Content != "What's the weather like in Paris Today?" {
|
||||||
|
t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
|
||||||
|
t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
|
||||||
|
t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "chat handler error forwarding",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := ChatCompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []Message{{Role: "user", Content: 2}},
|
||||||
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||||
|
if resp.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", resp.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(resp.Body.String(), "invalid message content type") {
|
||||||
|
t.Fatalf("error was not forwarded")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := func(c *gin.Context) {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||||
|
router.Handle(http.MethodPost, "/api/chat", endpoint)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
|
||||||
|
|
||||||
|
tc.Setup(t, req)
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
tc.Expected(t, capturedRequest, resp)
|
||||||
|
|
||||||
|
capturedRequest = nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompletionsMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
Name string
|
||||||
|
Setup func(t *testing.T, req *http.Request)
|
||||||
|
Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
|
||||||
|
}
|
||||||
|
|
||||||
|
var capturedRequest *api.GenerateRequest
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
Name: "completions handler",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
temp := float32(0.8)
|
||||||
|
body := CompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Temperature: &temp,
|
||||||
|
Stop: []string{"\n", "stop"},
|
||||||
|
Suffix: "suffix",
|
||||||
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
||||||
|
if req.Prompt != "Hello" {
|
||||||
|
t.Fatalf("expected 'Hello', got %s", req.Prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Options["temperature"] != 1.6 {
|
||||||
|
t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
|
||||||
|
}
|
||||||
|
|
||||||
|
stopTokens, ok := req.Options["stop"].([]any)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected stop tokens to be a list")
|
||||||
|
}
|
||||||
|
|
||||||
|
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
|
||||||
|
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Suffix != "suffix" {
|
||||||
|
t.Fatalf("expected 'suffix', got %s", req.Suffix)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "completions handler error forwarding",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := CompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Prompt: "Hello",
|
||||||
|
Temperature: nil,
|
||||||
|
Stop: []int{1, 2},
|
||||||
|
Suffix: "suffix",
|
||||||
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
||||||
|
if resp.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", resp.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
|
||||||
|
t.Fatalf("error was not forwarded")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := func(c *gin.Context) {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||||
|
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil)
|
||||||
|
|
||||||
|
tc.Setup(t, req)
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
tc.Expected(t, capturedRequest, resp)
|
||||||
|
|
||||||
|
capturedRequest = nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEmbeddingsMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
Name string
|
||||||
|
Setup func(t *testing.T, req *http.Request)
|
||||||
|
Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
|
||||||
|
}
|
||||||
|
|
||||||
|
var capturedRequest *api.EmbedRequest
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
Name: "embed handler single input",
|
Name: "embed handler single input",
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/embed",
|
|
||||||
Handler: EmbeddingsMiddleware,
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
body := EmbedRequest{
|
body := EmbedRequest{
|
||||||
Input: "Hello",
|
Input: "Hello",
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
}
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *http.Request) {
|
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
||||||
var embedReq api.EmbedRequest
|
if req.Input != "Hello" {
|
||||||
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
|
t.Fatalf("expected 'Hello', got %s", req.Input)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if embedReq.Input != "Hello" {
|
if req.Model != "test-model" {
|
||||||
t.Fatalf("expected 'Hello', got %s", embedReq.Input)
|
t.Fatalf("expected 'test-model', got %s", req.Model)
|
||||||
}
|
|
||||||
|
|
||||||
if embedReq.Model != "test-model" {
|
|
||||||
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "embed handler batch input",
|
Name: "embed handler batch input",
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/embed",
|
|
||||||
Handler: EmbeddingsMiddleware,
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
body := EmbedRequest{
|
body := EmbedRequest{
|
||||||
Input: []string{"Hello", "World"},
|
Input: []string{"Hello", "World"},
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
}
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, req *http.Request) {
|
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
||||||
var embedReq api.EmbedRequest
|
input, ok := req.Input.([]any)
|
||||||
if err := json.NewDecoder(req.Body).Decode(&embedReq); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
input, ok := embedReq.Input.([]any)
|
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Fatalf("expected input to be a list")
|
t.Fatalf("expected input to be a list")
|
||||||
|
@ -237,36 +346,52 @@ func TestMiddlewareRequests(t *testing.T) {
|
||||||
t.Fatalf("expected 'World', got %s", input[1])
|
t.Fatalf("expected 'World', got %s", input[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
if embedReq.Model != "test-model" {
|
if req.Model != "test-model" {
|
||||||
t.Fatalf("expected 'test-model', got %s", embedReq.Model)
|
t.Fatalf("expected 'test-model', got %s", req.Model)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "embed handler error forwarding",
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := EmbedRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
}
|
||||||
|
prepareRequest(req, body)
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
||||||
|
if resp.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("expected 400, got %d", resp.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
gin.SetMode(gin.TestMode)
|
if !strings.Contains(resp.Body.String(), "invalid input") {
|
||||||
router := gin.New()
|
t.Fatalf("error was not forwarded")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
endpoint := func(c *gin.Context) {
|
endpoint := func(c *gin.Context) {
|
||||||
c.Status(http.StatusOK)
|
c.Status(http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||||
|
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.Name, func(t *testing.T) {
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
router = gin.New()
|
req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
|
||||||
router.Use(captureRequestMiddleware())
|
|
||||||
router.Use(tc.Handler())
|
|
||||||
router.Handle(tc.Method, tc.Path, endpoint)
|
|
||||||
req, _ := http.NewRequest(tc.Method, tc.Path, nil)
|
|
||||||
|
|
||||||
if tc.Setup != nil {
|
|
||||||
tc.Setup(t, req)
|
tc.Setup(t, req)
|
||||||
}
|
|
||||||
|
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
router.ServeHTTP(resp, req)
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
tc.Expected(t, capturedRequest)
|
tc.Expected(t, capturedRequest, resp)
|
||||||
|
|
||||||
|
capturedRequest = nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -284,36 +409,6 @@ func TestMiddlewareResponses(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
|
||||||
Name: "completions handler error forwarding",
|
|
||||||
Method: http.MethodPost,
|
|
||||||
Path: "/api/generate",
|
|
||||||
TestPath: "/api/generate",
|
|
||||||
Handler: CompletionsMiddleware,
|
|
||||||
Endpoint: func(c *gin.Context) {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
|
||||||
},
|
|
||||||
Setup: func(t *testing.T, req *http.Request) {
|
|
||||||
body := CompletionRequest{
|
|
||||||
Model: "test-model",
|
|
||||||
Prompt: "Hello",
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyBytes, _ := json.Marshal(body)
|
|
||||||
|
|
||||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
},
|
|
||||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
|
||||||
if resp.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("expected 400, got %d", resp.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(resp.Body.String(), `"invalid request"`) {
|
|
||||||
t.Fatalf("error was not forwarded")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
Name: "list handler",
|
Name: "list handler",
|
||||||
Method: http.MethodGet,
|
Method: http.MethodGet,
|
||||||
|
@ -330,8 +425,6 @@ func TestMiddlewareResponses(t *testing.T) {
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||||
assert.Equal(t, http.StatusOK, resp.Code)
|
|
||||||
|
|
||||||
var listResp ListCompletion
|
var listResp ListCompletion
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -395,6 +488,8 @@ func TestMiddlewareResponses(t *testing.T) {
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
router.ServeHTTP(resp, req)
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, resp.Code)
|
||||||
|
|
||||||
tc.Expected(t, resp)
|
tc.Expected(t, resp)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue