2024-07-02 11:50:56 -07:00
|
|
|
package openai
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
|
|
|
"encoding/json"
|
2024-07-02 16:01:45 -07:00
|
|
|
"fmt"
|
2024-07-02 11:50:56 -07:00
|
|
|
"io"
|
|
|
|
"net/http"
|
|
|
|
"net/http/httptest"
|
2024-07-02 16:01:45 -07:00
|
|
|
"strings"
|
2024-07-02 11:50:56 -07:00
|
|
|
"testing"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
)
|
|
|
|
|
|
|
|
func TestMiddleware(t *testing.T) {
|
|
|
|
type testCase struct {
|
|
|
|
Name string
|
|
|
|
Method string
|
|
|
|
Path string
|
|
|
|
TestPath string
|
|
|
|
Handler func() gin.HandlerFunc
|
|
|
|
Endpoint func(c *gin.Context)
|
|
|
|
Setup func(t *testing.T, req *http.Request)
|
|
|
|
Expected func(t *testing.T, resp *httptest.ResponseRecorder)
|
|
|
|
}
|
|
|
|
|
|
|
|
testCases := []testCase{
|
|
|
|
{
|
|
|
|
Name: "chat handler",
|
|
|
|
Method: http.MethodPost,
|
|
|
|
Path: "/api/chat",
|
|
|
|
TestPath: "/api/chat",
|
|
|
|
Handler: ChatMiddleware,
|
|
|
|
Endpoint: func(c *gin.Context) {
|
|
|
|
var chatReq api.ChatRequest
|
|
|
|
if err := c.ShouldBindJSON(&chatReq); err != nil {
|
|
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
userMessage := chatReq.Messages[0].Content
|
|
|
|
var assistantMessage string
|
|
|
|
|
|
|
|
switch userMessage {
|
|
|
|
case "Hello":
|
|
|
|
assistantMessage = "Hello!"
|
|
|
|
default:
|
|
|
|
assistantMessage = "I'm not sure how to respond to that."
|
|
|
|
}
|
|
|
|
|
|
|
|
c.JSON(http.StatusOK, api.ChatResponse{
|
|
|
|
Message: api.Message{
|
|
|
|
Role: "assistant",
|
|
|
|
Content: assistantMessage,
|
|
|
|
},
|
|
|
|
})
|
|
|
|
},
|
|
|
|
Setup: func(t *testing.T, req *http.Request) {
|
|
|
|
body := ChatCompletionRequest{
|
|
|
|
Model: "test-model",
|
|
|
|
Messages: []Message{{Role: "user", Content: "Hello"}},
|
|
|
|
}
|
|
|
|
|
|
|
|
bodyBytes, _ := json.Marshal(body)
|
|
|
|
|
|
|
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
},
|
|
|
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
2024-07-02 16:01:45 -07:00
|
|
|
assert.Equal(t, http.StatusOK, resp.Code)
|
|
|
|
|
2024-07-02 11:50:56 -07:00
|
|
|
var chatResp ChatCompletion
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if chatResp.Object != "chat.completion" {
|
|
|
|
t.Fatalf("expected chat.completion, got %s", chatResp.Object)
|
|
|
|
}
|
|
|
|
|
|
|
|
if chatResp.Choices[0].Message.Content != "Hello!" {
|
|
|
|
t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content)
|
|
|
|
}
|
|
|
|
},
|
|
|
|
},
|
2024-07-02 16:01:45 -07:00
|
|
|
{
|
|
|
|
Name: "completions handler",
|
|
|
|
Method: http.MethodPost,
|
|
|
|
Path: "/api/generate",
|
|
|
|
TestPath: "/api/generate",
|
|
|
|
Handler: CompletionsMiddleware,
|
|
|
|
Endpoint: func(c *gin.Context) {
|
|
|
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
|
|
|
Response: "Hello!",
|
|
|
|
})
|
|
|
|
},
|
|
|
|
Setup: func(t *testing.T, req *http.Request) {
|
|
|
|
body := CompletionRequest{
|
|
|
|
Model: "test-model",
|
|
|
|
Prompt: "Hello",
|
|
|
|
}
|
|
|
|
|
|
|
|
bodyBytes, _ := json.Marshal(body)
|
|
|
|
|
|
|
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
},
|
|
|
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
|
|
|
assert.Equal(t, http.StatusOK, resp.Code)
|
|
|
|
var completionResp Completion
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if completionResp.Object != "text_completion" {
|
|
|
|
t.Fatalf("expected text_completion, got %s", completionResp.Object)
|
|
|
|
}
|
|
|
|
|
|
|
|
if completionResp.Choices[0].Text != "Hello!" {
|
|
|
|
t.Fatalf("expected Hello!, got %s", completionResp.Choices[0].Text)
|
|
|
|
}
|
|
|
|
},
|
|
|
|
},
|
|
|
|
{
|
|
|
|
Name: "completions handler with params",
|
|
|
|
Method: http.MethodPost,
|
|
|
|
Path: "/api/generate",
|
|
|
|
TestPath: "/api/generate",
|
|
|
|
Handler: CompletionsMiddleware,
|
|
|
|
Endpoint: func(c *gin.Context) {
|
|
|
|
var generateReq api.GenerateRequest
|
|
|
|
if err := c.ShouldBindJSON(&generateReq); err != nil {
|
|
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
temperature := generateReq.Options["temperature"].(float64)
|
|
|
|
var assistantMessage string
|
|
|
|
|
|
|
|
switch temperature {
|
|
|
|
case 1.6:
|
|
|
|
assistantMessage = "Received temperature of 1.6"
|
|
|
|
default:
|
|
|
|
assistantMessage = fmt.Sprintf("Received temperature of %f", temperature)
|
|
|
|
}
|
|
|
|
|
|
|
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
|
|
|
Response: assistantMessage,
|
|
|
|
})
|
|
|
|
},
|
|
|
|
Setup: func(t *testing.T, req *http.Request) {
|
|
|
|
temp := float32(0.8)
|
|
|
|
body := CompletionRequest{
|
|
|
|
Model: "test-model",
|
|
|
|
Prompt: "Hello",
|
|
|
|
Temperature: &temp,
|
|
|
|
}
|
|
|
|
|
|
|
|
bodyBytes, _ := json.Marshal(body)
|
|
|
|
|
|
|
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
},
|
|
|
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
|
|
|
assert.Equal(t, http.StatusOK, resp.Code)
|
|
|
|
var completionResp Completion
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&completionResp); err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if completionResp.Object != "text_completion" {
|
|
|
|
t.Fatalf("expected text_completion, got %s", completionResp.Object)
|
|
|
|
}
|
|
|
|
|
|
|
|
if completionResp.Choices[0].Text != "Received temperature of 1.6" {
|
|
|
|
t.Fatalf("expected Received temperature of 1.6, got %s", completionResp.Choices[0].Text)
|
|
|
|
}
|
|
|
|
},
|
|
|
|
},
|
|
|
|
{
|
|
|
|
Name: "completions handler with error",
|
|
|
|
Method: http.MethodPost,
|
|
|
|
Path: "/api/generate",
|
|
|
|
TestPath: "/api/generate",
|
|
|
|
Handler: CompletionsMiddleware,
|
|
|
|
Endpoint: func(c *gin.Context) {
|
|
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
|
|
|
},
|
|
|
|
Setup: func(t *testing.T, req *http.Request) {
|
|
|
|
body := CompletionRequest{
|
|
|
|
Model: "test-model",
|
|
|
|
Prompt: "Hello",
|
|
|
|
}
|
|
|
|
|
|
|
|
bodyBytes, _ := json.Marshal(body)
|
|
|
|
|
|
|
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
},
|
|
|
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
|
|
|
if resp.Code != http.StatusBadRequest {
|
|
|
|
t.Fatalf("expected 400, got %d", resp.Code)
|
|
|
|
}
|
|
|
|
|
|
|
|
if !strings.Contains(resp.Body.String(), `"invalid request"`) {
|
|
|
|
t.Fatalf("error was not forwarded")
|
|
|
|
}
|
|
|
|
},
|
|
|
|
},
|
2024-07-02 11:50:56 -07:00
|
|
|
{
|
|
|
|
Name: "list handler",
|
|
|
|
Method: http.MethodGet,
|
|
|
|
Path: "/api/tags",
|
|
|
|
TestPath: "/api/tags",
|
|
|
|
Handler: ListMiddleware,
|
|
|
|
Endpoint: func(c *gin.Context) {
|
|
|
|
c.JSON(http.StatusOK, api.ListResponse{
|
|
|
|
Models: []api.ListModelResponse{
|
|
|
|
{
|
|
|
|
Name: "Test Model",
|
|
|
|
},
|
|
|
|
},
|
|
|
|
})
|
|
|
|
},
|
|
|
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
2024-07-02 16:01:45 -07:00
|
|
|
assert.Equal(t, http.StatusOK, resp.Code)
|
|
|
|
|
2024-07-02 11:50:56 -07:00
|
|
|
var listResp ListCompletion
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if listResp.Object != "list" {
|
|
|
|
t.Fatalf("expected list, got %s", listResp.Object)
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(listResp.Data) != 1 {
|
|
|
|
t.Fatalf("expected 1, got %d", len(listResp.Data))
|
|
|
|
}
|
|
|
|
|
|
|
|
if listResp.Data[0].Id != "Test Model" {
|
|
|
|
t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
|
|
|
|
}
|
|
|
|
},
|
|
|
|
},
|
|
|
|
{
|
|
|
|
Name: "retrieve model",
|
|
|
|
Method: http.MethodGet,
|
|
|
|
Path: "/api/show/:model",
|
|
|
|
TestPath: "/api/show/test-model",
|
|
|
|
Handler: RetrieveMiddleware,
|
|
|
|
Endpoint: func(c *gin.Context) {
|
|
|
|
c.JSON(http.StatusOK, api.ShowResponse{
|
|
|
|
ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
|
|
|
|
})
|
|
|
|
},
|
|
|
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
|
|
|
var retrieveResp Model
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if retrieveResp.Object != "model" {
|
|
|
|
t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
|
|
|
|
}
|
|
|
|
|
|
|
|
if retrieveResp.Id != "test-model" {
|
|
|
|
t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
|
|
|
|
}
|
|
|
|
},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
router := gin.New()
|
|
|
|
|
|
|
|
for _, tc := range testCases {
|
|
|
|
t.Run(tc.Name, func(t *testing.T) {
|
|
|
|
router = gin.New()
|
|
|
|
router.Use(tc.Handler())
|
|
|
|
router.Handle(tc.Method, tc.Path, tc.Endpoint)
|
|
|
|
req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
|
|
|
|
|
|
|
|
if tc.Setup != nil {
|
|
|
|
tc.Setup(t, req)
|
|
|
|
}
|
|
|
|
|
|
|
|
resp := httptest.NewRecorder()
|
|
|
|
router.ServeHTTP(resp, req)
|
|
|
|
|
|
|
|
tc.Expected(t, resp)
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|