OpenAI: /v1/models and /v1/models/{model} compatibility (#5007)
* OpenAI v1 models * Refactor Writers * Add Test Co-Authored-By: Attila Kerekes * Credit Co-Author Co-Authored-By: Attila Kerekes <439392+keriati@users.noreply.github.com> * Empty List Testing * Use Namespace for Ownedby * Update Test * Add back envconfig * v1/models docs * Use ModelName Parser * Test Names * Remove Docs * Clean Up * Test name Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com> * Add Middleware for Chat and List * Testing Cleanup * Test with Fatal * Add functionality to chat test * OpenAI: /v1/models/{model} compatibility (#5028) * Retrieve Model * OpenAI Delete Model * Retrieve Middleware * Remove Delete from Branch * Update Test * Middleware Test File * Function name * Cleanup * Test Update * Test Update --------- Co-authored-by: Attila Kerekes <439392+keriati@users.noreply.github.com> Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
This commit is contained in:
parent
422dcc3856
commit
996bb1b85e
6 changed files with 387 additions and 14 deletions
|
@ -345,6 +345,13 @@ type ProcessModelResponse struct {
|
||||||
SizeVRAM int64 `json:"size_vram"`
|
SizeVRAM int64 `json:"size_vram"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RetrieveModelResponse struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
OwnedBy string `json:"owned_by"`
|
||||||
|
}
|
||||||
|
|
||||||
type TokenResponse struct {
|
type TokenResponse struct {
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,6 +65,7 @@ curl http://localhost:11434/v1/chat/completions \
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}'
|
}'
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Endpoints
|
## Endpoints
|
||||||
|
|
163
openai/openai.go
163
openai/openai.go
|
@ -12,6 +12,7 @@ import (
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Error struct {
|
type Error struct {
|
||||||
|
@ -85,6 +86,18 @@ type ChatCompletionChunk struct {
|
||||||
Choices []ChunkChoice `json:"choices"`
|
Choices []ChunkChoice `json:"choices"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
OwnedBy string `json:"owned_by"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ListCompletion struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []Model `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
func NewError(code int, message string) ErrorResponse {
|
func NewError(code int, message string) ErrorResponse {
|
||||||
var etype string
|
var etype string
|
||||||
switch code {
|
switch code {
|
||||||
|
@ -145,7 +158,33 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func fromRequest(r ChatCompletionRequest) api.ChatRequest {
|
func toListCompletion(r api.ListResponse) ListCompletion {
|
||||||
|
var data []Model
|
||||||
|
for _, m := range r.Models {
|
||||||
|
data = append(data, Model{
|
||||||
|
Id: m.Name,
|
||||||
|
Object: "model",
|
||||||
|
Created: m.ModifiedAt.Unix(),
|
||||||
|
OwnedBy: model.ParseName(m.Name).Namespace,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return ListCompletion{
|
||||||
|
Object: "list",
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toModel(r api.ShowResponse, m string) Model {
|
||||||
|
return Model{
|
||||||
|
Id: m,
|
||||||
|
Object: "model",
|
||||||
|
Created: r.ModifiedAt.Unix(),
|
||||||
|
OwnedBy: model.ParseName(m).Namespace,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
|
||||||
var messages []api.Message
|
var messages []api.Message
|
||||||
for _, msg := range r.Messages {
|
for _, msg := range r.Messages {
|
||||||
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
|
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
|
||||||
|
@ -208,13 +247,26 @@ func fromRequest(r ChatCompletionRequest) api.ChatRequest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type writer struct {
|
type BaseWriter struct {
|
||||||
stream bool
|
|
||||||
id string
|
|
||||||
gin.ResponseWriter
|
gin.ResponseWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *writer) writeError(code int, data []byte) (int, error) {
|
type ChatWriter struct {
|
||||||
|
stream bool
|
||||||
|
id string
|
||||||
|
BaseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
type ListWriter struct {
|
||||||
|
BaseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
type RetrieveWriter struct {
|
||||||
|
BaseWriter
|
||||||
|
model string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
|
||||||
var serr api.StatusError
|
var serr api.StatusError
|
||||||
err := json.Unmarshal(data, &serr)
|
err := json.Unmarshal(data, &serr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -230,7 +282,7 @@ func (w *writer) writeError(code int, data []byte) (int, error) {
|
||||||
return len(data), nil
|
return len(data), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *writer) writeResponse(data []byte) (int, error) {
|
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||||
var chatResponse api.ChatResponse
|
var chatResponse api.ChatResponse
|
||||||
err := json.Unmarshal(data, &chatResponse)
|
err := json.Unmarshal(data, &chatResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -270,7 +322,7 @@ func (w *writer) writeResponse(data []byte) (int, error) {
|
||||||
return len(data), nil
|
return len(data), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *writer) Write(data []byte) (int, error) {
|
func (w *ChatWriter) Write(data []byte) (int, error) {
|
||||||
code := w.ResponseWriter.Status()
|
code := w.ResponseWriter.Status()
|
||||||
if code != http.StatusOK {
|
if code != http.StatusOK {
|
||||||
return w.writeError(code, data)
|
return w.writeError(code, data)
|
||||||
|
@ -279,7 +331,92 @@ func (w *writer) Write(data []byte) (int, error) {
|
||||||
return w.writeResponse(data)
|
return w.writeResponse(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Middleware() gin.HandlerFunc {
|
func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var listResponse api.ListResponse
|
||||||
|
err := json.Unmarshal(data, &listResponse)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ListWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(code, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
|
||||||
|
var showResponse api.ShowResponse
|
||||||
|
err := json.Unmarshal(data, &showResponse)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// retrieve completion
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(code, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
w := &ListWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func RetrieveMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
|
// response writer
|
||||||
|
w := &RetrieveWriter{
|
||||||
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
|
model: c.Param("model"),
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ChatMiddleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
var req ChatCompletionRequest
|
var req ChatCompletionRequest
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
|
@ -294,17 +431,17 @@ func Middleware() gin.HandlerFunc {
|
||||||
}
|
}
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := json.NewEncoder(&b).Encode(fromRequest(req)); err != nil {
|
if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(&b)
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
w := &writer{
|
w := &ChatWriter{
|
||||||
ResponseWriter: c.Writer,
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||||
stream: req.Stream,
|
stream: req.Stream,
|
||||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Writer = w
|
c.Writer = w
|
||||||
|
|
170
openai/openai_test.go
Normal file
170
openai/openai_test.go
Normal file
|
@ -0,0 +1,170 @@
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMiddleware(t *testing.T) {
|
||||||
|
type testCase struct {
|
||||||
|
Name string
|
||||||
|
Method string
|
||||||
|
Path string
|
||||||
|
TestPath string
|
||||||
|
Handler func() gin.HandlerFunc
|
||||||
|
Endpoint func(c *gin.Context)
|
||||||
|
Setup func(t *testing.T, req *http.Request)
|
||||||
|
Expected func(t *testing.T, resp *httptest.ResponseRecorder)
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []testCase{
|
||||||
|
{
|
||||||
|
Name: "chat handler",
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Path: "/api/chat",
|
||||||
|
TestPath: "/api/chat",
|
||||||
|
Handler: ChatMiddleware,
|
||||||
|
Endpoint: func(c *gin.Context) {
|
||||||
|
var chatReq api.ChatRequest
|
||||||
|
if err := c.ShouldBindJSON(&chatReq); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userMessage := chatReq.Messages[0].Content
|
||||||
|
var assistantMessage string
|
||||||
|
|
||||||
|
switch userMessage {
|
||||||
|
case "Hello":
|
||||||
|
assistantMessage = "Hello!"
|
||||||
|
default:
|
||||||
|
assistantMessage = "I'm not sure how to respond to that."
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, api.ChatResponse{
|
||||||
|
Message: api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: assistantMessage,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
},
|
||||||
|
Setup: func(t *testing.T, req *http.Request) {
|
||||||
|
body := ChatCompletionRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
Messages: []Message{{Role: "user", Content: "Hello"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, _ := json.Marshal(body)
|
||||||
|
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||||
|
var chatResp ChatCompletion
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if chatResp.Object != "chat.completion" {
|
||||||
|
t.Fatalf("expected chat.completion, got %s", chatResp.Object)
|
||||||
|
}
|
||||||
|
|
||||||
|
if chatResp.Choices[0].Message.Content != "Hello!" {
|
||||||
|
t.Fatalf("expected Hello!, got %s", chatResp.Choices[0].Message.Content)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "list handler",
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/api/tags",
|
||||||
|
TestPath: "/api/tags",
|
||||||
|
Handler: ListMiddleware,
|
||||||
|
Endpoint: func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, api.ListResponse{
|
||||||
|
Models: []api.ListModelResponse{
|
||||||
|
{
|
||||||
|
Name: "Test Model",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||||
|
var listResp ListCompletion
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if listResp.Object != "list" {
|
||||||
|
t.Fatalf("expected list, got %s", listResp.Object)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(listResp.Data) != 1 {
|
||||||
|
t.Fatalf("expected 1, got %d", len(listResp.Data))
|
||||||
|
}
|
||||||
|
|
||||||
|
if listResp.Data[0].Id != "Test Model" {
|
||||||
|
t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "retrieve model",
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/api/show/:model",
|
||||||
|
TestPath: "/api/show/test-model",
|
||||||
|
Handler: RetrieveMiddleware,
|
||||||
|
Endpoint: func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, api.ShowResponse{
|
||||||
|
ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
|
||||||
|
})
|
||||||
|
},
|
||||||
|
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||||
|
var retrieveResp Model
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if retrieveResp.Object != "model" {
|
||||||
|
t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
|
||||||
|
}
|
||||||
|
|
||||||
|
if retrieveResp.Id != "test-model" {
|
||||||
|
t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
router := gin.New()
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
|
router = gin.New()
|
||||||
|
router.Use(tc.Handler())
|
||||||
|
router.Handle(tc.Method, tc.Path, tc.Endpoint)
|
||||||
|
req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
|
||||||
|
|
||||||
|
if tc.Setup != nil {
|
||||||
|
tc.Setup(t, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, resp.Code)
|
||||||
|
|
||||||
|
tc.Expected(t, resp)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -1039,7 +1039,9 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||||
r.GET("/api/ps", s.ProcessHandler)
|
r.GET("/api/ps", s.ProcessHandler)
|
||||||
|
|
||||||
// Compatibility endpoints
|
// Compatibility endpoints
|
||||||
r.POST("/v1/chat/completions", openai.Middleware(), s.ChatHandler)
|
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
|
||||||
|
r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
|
||||||
|
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
|
||||||
|
|
||||||
for _, method := range []string{http.MethodGet, http.MethodHead} {
|
for _, method := range []string{http.MethodGet, http.MethodHead} {
|
||||||
r.Handle(method, "/", func(c *gin.Context) {
|
r.Handle(method, "/", func(c *gin.Context) {
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
|
"github.com/ollama/ollama/openai"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
|
@ -105,6 +106,24 @@ func Test_Routes(t *testing.T) {
|
||||||
assert.Empty(t, len(modelList.Models))
|
assert.Empty(t, len(modelList.Models))
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "openai empty list",
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/v1/models",
|
||||||
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
assert.Equal(t, "application/json", contentType)
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var modelList openai.ListCompletion
|
||||||
|
err = json.Unmarshal(body, &modelList)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "list", modelList.Object)
|
||||||
|
assert.Empty(t, modelList.Data)
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Name: "Tags Handler (yes tags)",
|
Name: "Tags Handler (yes tags)",
|
||||||
Method: http.MethodGet,
|
Method: http.MethodGet,
|
||||||
|
@ -128,6 +147,25 @@ func Test_Routes(t *testing.T) {
|
||||||
assert.Equal(t, "test-model:latest", modelList.Models[0].Name)
|
assert.Equal(t, "test-model:latest", modelList.Models[0].Name)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "openai list models with tags",
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/v1/models",
|
||||||
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
assert.Equal(t, "application/json", contentType)
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var modelList openai.ListCompletion
|
||||||
|
err = json.Unmarshal(body, &modelList)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, modelList.Data, 1)
|
||||||
|
assert.Equal(t, "test-model:latest", modelList.Data[0].Id)
|
||||||
|
assert.Equal(t, "library", modelList.Data[0].OwnedBy)
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Name: "Create Model Handler",
|
Name: "Create Model Handler",
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
|
@ -216,6 +254,24 @@ func Test_Routes(t *testing.T) {
|
||||||
assert.InDelta(t, 0, showResp.ModelInfo["general.parameter_count"], 1e-9, "Parameter count should be 0")
|
assert.InDelta(t, 0, showResp.ModelInfo["general.parameter_count"], 1e-9, "Parameter count should be 0")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "openai retrieve model handler",
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/v1/models/show-model",
|
||||||
|
Expected: func(t *testing.T, resp *http.Response) {
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
assert.Equal(t, "application/json", contentType)
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var retrieveResp api.RetrieveModelResponse
|
||||||
|
err = json.Unmarshal(body, &retrieveResp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "show-model", retrieveResp.Id)
|
||||||
|
assert.Equal(t, "library", retrieveResp.OwnedBy)
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||||
|
|
Loading…
Reference in a new issue