// openai package provides middleware for partial compatibility with the OpenAI REST API package openai import ( "bytes" "encoding/base64" "encoding/json" "fmt" "io" "log/slog" "math/rand" "net/http" "strings" "time" "github.com/gin-gonic/gin" "github.com/ollama/ollama/api" "github.com/ollama/ollama/types/model" ) type Error struct { Message string `json:"message"` Type string `json:"type"` Param interface{} `json:"param"` Code *string `json:"code"` } type ErrorResponse struct { Error Error `json:"error"` } type Message struct { Role string `json:"role"` Content any `json:"content"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` } type Choice struct { Index int `json:"index"` Message Message `json:"message"` FinishReason *string `json:"finish_reason"` } type ChunkChoice struct { Index int `json:"index"` Delta Message `json:"delta"` FinishReason *string `json:"finish_reason"` } type CompleteChunkChoice struct { Text string `json:"text"` Index int `json:"index"` FinishReason *string `json:"finish_reason"` } type Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } type ResponseFormat struct { Type string `json:"type"` } type EmbedRequest struct { Input any `json:"input"` Model string `json:"model"` } type ChatCompletionRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` Stream bool `json:"stream"` MaxTokens *int `json:"max_tokens"` Seed *int `json:"seed"` Stop any `json:"stop"` Temperature *float64 `json:"temperature"` FrequencyPenalty *float64 `json:"frequency_penalty"` PresencePenalty *float64 `json:"presence_penalty_penalty"` TopP *float64 `json:"top_p"` ResponseFormat *ResponseFormat `json:"response_format"` Tools []api.Tool `json:"tools"` } type ChatCompletion struct { Id string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` Model string `json:"model"` SystemFingerprint string `json:"system_fingerprint"` Choices []Choice `json:"choices"` Usage Usage `json:"usage,omitempty"` } type ChatCompletionChunk struct { Id string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` Model string `json:"model"` SystemFingerprint string `json:"system_fingerprint"` Choices []ChunkChoice `json:"choices"` } // TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int type CompletionRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` FrequencyPenalty float32 `json:"frequency_penalty"` MaxTokens *int `json:"max_tokens"` PresencePenalty float32 `json:"presence_penalty"` Seed *int `json:"seed"` Stop any `json:"stop"` Stream bool `json:"stream"` Temperature *float32 `json:"temperature"` TopP float32 `json:"top_p"` Suffix string `json:"suffix"` } type Completion struct { Id string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` Model string `json:"model"` SystemFingerprint string `json:"system_fingerprint"` Choices []CompleteChunkChoice `json:"choices"` Usage Usage `json:"usage,omitempty"` } type CompletionChunk struct { Id string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` Choices []CompleteChunkChoice `json:"choices"` Model string `json:"model"` SystemFingerprint string `json:"system_fingerprint"` } type ToolCall struct { ID string `json:"id"` Type string `json:"type"` Function struct { Name string `json:"name"` Arguments string `json:"arguments"` } `json:"function"` } type Model struct { Id string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` OwnedBy string `json:"owned_by"` } type Embedding struct { Object string `json:"object"` Embedding []float32 `json:"embedding"` Index int `json:"index"` } type ListCompletion struct { Object string `json:"object"` Data []Model `json:"data"` } type EmbeddingList struct { Object string `json:"object"` Data []Embedding `json:"data"` Model string `json:"model"` } func NewError(code int, message string) ErrorResponse { var etype string switch code { case http.StatusBadRequest: etype = "invalid_request_error" case http.StatusNotFound: etype = "not_found_error" default: etype = "api_error" } return ErrorResponse{Error{Type: etype, Message: message}} } func toolCallId() string { const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789" b := make([]byte, 8) for i := range b { b[i] = letterBytes[rand.Intn(len(letterBytes))] } return "call_" + strings.ToLower(string(b)) } func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { toolCalls := make([]ToolCall, len(r.Message.ToolCalls)) for i, tc := range r.Message.ToolCalls { toolCalls[i].ID = toolCallId() toolCalls[i].Type = "function" toolCalls[i].Function.Name = tc.Function.Name args, err := json.Marshal(tc.Function.Arguments) if err != nil { slog.Error("could not marshall function arguments to json", "error", err) continue } toolCalls[i].Function.Arguments = string(args) } return ChatCompletion{ Id: id, Object: "chat.completion", Created: r.CreatedAt.Unix(), Model: r.Model, SystemFingerprint: "fp_ollama", Choices: []Choice{{ Index: 0, Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls}, FinishReason: func(reason string) *string { if len(toolCalls) > 0 { reason = "tool_calls" } if len(reason) > 0 { return &reason } return nil }(r.DoneReason), }}, Usage: Usage{ PromptTokens: r.PromptEvalCount, CompletionTokens: r.EvalCount, TotalTokens: r.PromptEvalCount + r.EvalCount, }, } } func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { return ChatCompletionChunk{ Id: id, Object: "chat.completion.chunk", Created: time.Now().Unix(), Model: r.Model, SystemFingerprint: "fp_ollama", Choices: []ChunkChoice{{ Index: 0, Delta: Message{Role: "assistant", Content: r.Message.Content}, FinishReason: func(reason string) *string { if len(reason) > 0 { return &reason } return nil }(r.DoneReason), }}, } } func toCompletion(id string, r api.GenerateResponse) Completion { return Completion{ Id: id, Object: "text_completion", Created: r.CreatedAt.Unix(), Model: r.Model, SystemFingerprint: "fp_ollama", Choices: []CompleteChunkChoice{{ Text: r.Response, Index: 0, FinishReason: func(reason string) *string { if len(reason) > 0 { return &reason } return nil }(r.DoneReason), }}, Usage: Usage{ PromptTokens: r.PromptEvalCount, CompletionTokens: r.EvalCount, TotalTokens: r.PromptEvalCount + r.EvalCount, }, } } func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk { return CompletionChunk{ Id: id, Object: "text_completion", Created: time.Now().Unix(), Model: r.Model, SystemFingerprint: "fp_ollama", Choices: []CompleteChunkChoice{{ Text: r.Response, Index: 0, FinishReason: func(reason string) *string { if len(reason) > 0 { return &reason } return nil }(r.DoneReason), }}, } } func toListCompletion(r api.ListResponse) ListCompletion { var data []Model for _, m := range r.Models { data = append(data, Model{ Id: m.Name, Object: "model", Created: m.ModifiedAt.Unix(), OwnedBy: model.ParseName(m.Name).Namespace, }) } return ListCompletion{ Object: "list", Data: data, } } func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList { if r.Embeddings != nil { var data []Embedding for i, e := range r.Embeddings { data = append(data, Embedding{ Object: "embedding", Embedding: e, Index: i, }) } return EmbeddingList{ Object: "list", Data: data, Model: model, } } return EmbeddingList{} } func toModel(r api.ShowResponse, m string) Model { return Model{ Id: m, Object: "model", Created: r.ModifiedAt.Unix(), OwnedBy: model.ParseName(m).Namespace, } } func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { var messages []api.Message for _, msg := range r.Messages { switch content := msg.Content.(type) { case string: messages = append(messages, api.Message{Role: msg.Role, Content: content}) case []any: for _, c := range content { data, ok := c.(map[string]any) if !ok { return nil, fmt.Errorf("invalid message format") } switch data["type"] { case "text": text, ok := data["text"].(string) if !ok { return nil, fmt.Errorf("invalid message format") } messages = append(messages, api.Message{Role: msg.Role, Content: text}) case "image_url": var url string if urlMap, ok := data["image_url"].(map[string]any); ok { if url, ok = urlMap["url"].(string); !ok { return nil, fmt.Errorf("invalid message format") } } else { if url, ok = data["image_url"].(string); !ok { return nil, fmt.Errorf("invalid message format") } } types := []string{"jpeg", "jpg", "png"} valid := false for _, t := range types { prefix := "data:image/" + t + ";base64," if strings.HasPrefix(url, prefix) { url = strings.TrimPrefix(url, prefix) valid = true break } } if !valid { return nil, fmt.Errorf("invalid image input") } img, err := base64.StdEncoding.DecodeString(url) if err != nil { return nil, fmt.Errorf("invalid message format") } messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}}) default: return nil, fmt.Errorf("invalid message format") } } default: if msg.ToolCalls == nil { return nil, fmt.Errorf("invalid message content type: %T", content) } toolCalls := make([]api.ToolCall, len(msg.ToolCalls)) for i, tc := range msg.ToolCalls { toolCalls[i].Function.Name = tc.Function.Name err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments) if err != nil { return nil, fmt.Errorf("invalid tool call arguments") } } messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls}) } } options := make(map[string]interface{}) switch stop := r.Stop.(type) { case string: options["stop"] = []string{stop} case []any: var stops []string for _, s := range stop { if str, ok := s.(string); ok { stops = append(stops, str) } } options["stop"] = stops } if r.MaxTokens != nil { options["num_predict"] = *r.MaxTokens } if r.Temperature != nil { options["temperature"] = *r.Temperature * 2.0 } else { options["temperature"] = 1.0 } if r.Seed != nil { options["seed"] = *r.Seed } if r.FrequencyPenalty != nil { options["frequency_penalty"] = *r.FrequencyPenalty * 2.0 } if r.PresencePenalty != nil { options["presence_penalty"] = *r.PresencePenalty * 2.0 } if r.TopP != nil { options["top_p"] = *r.TopP } else { options["top_p"] = 1.0 } var format string if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" { format = "json" } return &api.ChatRequest{ Model: r.Model, Messages: messages, Format: format, Options: options, Stream: &r.Stream, Tools: r.Tools, }, nil } func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { options := make(map[string]any) switch stop := r.Stop.(type) { case string: options["stop"] = []string{stop} case []any: var stops []string for _, s := range stop { if str, ok := s.(string); ok { stops = append(stops, str) } else { return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s) } } options["stop"] = stops } if r.MaxTokens != nil { options["num_predict"] = *r.MaxTokens } if r.Temperature != nil { options["temperature"] = *r.Temperature * 2.0 } else { options["temperature"] = 1.0 } if r.Seed != nil { options["seed"] = *r.Seed } options["frequency_penalty"] = r.FrequencyPenalty * 2.0 options["presence_penalty"] = r.PresencePenalty * 2.0 if r.TopP != 0.0 { options["top_p"] = r.TopP } else { options["top_p"] = 1.0 } return api.GenerateRequest{ Model: r.Model, Prompt: r.Prompt, Options: options, Stream: &r.Stream, Suffix: r.Suffix, }, nil } type BaseWriter struct { gin.ResponseWriter } type ChatWriter struct { stream bool id string BaseWriter } type CompleteWriter struct { stream bool id string BaseWriter } type ListWriter struct { BaseWriter } type RetrieveWriter struct { BaseWriter model string } type EmbedWriter struct { BaseWriter model string } func (w *BaseWriter) writeError(code int, data []byte) (int, error) { var serr api.StatusError err := json.Unmarshal(data, &serr) if err != nil { return 0, err } w.ResponseWriter.Header().Set("Content-Type", "application/json") err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error())) if err != nil { return 0, err } return len(data), nil } func (w *ChatWriter) writeResponse(data []byte) (int, error) { var chatResponse api.ChatResponse err := json.Unmarshal(data, &chatResponse) if err != nil { return 0, err } // chat chunk if w.stream { d, err := json.Marshal(toChunk(w.id, chatResponse)) if err != nil { return 0, err } w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) if err != nil { return 0, err } if chatResponse.Done { _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) if err != nil { return 0, err } } return len(data), nil } // chat completion w.ResponseWriter.Header().Set("Content-Type", "application/json") err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse)) if err != nil { return 0, err } return len(data), nil } func (w *ChatWriter) Write(data []byte) (int, error) { code := w.ResponseWriter.Status() if code != http.StatusOK { return w.writeError(code, data) } return w.writeResponse(data) } func (w *CompleteWriter) writeResponse(data []byte) (int, error) { var generateResponse api.GenerateResponse err := json.Unmarshal(data, &generateResponse) if err != nil { return 0, err } // completion chunk if w.stream { d, err := json.Marshal(toCompleteChunk(w.id, generateResponse)) if err != nil { return 0, err } w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) if err != nil { return 0, err } if generateResponse.Done { _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) if err != nil { return 0, err } } return len(data), nil } // completion w.ResponseWriter.Header().Set("Content-Type", "application/json") err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse)) if err != nil { return 0, err } return len(data), nil } func (w *CompleteWriter) Write(data []byte) (int, error) { code := w.ResponseWriter.Status() if code != http.StatusOK { return w.writeError(code, data) } return w.writeResponse(data) } func (w *ListWriter) writeResponse(data []byte) (int, error) { var listResponse api.ListResponse err := json.Unmarshal(data, &listResponse) if err != nil { return 0, err } w.ResponseWriter.Header().Set("Content-Type", "application/json") err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse)) if err != nil { return 0, err } return len(data), nil } func (w *ListWriter) Write(data []byte) (int, error) { code := w.ResponseWriter.Status() if code != http.StatusOK { return w.writeError(code, data) } return w.writeResponse(data) } func (w *RetrieveWriter) writeResponse(data []byte) (int, error) { var showResponse api.ShowResponse err := json.Unmarshal(data, &showResponse) if err != nil { return 0, err } // retrieve completion w.ResponseWriter.Header().Set("Content-Type", "application/json") err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model)) if err != nil { return 0, err } return len(data), nil } func (w *RetrieveWriter) Write(data []byte) (int, error) { code := w.ResponseWriter.Status() if code != http.StatusOK { return w.writeError(code, data) } return w.writeResponse(data) } func (w *EmbedWriter) writeResponse(data []byte) (int, error) { var embedResponse api.EmbedResponse err := json.Unmarshal(data, &embedResponse) if err != nil { return 0, err } w.ResponseWriter.Header().Set("Content-Type", "application/json") err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse)) if err != nil { return 0, err } return len(data), nil } func (w *EmbedWriter) Write(data []byte) (int, error) { code := w.ResponseWriter.Status() if code != http.StatusOK { return w.writeError(code, data) } return w.writeResponse(data) } func ListMiddleware() gin.HandlerFunc { return func(c *gin.Context) { w := &ListWriter{ BaseWriter: BaseWriter{ResponseWriter: c.Writer}, } c.Writer = w c.Next() } } func RetrieveMiddleware() gin.HandlerFunc { return func(c *gin.Context) { var b bytes.Buffer if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) return } c.Request.Body = io.NopCloser(&b) // response writer w := &RetrieveWriter{ BaseWriter: BaseWriter{ResponseWriter: c.Writer}, model: c.Param("model"), } c.Writer = w c.Next() } } func CompletionsMiddleware() gin.HandlerFunc { return func(c *gin.Context) { var req CompletionRequest err := c.ShouldBindJSON(&req) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) return } var b bytes.Buffer genReq, err := fromCompleteRequest(req) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) return } if err := json.NewEncoder(&b).Encode(genReq); err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) return } c.Request.Body = io.NopCloser(&b) w := &CompleteWriter{ BaseWriter: BaseWriter{ResponseWriter: c.Writer}, stream: req.Stream, id: fmt.Sprintf("cmpl-%d", rand.Intn(999)), } c.Writer = w c.Next() } } func EmbeddingsMiddleware() gin.HandlerFunc { return func(c *gin.Context) { var req EmbedRequest err := c.ShouldBindJSON(&req) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) return } if req.Input == "" { req.Input = []string{""} } if req.Input == nil { c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input")) return } if v, ok := req.Input.([]any); ok && len(v) == 0 { c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input")) return } var b bytes.Buffer if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) return } c.Request.Body = io.NopCloser(&b) w := &EmbedWriter{ BaseWriter: BaseWriter{ResponseWriter: c.Writer}, model: req.Model, } c.Writer = w c.Next() } } func ChatMiddleware() gin.HandlerFunc { return func(c *gin.Context) { var req ChatCompletionRequest err := c.ShouldBindJSON(&req) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) return } if len(req.Messages) == 0 { c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'")) return } var b bytes.Buffer chatReq, err := fromChatRequest(req) if err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) return } if err := json.NewEncoder(&b).Encode(chatReq); err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) return } c.Request.Body = io.NopCloser(&b) w := &ChatWriter{ BaseWriter: BaseWriter{ResponseWriter: c.Writer}, stream: req.Stream, id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)), } c.Writer = w c.Next() } }