Initial OpenAI /v1/chat/completions
API compatibility (#2376)
This commit is contained in:
parent
c9dfa6e571
commit
453f572f83
3 changed files with 466 additions and 0 deletions
140
docs/openai.md
Normal file
140
docs/openai.md
Normal file
|
@ -0,0 +1,140 @@
|
||||||
|
# OpenAI compatibility
|
||||||
|
|
||||||
|
Ollama provides experimental compatibility with parts of the [OpenAI API](https://platform.openai.com/docs/api-reference) to help connect existing applications to Ollama.
|
||||||
|
|
||||||
|
> **Note:** OpenAI compatibility is experimental and is subject to major adjustments including breaking changes. For fully-featured access to the Ollama API, see the Ollama [Python library](https://github.com/ollama/ollama-python), [JavaScript library](https://github.com/ollama/ollama-js) and [REST API](https://github.com/jmorganca/ollama/blob/main/docs/api.md).
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### OpenAI Python library
|
||||||
|
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
base_url='http://localhost:11434/v1/',
|
||||||
|
|
||||||
|
# required but ignored
|
||||||
|
api_key='ollama',
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_completion = client.chat.completions.create(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
'role': 'user',
|
||||||
|
'content': 'Say this is a test',
|
||||||
|
}
|
||||||
|
],
|
||||||
|
model='llama2',
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### OpenAI JavaScript library
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
import OpenAI from 'openai'
|
||||||
|
|
||||||
|
const openai = new OpenAI({
|
||||||
|
baseURL: 'http://localhost:11434/v1/',
|
||||||
|
|
||||||
|
// required but ignored
|
||||||
|
apiKey: 'ollama',
|
||||||
|
})
|
||||||
|
|
||||||
|
const chatCompletion = await openai.chat.completions.create({
|
||||||
|
messages: [{ role: 'user', content: 'Say this is a test' }],
|
||||||
|
model: 'llama2',
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### `curl`
|
||||||
|
|
||||||
|
```
|
||||||
|
curl http://localhost:11434/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "llama2",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello!"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Endpoints
|
||||||
|
|
||||||
|
### `/v1/chat/completions`
|
||||||
|
|
||||||
|
#### Supported features
|
||||||
|
|
||||||
|
- [x] Chat completions
|
||||||
|
- [x] Streaming
|
||||||
|
- [x] JSON mode
|
||||||
|
- [x] Reproducible outputs
|
||||||
|
- [ ] Vision
|
||||||
|
- [ ] Function calling
|
||||||
|
- [ ] Logprobs
|
||||||
|
|
||||||
|
#### Supported request fields
|
||||||
|
|
||||||
|
- [x] `model`
|
||||||
|
- [x] `messages`
|
||||||
|
- [x] Text `content`
|
||||||
|
- [ ] Array of `content` parts
|
||||||
|
- [x] `frequency_penalty`
|
||||||
|
- [x] `presence_penalty`
|
||||||
|
- [x] `response_format`
|
||||||
|
- [x] `seed`
|
||||||
|
- [x] `stop`
|
||||||
|
- [x] `stream`
|
||||||
|
- [x] `temperature`
|
||||||
|
- [x] `top_p`
|
||||||
|
- [x] `max_tokens`
|
||||||
|
- [ ] `logit_bias`
|
||||||
|
- [ ] `tools`
|
||||||
|
- [ ] `tool_choice`
|
||||||
|
- [ ] `user`
|
||||||
|
|
||||||
|
#### Notes
|
||||||
|
|
||||||
|
- Setting `seed` will always set `temperature` to `0`
|
||||||
|
- `finish_reason` will always be `stop`
|
||||||
|
- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
|
||||||
|
|
||||||
|
## Models
|
||||||
|
|
||||||
|
Before using a model, pull it locally `ollama pull`:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
ollama pull llama2
|
||||||
|
```
|
||||||
|
|
||||||
|
### Default model names
|
||||||
|
|
||||||
|
For tooling that relies on default OpenAI model names such as `gpt-3.5-turbo`, use `ollama cp` to copy an existing model name to a temporary name:
|
||||||
|
|
||||||
|
```
|
||||||
|
ollama cp llama2 gpt-3.5-turbo
|
||||||
|
```
|
||||||
|
|
||||||
|
Afterwards, this new model name can be specified the `model` field:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl http://localhost:11434/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello!"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
322
openai/openai.go
Normal file
322
openai/openai.go
Normal file
|
@ -0,0 +1,322 @@
|
||||||
|
// openai package provides middleware for partial compatibility with the OpenAI REST API
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/jmorganca/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Error struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Param interface{} `json:"param"`
|
||||||
|
Code *string `json:"code"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ErrorResponse struct {
|
||||||
|
Error Error `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Choice struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Message Message `json:"message"`
|
||||||
|
FinishReason *string `json:"finish_reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChunkChoice struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Delta Message `json:"delta"`
|
||||||
|
FinishReason *string `json:"finish_reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Usage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResponseFormat struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []Message `json:"messages"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
MaxTokens *int `json:"max_tokens"`
|
||||||
|
Seed *int `json:"seed"`
|
||||||
|
Stop any `json:"stop"`
|
||||||
|
Temperature *float64 `json:"temperature"`
|
||||||
|
FrequencyPenalty *float64 `json:"frequency_penalty"`
|
||||||
|
PresencePenalty *float64 `json:"presence_penalty_penalty"`
|
||||||
|
TopP *float64 `json:"top_p"`
|
||||||
|
ResponseFormat *ResponseFormat `json:"response_format"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletion struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
SystemFingerprint string `json:"system_fingerprint"`
|
||||||
|
Choices []Choice `json:"choices"`
|
||||||
|
Usage Usage `json:"usage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionChunk struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
SystemFingerprint string `json:"system_fingerprint"`
|
||||||
|
Choices []ChunkChoice `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewError(code int, message string) ErrorResponse {
|
||||||
|
var etype string
|
||||||
|
switch code {
|
||||||
|
case http.StatusBadRequest:
|
||||||
|
etype = "invalid_request_error"
|
||||||
|
case http.StatusNotFound:
|
||||||
|
etype = "not_found_error"
|
||||||
|
default:
|
||||||
|
etype = "api_error"
|
||||||
|
}
|
||||||
|
|
||||||
|
return ErrorResponse{Error{Type: etype, Message: message}}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||||
|
return ChatCompletion{
|
||||||
|
Id: id,
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: r.CreatedAt.Unix(),
|
||||||
|
Model: r.Model,
|
||||||
|
SystemFingerprint: "fp_ollama",
|
||||||
|
Choices: []Choice{{
|
||||||
|
Index: 0,
|
||||||
|
Message: Message{Role: r.Message.Role, Content: r.Message.Content},
|
||||||
|
FinishReason: func(done bool) *string {
|
||||||
|
if done {
|
||||||
|
reason := "stop"
|
||||||
|
return &reason
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}(r.Done),
|
||||||
|
}},
|
||||||
|
Usage: Usage{
|
||||||
|
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
|
||||||
|
PromptTokens: r.PromptEvalCount,
|
||||||
|
CompletionTokens: r.EvalCount,
|
||||||
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
||||||
|
return ChatCompletionChunk{
|
||||||
|
Id: id,
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
Model: r.Model,
|
||||||
|
SystemFingerprint: "fp_ollama",
|
||||||
|
Choices: []ChunkChoice{
|
||||||
|
{
|
||||||
|
Index: 0,
|
||||||
|
Delta: Message{Role: "assistant", Content: r.Message.Content},
|
||||||
|
FinishReason: func(done bool) *string {
|
||||||
|
if done {
|
||||||
|
reason := "stop"
|
||||||
|
return &reason
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}(r.Done),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fromRequest(r ChatCompletionRequest) api.ChatRequest {
|
||||||
|
var messages []api.Message
|
||||||
|
for _, msg := range r.Messages {
|
||||||
|
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
|
||||||
|
}
|
||||||
|
|
||||||
|
options := make(map[string]interface{})
|
||||||
|
|
||||||
|
switch stop := r.Stop.(type) {
|
||||||
|
case string:
|
||||||
|
options["stop"] = []string{stop}
|
||||||
|
case []interface{}:
|
||||||
|
var stops []string
|
||||||
|
for _, s := range stop {
|
||||||
|
if str, ok := s.(string); ok {
|
||||||
|
stops = append(stops, str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
options["stop"] = stops
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.MaxTokens != nil {
|
||||||
|
options["num_predict"] = *r.MaxTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Temperature != nil {
|
||||||
|
options["temperature"] = *r.Temperature * 2.0
|
||||||
|
} else {
|
||||||
|
options["temperature"] = 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Seed != nil {
|
||||||
|
options["seed"] = *r.Seed
|
||||||
|
|
||||||
|
// temperature=0 is required for reproducible outputs
|
||||||
|
options["temperature"] = 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.FrequencyPenalty != nil {
|
||||||
|
options["frequency_penalty"] = *r.FrequencyPenalty * 2.0
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.PresencePenalty != nil {
|
||||||
|
options["presence_penalty"] = *r.PresencePenalty * 2.0
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.TopP != nil {
|
||||||
|
options["top_p"] = *r.TopP
|
||||||
|
} else {
|
||||||
|
options["top_p"] = 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
var format string
|
||||||
|
if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
|
||||||
|
format = "json"
|
||||||
|
}
|
||||||
|
|
||||||
|
return api.ChatRequest{
|
||||||
|
Model: r.Model,
|
||||||
|
Messages: messages,
|
||||||
|
Format: format,
|
||||||
|
Options: options,
|
||||||
|
Stream: &r.Stream,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type writer struct {
|
||||||
|
stream bool
|
||||||
|
id string
|
||||||
|
gin.ResponseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *writer) writeError(code int, data []byte) (int, error) {
|
||||||
|
var serr api.StatusError
|
||||||
|
err := json.Unmarshal(data, &serr)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *writer) writeResponse(data []byte) (int, error) {
|
||||||
|
var chatResponse api.ChatResponse
|
||||||
|
err := json.Unmarshal(data, &chatResponse)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// chat chunk
|
||||||
|
if w.stream {
|
||||||
|
d, err := json.Marshal(toChunk(w.id, chatResponse))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if chatResponse.Done {
|
||||||
|
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// chat completion
|
||||||
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
|
err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *writer) Write(data []byte) (int, error) {
|
||||||
|
code := w.ResponseWriter.Status()
|
||||||
|
if code != http.StatusOK {
|
||||||
|
return w.writeError(code, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.writeResponse(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Middleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
var req ChatCompletionRequest
|
||||||
|
err := c.ShouldBindJSON(&req)
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(req.Messages) == 0 {
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := json.NewEncoder(&b).Encode(fromRequest(req)); err != nil {
|
||||||
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Request.Body = io.NopCloser(&b)
|
||||||
|
|
||||||
|
w := &writer{
|
||||||
|
ResponseWriter: c.Writer,
|
||||||
|
stream: req.Stream,
|
||||||
|
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Writer = w
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
|
@ -26,6 +26,7 @@ import (
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
"github.com/jmorganca/ollama/gpu"
|
"github.com/jmorganca/ollama/gpu"
|
||||||
"github.com/jmorganca/ollama/llm"
|
"github.com/jmorganca/ollama/llm"
|
||||||
|
"github.com/jmorganca/ollama/openai"
|
||||||
"github.com/jmorganca/ollama/parser"
|
"github.com/jmorganca/ollama/parser"
|
||||||
"github.com/jmorganca/ollama/version"
|
"github.com/jmorganca/ollama/version"
|
||||||
)
|
)
|
||||||
|
@ -935,6 +936,9 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||||
r.POST("/api/blobs/:digest", CreateBlobHandler)
|
r.POST("/api/blobs/:digest", CreateBlobHandler)
|
||||||
r.HEAD("/api/blobs/:digest", HeadBlobHandler)
|
r.HEAD("/api/blobs/:digest", HeadBlobHandler)
|
||||||
|
|
||||||
|
// Compatibility endpoints
|
||||||
|
r.POST("/v1/chat/completions", openai.Middleware(), ChatHandler)
|
||||||
|
|
||||||
for _, method := range []string{http.MethodGet, http.MethodHead} {
|
for _, method := range []string{http.MethodGet, http.MethodHead} {
|
||||||
r.Handle(method, "/", func(c *gin.Context) {
|
r.Handle(method, "/", func(c *gin.Context) {
|
||||||
c.String(http.StatusOK, "Ollama is running")
|
c.String(http.StatusOK, "Ollama is running")
|
||||||
|
|
Loading…
Reference in a new issue