920 lines
22 KiB
Go
920 lines
22 KiB
Go
// openai package provides middleware for partial compatibility with the OpenAI REST API
|
|
package openai
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"math/rand"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/types/model"
|
|
)
|
|
|
|
type Error struct {
|
|
Message string `json:"message"`
|
|
Type string `json:"type"`
|
|
Param interface{} `json:"param"`
|
|
Code *string `json:"code"`
|
|
}
|
|
|
|
type ErrorResponse struct {
|
|
Error Error `json:"error"`
|
|
}
|
|
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
Content any `json:"content"`
|
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
|
}
|
|
|
|
type Choice struct {
|
|
Index int `json:"index"`
|
|
Message Message `json:"message"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
}
|
|
|
|
type ChunkChoice struct {
|
|
Index int `json:"index"`
|
|
Delta Message `json:"delta"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
}
|
|
|
|
type CompleteChunkChoice struct {
|
|
Text string `json:"text"`
|
|
Index int `json:"index"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
}
|
|
|
|
type Usage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
}
|
|
|
|
type ResponseFormat struct {
|
|
Type string `json:"type"`
|
|
}
|
|
|
|
type EmbedRequest struct {
|
|
Input any `json:"input"`
|
|
Model string `json:"model"`
|
|
}
|
|
|
|
type ChatCompletionRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []Message `json:"messages"`
|
|
Stream bool `json:"stream"`
|
|
MaxTokens *int `json:"max_tokens"`
|
|
Seed *int `json:"seed"`
|
|
Stop any `json:"stop"`
|
|
Temperature *float64 `json:"temperature"`
|
|
FrequencyPenalty *float64 `json:"frequency_penalty"`
|
|
PresencePenalty *float64 `json:"presence_penalty"`
|
|
TopP *float64 `json:"top_p"`
|
|
ResponseFormat *ResponseFormat `json:"response_format"`
|
|
Tools []api.Tool `json:"tools"`
|
|
}
|
|
|
|
type ChatCompletion struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
Choices []Choice `json:"choices"`
|
|
Usage Usage `json:"usage,omitempty"`
|
|
}
|
|
|
|
type ChatCompletionChunk struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
Choices []ChunkChoice `json:"choices"`
|
|
}
|
|
|
|
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
|
|
type CompletionRequest struct {
|
|
Model string `json:"model"`
|
|
Prompt string `json:"prompt"`
|
|
FrequencyPenalty float32 `json:"frequency_penalty"`
|
|
MaxTokens *int `json:"max_tokens"`
|
|
PresencePenalty float32 `json:"presence_penalty"`
|
|
Seed *int `json:"seed"`
|
|
Stop any `json:"stop"`
|
|
Stream bool `json:"stream"`
|
|
Temperature *float32 `json:"temperature"`
|
|
TopP float32 `json:"top_p"`
|
|
Suffix string `json:"suffix"`
|
|
}
|
|
|
|
type Completion struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
Choices []CompleteChunkChoice `json:"choices"`
|
|
Usage Usage `json:"usage,omitempty"`
|
|
}
|
|
|
|
type CompletionChunk struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Choices []CompleteChunkChoice `json:"choices"`
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
}
|
|
|
|
type ToolCall struct {
|
|
ID string `json:"id"`
|
|
Index int `json:"index"`
|
|
Type string `json:"type"`
|
|
Function struct {
|
|
Name string `json:"name"`
|
|
Arguments string `json:"arguments"`
|
|
} `json:"function"`
|
|
}
|
|
|
|
type Model struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
OwnedBy string `json:"owned_by"`
|
|
}
|
|
|
|
type Embedding struct {
|
|
Object string `json:"object"`
|
|
Embedding []float32 `json:"embedding"`
|
|
Index int `json:"index"`
|
|
}
|
|
|
|
type ListCompletion struct {
|
|
Object string `json:"object"`
|
|
Data []Model `json:"data"`
|
|
}
|
|
|
|
type EmbeddingList struct {
|
|
Object string `json:"object"`
|
|
Data []Embedding `json:"data"`
|
|
Model string `json:"model"`
|
|
Usage EmbeddingUsage `json:"usage,omitempty"`
|
|
}
|
|
|
|
type EmbeddingUsage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
}
|
|
|
|
func NewError(code int, message string) ErrorResponse {
|
|
var etype string
|
|
switch code {
|
|
case http.StatusBadRequest:
|
|
etype = "invalid_request_error"
|
|
case http.StatusNotFound:
|
|
etype = "not_found_error"
|
|
default:
|
|
etype = "api_error"
|
|
}
|
|
|
|
return ErrorResponse{Error{Type: etype, Message: message}}
|
|
}
|
|
|
|
func toolCallId() string {
|
|
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
|
b := make([]byte, 8)
|
|
for i := range b {
|
|
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
|
}
|
|
return "call_" + strings.ToLower(string(b))
|
|
}
|
|
|
|
func toToolCalls(tc []api.ToolCall) []ToolCall {
|
|
toolCalls := make([]ToolCall, len(tc))
|
|
for i, tc := range tc {
|
|
toolCalls[i].ID = toolCallId()
|
|
toolCalls[i].Type = "function"
|
|
toolCalls[i].Function.Name = tc.Function.Name
|
|
toolCalls[i].Index = tc.Function.Index
|
|
|
|
args, err := json.Marshal(tc.Function.Arguments)
|
|
if err != nil {
|
|
slog.Error("could not marshall function arguments to json", "error", err)
|
|
continue
|
|
}
|
|
|
|
toolCalls[i].Function.Arguments = string(args)
|
|
}
|
|
return toolCalls
|
|
}
|
|
|
|
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|
toolCalls := toToolCalls(r.Message.ToolCalls)
|
|
return ChatCompletion{
|
|
Id: id,
|
|
Object: "chat.completion",
|
|
Created: r.CreatedAt.Unix(),
|
|
Model: r.Model,
|
|
SystemFingerprint: "fp_ollama",
|
|
Choices: []Choice{{
|
|
Index: 0,
|
|
Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls},
|
|
FinishReason: func(reason string) *string {
|
|
if len(toolCalls) > 0 {
|
|
reason = "tool_calls"
|
|
}
|
|
if len(reason) > 0 {
|
|
return &reason
|
|
}
|
|
return nil
|
|
}(r.DoneReason),
|
|
}},
|
|
Usage: Usage{
|
|
PromptTokens: r.PromptEvalCount,
|
|
CompletionTokens: r.EvalCount,
|
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
|
},
|
|
}
|
|
}
|
|
|
|
func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
|
toolCalls := toToolCalls(r.Message.ToolCalls)
|
|
return ChatCompletionChunk{
|
|
Id: id,
|
|
Object: "chat.completion.chunk",
|
|
Created: time.Now().Unix(),
|
|
Model: r.Model,
|
|
SystemFingerprint: "fp_ollama",
|
|
Choices: []ChunkChoice{{
|
|
Index: 0,
|
|
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
|
|
FinishReason: func(reason string) *string {
|
|
if len(reason) > 0 {
|
|
return &reason
|
|
}
|
|
return nil
|
|
}(r.DoneReason),
|
|
}},
|
|
}
|
|
}
|
|
|
|
func toCompletion(id string, r api.GenerateResponse) Completion {
|
|
return Completion{
|
|
Id: id,
|
|
Object: "text_completion",
|
|
Created: r.CreatedAt.Unix(),
|
|
Model: r.Model,
|
|
SystemFingerprint: "fp_ollama",
|
|
Choices: []CompleteChunkChoice{{
|
|
Text: r.Response,
|
|
Index: 0,
|
|
FinishReason: func(reason string) *string {
|
|
if len(reason) > 0 {
|
|
return &reason
|
|
}
|
|
return nil
|
|
}(r.DoneReason),
|
|
}},
|
|
Usage: Usage{
|
|
PromptTokens: r.PromptEvalCount,
|
|
CompletionTokens: r.EvalCount,
|
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
|
},
|
|
}
|
|
}
|
|
|
|
func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
|
return CompletionChunk{
|
|
Id: id,
|
|
Object: "text_completion",
|
|
Created: time.Now().Unix(),
|
|
Model: r.Model,
|
|
SystemFingerprint: "fp_ollama",
|
|
Choices: []CompleteChunkChoice{{
|
|
Text: r.Response,
|
|
Index: 0,
|
|
FinishReason: func(reason string) *string {
|
|
if len(reason) > 0 {
|
|
return &reason
|
|
}
|
|
return nil
|
|
}(r.DoneReason),
|
|
}},
|
|
}
|
|
}
|
|
|
|
func toListCompletion(r api.ListResponse) ListCompletion {
|
|
var data []Model
|
|
for _, m := range r.Models {
|
|
data = append(data, Model{
|
|
Id: m.Name,
|
|
Object: "model",
|
|
Created: m.ModifiedAt.Unix(),
|
|
OwnedBy: model.ParseName(m.Name).Namespace,
|
|
})
|
|
}
|
|
|
|
return ListCompletion{
|
|
Object: "list",
|
|
Data: data,
|
|
}
|
|
}
|
|
|
|
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
|
if r.Embeddings != nil {
|
|
var data []Embedding
|
|
for i, e := range r.Embeddings {
|
|
data = append(data, Embedding{
|
|
Object: "embedding",
|
|
Embedding: e,
|
|
Index: i,
|
|
})
|
|
}
|
|
|
|
return EmbeddingList{
|
|
Object: "list",
|
|
Data: data,
|
|
Model: model,
|
|
Usage: EmbeddingUsage{
|
|
PromptTokens: r.PromptEvalCount,
|
|
TotalTokens: r.PromptEvalCount,
|
|
},
|
|
}
|
|
}
|
|
|
|
return EmbeddingList{}
|
|
}
|
|
|
|
func toModel(r api.ShowResponse, m string) Model {
|
|
return Model{
|
|
Id: m,
|
|
Object: "model",
|
|
Created: r.ModifiedAt.Unix(),
|
|
OwnedBy: model.ParseName(m).Namespace,
|
|
}
|
|
}
|
|
|
|
func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|
var messages []api.Message
|
|
for _, msg := range r.Messages {
|
|
switch content := msg.Content.(type) {
|
|
case string:
|
|
messages = append(messages, api.Message{Role: msg.Role, Content: content})
|
|
case []any:
|
|
for _, c := range content {
|
|
data, ok := c.(map[string]any)
|
|
if !ok {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
switch data["type"] {
|
|
case "text":
|
|
text, ok := data["text"].(string)
|
|
if !ok {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
messages = append(messages, api.Message{Role: msg.Role, Content: text})
|
|
case "image_url":
|
|
var url string
|
|
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
|
if url, ok = urlMap["url"].(string); !ok {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
} else {
|
|
if url, ok = data["image_url"].(string); !ok {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
}
|
|
|
|
types := []string{"jpeg", "jpg", "png"}
|
|
valid := false
|
|
for _, t := range types {
|
|
prefix := "data:image/" + t + ";base64,"
|
|
if strings.HasPrefix(url, prefix) {
|
|
url = strings.TrimPrefix(url, prefix)
|
|
valid = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !valid {
|
|
return nil, errors.New("invalid image input")
|
|
}
|
|
|
|
img, err := base64.StdEncoding.DecodeString(url)
|
|
if err != nil {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
|
|
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
|
|
default:
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
}
|
|
default:
|
|
if msg.ToolCalls == nil {
|
|
return nil, fmt.Errorf("invalid message content type: %T", content)
|
|
}
|
|
|
|
toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
|
|
for i, tc := range msg.ToolCalls {
|
|
toolCalls[i].Function.Name = tc.Function.Name
|
|
err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
|
|
if err != nil {
|
|
return nil, errors.New("invalid tool call arguments")
|
|
}
|
|
}
|
|
messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
|
|
}
|
|
}
|
|
|
|
options := make(map[string]interface{})
|
|
|
|
switch stop := r.Stop.(type) {
|
|
case string:
|
|
options["stop"] = []string{stop}
|
|
case []any:
|
|
var stops []string
|
|
for _, s := range stop {
|
|
if str, ok := s.(string); ok {
|
|
stops = append(stops, str)
|
|
}
|
|
}
|
|
options["stop"] = stops
|
|
}
|
|
|
|
if r.MaxTokens != nil {
|
|
options["num_predict"] = *r.MaxTokens
|
|
}
|
|
|
|
if r.Temperature != nil {
|
|
options["temperature"] = *r.Temperature
|
|
} else {
|
|
options["temperature"] = 1.0
|
|
}
|
|
|
|
if r.Seed != nil {
|
|
options["seed"] = *r.Seed
|
|
}
|
|
|
|
if r.FrequencyPenalty != nil {
|
|
options["frequency_penalty"] = *r.FrequencyPenalty
|
|
}
|
|
|
|
if r.PresencePenalty != nil {
|
|
options["presence_penalty"] = *r.PresencePenalty
|
|
}
|
|
|
|
if r.TopP != nil {
|
|
options["top_p"] = *r.TopP
|
|
} else {
|
|
options["top_p"] = 1.0
|
|
}
|
|
|
|
var format string
|
|
if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
|
|
format = "json"
|
|
}
|
|
|
|
return &api.ChatRequest{
|
|
Model: r.Model,
|
|
Messages: messages,
|
|
Format: format,
|
|
Options: options,
|
|
Stream: &r.Stream,
|
|
Tools: r.Tools,
|
|
}, nil
|
|
}
|
|
|
|
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
|
options := make(map[string]any)
|
|
|
|
switch stop := r.Stop.(type) {
|
|
case string:
|
|
options["stop"] = []string{stop}
|
|
case []any:
|
|
var stops []string
|
|
for _, s := range stop {
|
|
if str, ok := s.(string); ok {
|
|
stops = append(stops, str)
|
|
} else {
|
|
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
|
|
}
|
|
}
|
|
options["stop"] = stops
|
|
}
|
|
|
|
if r.MaxTokens != nil {
|
|
options["num_predict"] = *r.MaxTokens
|
|
}
|
|
|
|
if r.Temperature != nil {
|
|
options["temperature"] = *r.Temperature
|
|
} else {
|
|
options["temperature"] = 1.0
|
|
}
|
|
|
|
if r.Seed != nil {
|
|
options["seed"] = *r.Seed
|
|
}
|
|
|
|
options["frequency_penalty"] = r.FrequencyPenalty
|
|
|
|
options["presence_penalty"] = r.PresencePenalty
|
|
|
|
if r.TopP != 0.0 {
|
|
options["top_p"] = r.TopP
|
|
} else {
|
|
options["top_p"] = 1.0
|
|
}
|
|
|
|
return api.GenerateRequest{
|
|
Model: r.Model,
|
|
Prompt: r.Prompt,
|
|
Options: options,
|
|
Stream: &r.Stream,
|
|
Suffix: r.Suffix,
|
|
}, nil
|
|
}
|
|
|
|
type BaseWriter struct {
|
|
gin.ResponseWriter
|
|
}
|
|
|
|
type ChatWriter struct {
|
|
stream bool
|
|
id string
|
|
BaseWriter
|
|
}
|
|
|
|
type CompleteWriter struct {
|
|
stream bool
|
|
id string
|
|
BaseWriter
|
|
}
|
|
|
|
type ListWriter struct {
|
|
BaseWriter
|
|
}
|
|
|
|
type RetrieveWriter struct {
|
|
BaseWriter
|
|
model string
|
|
}
|
|
|
|
type EmbedWriter struct {
|
|
BaseWriter
|
|
model string
|
|
}
|
|
|
|
func (w *BaseWriter) writeError(data []byte) (int, error) {
|
|
var serr api.StatusError
|
|
err := json.Unmarshal(data, &serr)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
|
var chatResponse api.ChatResponse
|
|
err := json.Unmarshal(data, &chatResponse)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// chat chunk
|
|
if w.stream {
|
|
d, err := json.Marshal(toChunk(w.id, chatResponse))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if chatResponse.Done {
|
|
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
// chat completion
|
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
func (w *ChatWriter) Write(data []byte) (int, error) {
|
|
code := w.ResponseWriter.Status()
|
|
if code != http.StatusOK {
|
|
return w.writeError(data)
|
|
}
|
|
|
|
return w.writeResponse(data)
|
|
}
|
|
|
|
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
|
var generateResponse api.GenerateResponse
|
|
err := json.Unmarshal(data, &generateResponse)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// completion chunk
|
|
if w.stream {
|
|
d, err := json.Marshal(toCompleteChunk(w.id, generateResponse))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if generateResponse.Done {
|
|
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
// completion
|
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
func (w *CompleteWriter) Write(data []byte) (int, error) {
|
|
code := w.ResponseWriter.Status()
|
|
if code != http.StatusOK {
|
|
return w.writeError(data)
|
|
}
|
|
|
|
return w.writeResponse(data)
|
|
}
|
|
|
|
func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
|
var listResponse api.ListResponse
|
|
err := json.Unmarshal(data, &listResponse)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
func (w *ListWriter) Write(data []byte) (int, error) {
|
|
code := w.ResponseWriter.Status()
|
|
if code != http.StatusOK {
|
|
return w.writeError(data)
|
|
}
|
|
|
|
return w.writeResponse(data)
|
|
}
|
|
|
|
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
|
|
var showResponse api.ShowResponse
|
|
err := json.Unmarshal(data, &showResponse)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// retrieve completion
|
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
|
code := w.ResponseWriter.Status()
|
|
if code != http.StatusOK {
|
|
return w.writeError(data)
|
|
}
|
|
|
|
return w.writeResponse(data)
|
|
}
|
|
|
|
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
|
var embedResponse api.EmbedResponse
|
|
err := json.Unmarshal(data, &embedResponse)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
|
code := w.ResponseWriter.Status()
|
|
if code != http.StatusOK {
|
|
return w.writeError(data)
|
|
}
|
|
|
|
return w.writeResponse(data)
|
|
}
|
|
|
|
func ListMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
w := &ListWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
}
|
|
|
|
c.Writer = w
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func RetrieveMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
var b bytes.Buffer
|
|
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
return
|
|
}
|
|
|
|
c.Request.Body = io.NopCloser(&b)
|
|
|
|
// response writer
|
|
w := &RetrieveWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
model: c.Param("model"),
|
|
}
|
|
|
|
c.Writer = w
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func CompletionsMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
var req CompletionRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
return
|
|
}
|
|
|
|
var b bytes.Buffer
|
|
genReq, err := fromCompleteRequest(req)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
return
|
|
}
|
|
|
|
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
return
|
|
}
|
|
|
|
c.Request.Body = io.NopCloser(&b)
|
|
|
|
w := &CompleteWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
stream: req.Stream,
|
|
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
|
}
|
|
|
|
c.Writer = w
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func EmbeddingsMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
var req EmbedRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
return
|
|
}
|
|
|
|
if req.Input == "" {
|
|
req.Input = []string{""}
|
|
}
|
|
|
|
if req.Input == nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
|
return
|
|
}
|
|
|
|
if v, ok := req.Input.([]any); ok && len(v) == 0 {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
|
return
|
|
}
|
|
|
|
var b bytes.Buffer
|
|
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
return
|
|
}
|
|
|
|
c.Request.Body = io.NopCloser(&b)
|
|
|
|
w := &EmbedWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
model: req.Model,
|
|
}
|
|
|
|
c.Writer = w
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func ChatMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
var req ChatCompletionRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
return
|
|
}
|
|
|
|
if len(req.Messages) == 0 {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
|
|
return
|
|
}
|
|
|
|
var b bytes.Buffer
|
|
|
|
chatReq, err := fromChatRequest(req)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
return
|
|
}
|
|
|
|
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
return
|
|
}
|
|
|
|
c.Request.Body = io.NopCloser(&b)
|
|
|
|
w := &ChatWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
stream: req.Stream,
|
|
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
|
}
|
|
|
|
c.Writer = w
|
|
|
|
c.Next()
|
|
}
|
|
}
|