diff --git a/docs/openai.md b/docs/openai.md new file mode 100644 index 00000000..e1f1919d --- /dev/null +++ b/docs/openai.md @@ -0,0 +1,140 @@ +# OpenAI compatibility + +Ollama provides experimental compatibility with parts of the [OpenAI API](https://platform.openai.com/docs/api-reference) to help connect existing applications to Ollama. + +> **Note:** OpenAI compatibility is experimental and is subject to major adjustments including breaking changes. For fully-featured access to the Ollama API, see the Ollama [Python library](https://github.com/ollama/ollama-python), [JavaScript library](https://github.com/ollama/ollama-js) and [REST API](https://github.com/jmorganca/ollama/blob/main/docs/api.md). + +## Usage + +### OpenAI Python library + +```python +from openai import OpenAI + +client = OpenAI( + base_url='http://localhost:11434/v1/', + + # required but ignored + api_key='ollama', +) + +chat_completion = client.chat.completions.create( + messages=[ + { + 'role': 'user', + 'content': 'Say this is a test', + } + ], + model='llama2', +) +``` + +### OpenAI JavaScript library + +```javascript +import OpenAI from 'openai' + +const openai = new OpenAI({ + baseURL: 'http://localhost:11434/v1/', + + // required but ignored + apiKey: 'ollama', +}) + +const chatCompletion = await openai.chat.completions.create({ + messages: [{ role: 'user', content: 'Say this is a test' }], + model: 'llama2', +}) +``` + +### `curl` + +``` +curl http://localhost:11434/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "llama2", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello!" + } + ] + }' +``` + +## Endpoints + +### `/v1/chat/completions` + +#### Supported features + +- [x] Chat completions +- [x] Streaming +- [x] JSON mode +- [x] Reproducible outputs +- [ ] Vision +- [ ] Function calling +- [ ] Logprobs + +#### Supported request fields + +- [x] `model` +- [x] `messages` + - [x] Text `content` + - [ ] Array of `content` parts +- [x] `frequency_penalty` +- [x] `presence_penalty` +- [x] `response_format` +- [x] `seed` +- [x] `stop` +- [x] `stream` +- [x] `temperature` +- [x] `top_p` +- [x] `max_tokens` +- [ ] `logit_bias` +- [ ] `tools` +- [ ] `tool_choice` +- [ ] `user` + +#### Notes + +- Setting `seed` will always set `temperature` to `0` +- `finish_reason` will always be `stop` +- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached + +## Models + +Before using a model, pull it locally `ollama pull`: + +```shell +ollama pull llama2 +``` + +### Default model names + +For tooling that relies on default OpenAI model names such as `gpt-3.5-turbo`, use `ollama cp` to copy an existing model name to a temporary name: + +``` +ollama cp llama2 gpt-3.5-turbo +``` + +Afterwards, this new model name can be specified the `model` field: + +```shell +curl http://localhost:11434/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": "Hello!" + } + ] + }' +``` diff --git a/openai/openai.go b/openai/openai.go new file mode 100644 index 00000000..4f495569 --- /dev/null +++ b/openai/openai.go @@ -0,0 +1,322 @@ +// openai package provides middleware for partial compatibility with the OpenAI REST API +package openai + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "math/rand" + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/jmorganca/ollama/api" +) + +type Error struct { + Message string `json:"message"` + Type string `json:"type"` + Param interface{} `json:"param"` + Code *string `json:"code"` +} + +type ErrorResponse struct { + Error Error `json:"error"` +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type Choice struct { + Index int `json:"index"` + Message Message `json:"message"` + FinishReason *string `json:"finish_reason"` +} + +type ChunkChoice struct { + Index int `json:"index"` + Delta Message `json:"delta"` + FinishReason *string `json:"finish_reason"` +} + +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type ResponseFormat struct { + Type string `json:"type"` +} + +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Stream bool `json:"stream"` + MaxTokens *int `json:"max_tokens"` + Seed *int `json:"seed"` + Stop any `json:"stop"` + Temperature *float64 `json:"temperature"` + FrequencyPenalty *float64 `json:"frequency_penalty"` + PresencePenalty *float64 `json:"presence_penalty_penalty"` + TopP *float64 `json:"top_p"` + ResponseFormat *ResponseFormat `json:"response_format"` +} + +type ChatCompletion struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []Choice `json:"choices"` + Usage Usage `json:"usage,omitempty"` +} + +type ChatCompletionChunk struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []ChunkChoice `json:"choices"` +} + +func NewError(code int, message string) ErrorResponse { + var etype string + switch code { + case http.StatusBadRequest: + etype = "invalid_request_error" + case http.StatusNotFound: + etype = "not_found_error" + default: + etype = "api_error" + } + + return ErrorResponse{Error{Type: etype, Message: message}} +} + +func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { + return ChatCompletion{ + Id: id, + Object: "chat.completion", + Created: r.CreatedAt.Unix(), + Model: r.Model, + SystemFingerprint: "fp_ollama", + Choices: []Choice{{ + Index: 0, + Message: Message{Role: r.Message.Role, Content: r.Message.Content}, + FinishReason: func(done bool) *string { + if done { + reason := "stop" + return &reason + } + return nil + }(r.Done), + }}, + Usage: Usage{ + // TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count + PromptTokens: r.PromptEvalCount, + CompletionTokens: r.EvalCount, + TotalTokens: r.PromptEvalCount + r.EvalCount, + }, + } +} + +func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { + return ChatCompletionChunk{ + Id: id, + Object: "chat.completion.chunk", + Created: time.Now().Unix(), + Model: r.Model, + SystemFingerprint: "fp_ollama", + Choices: []ChunkChoice{ + { + Index: 0, + Delta: Message{Role: "assistant", Content: r.Message.Content}, + FinishReason: func(done bool) *string { + if done { + reason := "stop" + return &reason + } + return nil + }(r.Done), + }, + }, + } +} + +func fromRequest(r ChatCompletionRequest) api.ChatRequest { + var messages []api.Message + for _, msg := range r.Messages { + messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content}) + } + + options := make(map[string]interface{}) + + switch stop := r.Stop.(type) { + case string: + options["stop"] = []string{stop} + case []interface{}: + var stops []string + for _, s := range stop { + if str, ok := s.(string); ok { + stops = append(stops, str) + } + } + options["stop"] = stops + } + + if r.MaxTokens != nil { + options["num_predict"] = *r.MaxTokens + } + + if r.Temperature != nil { + options["temperature"] = *r.Temperature * 2.0 + } else { + options["temperature"] = 1.0 + } + + if r.Seed != nil { + options["seed"] = *r.Seed + + // temperature=0 is required for reproducible outputs + options["temperature"] = 0.0 + } + + if r.FrequencyPenalty != nil { + options["frequency_penalty"] = *r.FrequencyPenalty * 2.0 + } + + if r.PresencePenalty != nil { + options["presence_penalty"] = *r.PresencePenalty * 2.0 + } + + if r.TopP != nil { + options["top_p"] = *r.TopP + } else { + options["top_p"] = 1.0 + } + + var format string + if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" { + format = "json" + } + + return api.ChatRequest{ + Model: r.Model, + Messages: messages, + Format: format, + Options: options, + Stream: &r.Stream, + } +} + +type writer struct { + stream bool + id string + gin.ResponseWriter +} + +func (w *writer) writeError(code int, data []byte) (int, error) { + var serr api.StatusError + err := json.Unmarshal(data, &serr) + if err != nil { + return 0, err + } + + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error())) + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *writer) writeResponse(data []byte) (int, error) { + var chatResponse api.ChatResponse + err := json.Unmarshal(data, &chatResponse) + if err != nil { + return 0, err + } + + // chat chunk + if w.stream { + d, err := json.Marshal(toChunk(w.id, chatResponse)) + if err != nil { + return 0, err + + } + + w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") + _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) + if err != nil { + return 0, err + } + + if chatResponse.Done { + _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) + if err != nil { + return 0, err + } + } + + return len(data), nil + } + + // chat completion + w.ResponseWriter.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse)) + if err != nil { + return 0, err + } + + return len(data), nil +} + +func (w *writer) Write(data []byte) (int, error) { + code := w.ResponseWriter.Status() + if code != http.StatusOK { + return w.writeError(code, data) + } + + return w.writeResponse(data) +} + +func Middleware() gin.HandlerFunc { + return func(c *gin.Context) { + var req ChatCompletionRequest + err := c.ShouldBindJSON(&req) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error())) + return + } + + if len(req.Messages) == 0 { + c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'")) + return + } + + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(fromRequest(req)); err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) + return + } + + c.Request.Body = io.NopCloser(&b) + + w := &writer{ + ResponseWriter: c.Writer, + stream: req.Stream, + id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)), + } + + c.Writer = w + + c.Next() + } +} diff --git a/server/routes.go b/server/routes.go index 7d1f9dfb..7be4c126 100644 --- a/server/routes.go +++ b/server/routes.go @@ -26,6 +26,7 @@ import ( "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/gpu" "github.com/jmorganca/ollama/llm" + "github.com/jmorganca/ollama/openai" "github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/version" ) @@ -935,6 +936,9 @@ func (s *Server) GenerateRoutes() http.Handler { r.POST("/api/blobs/:digest", CreateBlobHandler) r.HEAD("/api/blobs/:digest", HeadBlobHandler) + // Compatibility endpoints + r.POST("/v1/chat/completions", openai.Middleware(), ChatHandler) + for _, method := range []string{http.MethodGet, http.MethodHead} { r.Handle(method, "/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running")