- update chat docs
- add messages chat endpoint
- remove deprecated context and template generate parameters from docs
- context and template are still supported for the time being and will continue to work as expected
- add partial response to chat history
This commit is contained in:
Bruce MacDonald 2023-12-04 18:01:06 -05:00 committed by GitHub
parent 0cca1486dd
commit 7a0899d62d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 667 additions and 256 deletions

View file

@ -221,6 +221,19 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate
}) })
} }
type ChatResponseFunc func(ChatResponse) error
func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error {
return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error {
var resp ChatResponse
if err := json.Unmarshal(bts, &resp); err != nil {
return err
}
return fn(resp)
})
}
type PullProgressFunc func(ProgressResponse) error type PullProgressFunc func(ProgressResponse) error
func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error { func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {

View file

@ -36,7 +36,7 @@ type GenerateRequest struct {
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
System string `json:"system"` System string `json:"system"`
Template string `json:"template"` Template string `json:"template"`
Context []int `json:"context,omitempty"` Context []int `json:"context,omitempty"` // DEPRECATED: context is deprecated, use the /chat endpoint instead for chat history
Stream *bool `json:"stream,omitempty"` Stream *bool `json:"stream,omitempty"`
Raw bool `json:"raw,omitempty"` Raw bool `json:"raw,omitempty"`
Format string `json:"format"` Format string `json:"format"`
@ -44,6 +44,41 @@ type GenerateRequest struct {
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
} }
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Template string `json:"template"`
Stream *bool `json:"stream,omitempty"`
Format string `json:"format"`
Options map[string]interface{} `json:"options"`
}
type Message struct {
Role string `json:"role"` // one of ["system", "user", "assistant"]
Content string `json:"content"`
}
type ChatResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Message *Message `json:"message,omitempty"`
Done bool `json:"done"`
Context []int `json:"context,omitempty"`
EvalMetrics
}
type EvalMetrics struct {
TotalDuration time.Duration `json:"total_duration,omitempty"`
LoadDuration time.Duration `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration time.Duration `json:"eval_duration,omitempty"`
}
// Options specfied in GenerateRequest, if you add a new option here add it to the API docs also // Options specfied in GenerateRequest, if you add a new option here add it to the API docs also
type Options struct { type Options struct {
Runner Runner
@ -173,39 +208,34 @@ type GenerateResponse struct {
Done bool `json:"done"` Done bool `json:"done"`
Context []int `json:"context,omitempty"` Context []int `json:"context,omitempty"`
TotalDuration time.Duration `json:"total_duration,omitempty"` EvalMetrics
LoadDuration time.Duration `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration time.Duration `json:"eval_duration,omitempty"`
} }
func (r *GenerateResponse) Summary() { func (m *EvalMetrics) Summary() {
if r.TotalDuration > 0 { if m.TotalDuration > 0 {
fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration) fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
} }
if r.LoadDuration > 0 { if m.LoadDuration > 0 {
fmt.Fprintf(os.Stderr, "load duration: %v\n", r.LoadDuration) fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
} }
if r.PromptEvalCount > 0 { if m.PromptEvalCount > 0 {
fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount) fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", m.PromptEvalCount)
} }
if r.PromptEvalDuration > 0 { if m.PromptEvalDuration > 0 {
fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", r.PromptEvalDuration) fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", m.PromptEvalDuration)
fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(r.PromptEvalCount)/r.PromptEvalDuration.Seconds()) fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(m.PromptEvalCount)/m.PromptEvalDuration.Seconds())
} }
if r.EvalCount > 0 { if m.EvalCount > 0 {
fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", r.EvalCount) fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", m.EvalCount)
} }
if r.EvalDuration > 0 { if m.EvalDuration > 0 {
fmt.Fprintf(os.Stderr, "eval duration: %s\n", r.EvalDuration) fmt.Fprintf(os.Stderr, "eval duration: %s\n", m.EvalDuration)
fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(r.EvalCount)/r.EvalDuration.Seconds()) fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(m.EvalCount)/m.EvalDuration.Seconds())
} }
} }

View file

@ -159,7 +159,54 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
return RunGenerate(cmd, args) interactive := true
opts := runOptions{
Model: name,
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]interface{}{},
}
format, err := cmd.Flags().GetString("format")
if err != nil {
return err
}
opts.Format = format
prompts := args[1:]
// prepend stdin to the prompt if provided
if !term.IsTerminal(int(os.Stdin.Fd())) {
in, err := io.ReadAll(os.Stdin)
if err != nil {
return err
}
prompts = append([]string{string(in)}, prompts...)
opts.WordWrap = false
interactive = false
}
msg := api.Message{
Role: "user",
Content: strings.Join(prompts, " "),
}
opts.Messages = append(opts.Messages, msg)
if len(prompts) > 0 {
interactive = false
}
nowrap, err := cmd.Flags().GetBool("nowordwrap")
if err != nil {
return err
}
opts.WordWrap = !nowrap
if !interactive {
_, err := chat(cmd, opts)
return err
}
return chatInteractive(cmd, opts)
} }
func PushHandler(cmd *cobra.Command, args []string) error { func PushHandler(cmd *cobra.Command, args []string) error {
@ -411,83 +458,26 @@ func PullHandler(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func RunGenerate(cmd *cobra.Command, args []string) error { type runOptions struct {
interactive := true
opts := generateOptions{
Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]interface{}{},
}
format, err := cmd.Flags().GetString("format")
if err != nil {
return err
}
opts.Format = format
prompts := args[1:]
// prepend stdin to the prompt if provided
if !term.IsTerminal(int(os.Stdin.Fd())) {
in, err := io.ReadAll(os.Stdin)
if err != nil {
return err
}
prompts = append([]string{string(in)}, prompts...)
opts.WordWrap = false
interactive = false
}
opts.Prompt = strings.Join(prompts, " ")
if len(prompts) > 0 {
interactive = false
}
nowrap, err := cmd.Flags().GetBool("nowordwrap")
if err != nil {
return err
}
opts.WordWrap = !nowrap
if !interactive {
return generate(cmd, opts)
}
return generateInteractive(cmd, opts)
}
type generateContextKey string
type generateOptions struct {
Model string Model string
Prompt string Messages []api.Message
WordWrap bool WordWrap bool
Format string Format string
System string
Template string Template string
Options map[string]interface{} Options map[string]interface{}
} }
func generate(cmd *cobra.Command, opts generateOptions) error { func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
return err return nil, err
} }
p := progress.NewProgress(os.Stderr) p := progress.NewProgress(os.Stderr)
defer p.StopAndClear() defer p.StopAndClear()
spinner := progress.NewSpinner("") spinner := progress.NewSpinner("")
p.Add("", spinner) p.Add("", spinner)
var latest api.GenerateResponse
generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int)
if !ok {
generateContext = []int{}
}
termWidth, _, err := term.GetSize(int(os.Stdout.Fd())) termWidth, _, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil { if err != nil {
opts.WordWrap = false opts.WordWrap = false
@ -506,24 +496,24 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
var currentLineLength int var currentLineLength int
var wordBuffer string var wordBuffer string
var latest api.ChatResponse
var fullResponse strings.Builder
var role string
request := api.GenerateRequest{ fn := func(response api.ChatResponse) error {
Model: opts.Model,
Prompt: opts.Prompt,
Context: generateContext,
Format: opts.Format,
System: opts.System,
Template: opts.Template,
Options: opts.Options,
}
fn := func(response api.GenerateResponse) error {
p.StopAndClear() p.StopAndClear()
latest = response latest = response
if response.Message == nil {
// warm-up response or done
return nil
}
role = response.Message.Role
content := response.Message.Content
fullResponse.WriteString(content)
termWidth, _, _ = term.GetSize(int(os.Stdout.Fd())) termWidth, _, _ = term.GetSize(int(os.Stdout.Fd()))
if opts.WordWrap && termWidth >= 10 { if opts.WordWrap && termWidth >= 10 {
for _, ch := range response.Response { for _, ch := range content {
if currentLineLength+1 > termWidth-5 { if currentLineLength+1 > termWidth-5 {
if len(wordBuffer) > termWidth-10 { if len(wordBuffer) > termWidth-10 {
fmt.Printf("%s%c", wordBuffer, ch) fmt.Printf("%s%c", wordBuffer, ch)
@ -551,7 +541,7 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
} }
} }
} else { } else {
fmt.Printf("%s%s", wordBuffer, response.Response) fmt.Printf("%s%s", wordBuffer, content)
if len(wordBuffer) > 0 { if len(wordBuffer) > 0 {
wordBuffer = "" wordBuffer = ""
} }
@ -560,35 +550,35 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
return nil return nil
} }
if err := client.Generate(cancelCtx, &request, fn); err != nil { req := &api.ChatRequest{
Model: opts.Model,
Messages: opts.Messages,
Format: opts.Format,
Template: opts.Template,
Options: opts.Options,
}
if err := client.Chat(cancelCtx, req, fn); err != nil {
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
return nil return nil, nil
} }
return err return nil, err
}
if opts.Prompt != "" {
fmt.Println()
fmt.Println()
} }
if !latest.Done { if len(opts.Messages) > 0 {
return nil fmt.Println()
fmt.Println()
} }
verbose, err := cmd.Flags().GetBool("verbose") verbose, err := cmd.Flags().GetBool("verbose")
if err != nil { if err != nil {
return err return nil, err
} }
if verbose { if verbose {
latest.Summary() latest.Summary()
} }
ctx := cmd.Context() return &api.Message{Role: role, Content: fullResponse.String()}, nil
ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
cmd.SetContext(ctx)
return nil
} }
type MultilineState int type MultilineState int
@ -600,13 +590,10 @@ const (
MultilineTemplate MultilineTemplate
) )
func generateInteractive(cmd *cobra.Command, opts generateOptions) error { func chatInteractive(cmd *cobra.Command, opts runOptions) error {
// load the model // load the model
loadOpts := generateOptions{ loadOpts := runOptions{Model: opts.Model}
Model: opts.Model, if _, err := chat(cmd, loadOpts); err != nil {
Prompt: "",
}
if err := generate(cmd, loadOpts); err != nil {
return err return err
} }
@ -677,7 +664,9 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
defer fmt.Printf(readline.EndBracketedPaste) defer fmt.Printf(readline.EndBracketedPaste)
var multiline MultilineState var multiline MultilineState
var prompt string var content string
var systemContent string
opts.Messages = make([]api.Message, 0)
for { for {
line, err := scanner.Readline() line, err := scanner.Readline()
@ -691,7 +680,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
} }
scanner.Prompt.UseAlt = false scanner.Prompt.UseAlt = false
prompt = "" content = ""
continue continue
case err != nil: case err != nil:
@ -699,37 +688,37 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
} }
switch { switch {
case strings.HasPrefix(prompt, `"""`): case strings.HasPrefix(content, `"""`):
// if the prompt so far starts with """ then we're in multiline mode // if the prompt so far starts with """ then we're in multiline mode
// and we need to keep reading until we find a line that ends with """ // and we need to keep reading until we find a line that ends with """
cut, found := strings.CutSuffix(line, `"""`) cut, found := strings.CutSuffix(line, `"""`)
prompt += cut + "\n" content += cut + "\n"
if !found { if !found {
continue continue
} }
prompt = strings.TrimPrefix(prompt, `"""`) content = strings.TrimPrefix(content, `"""`)
scanner.Prompt.UseAlt = false scanner.Prompt.UseAlt = false
switch multiline { switch multiline {
case MultilineSystem: case MultilineSystem:
opts.System = prompt systemContent = content
prompt = "" content = ""
fmt.Println("Set system template.\n") fmt.Println("Set system template.\n")
case MultilineTemplate: case MultilineTemplate:
opts.Template = prompt opts.Template = content
prompt = "" content = ""
fmt.Println("Set model template.\n") fmt.Println("Set model template.\n")
} }
multiline = MultilineNone multiline = MultilineNone
case strings.HasPrefix(line, `"""`) && len(prompt) == 0: case strings.HasPrefix(line, `"""`) && len(content) == 0:
scanner.Prompt.UseAlt = true scanner.Prompt.UseAlt = true
multiline = MultilinePrompt multiline = MultilinePrompt
prompt += line + "\n" content += line + "\n"
continue continue
case scanner.Pasting: case scanner.Pasting:
prompt += line + "\n" content += line + "\n"
continue continue
case strings.HasPrefix(line, "/list"): case strings.HasPrefix(line, "/list"):
args := strings.Fields(line) args := strings.Fields(line)
@ -791,17 +780,17 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
line = strings.TrimPrefix(line, `"""`) line = strings.TrimPrefix(line, `"""`)
if strings.HasPrefix(args[2], `"""`) { if strings.HasPrefix(args[2], `"""`) {
cut, found := strings.CutSuffix(line, `"""`) cut, found := strings.CutSuffix(line, `"""`)
prompt += cut + "\n" content += cut + "\n"
if found { if found {
opts.System = prompt systemContent = content
if args[1] == "system" { if args[1] == "system" {
fmt.Println("Set system template.\n") fmt.Println("Set system template.\n")
} else { } else {
fmt.Println("Set prompt template.\n") fmt.Println("Set prompt template.\n")
} }
prompt = "" content = ""
} else { } else {
prompt = `"""` + prompt content = `"""` + content
if args[1] == "system" { if args[1] == "system" {
multiline = MultilineSystem multiline = MultilineSystem
} else { } else {
@ -810,7 +799,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
scanner.Prompt.UseAlt = true scanner.Prompt.UseAlt = true
} }
} else { } else {
opts.System = line systemContent = line
fmt.Println("Set system template.\n") fmt.Println("Set system template.\n")
} }
default: default:
@ -858,8 +847,8 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
} }
case "system": case "system":
switch { switch {
case opts.System != "": case systemContent != "":
fmt.Println(opts.System + "\n") fmt.Println(systemContent + "\n")
case resp.System != "": case resp.System != "":
fmt.Println(resp.System + "\n") fmt.Println(resp.System + "\n")
default: default:
@ -899,16 +888,23 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0]) fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0])
continue continue
default: default:
prompt += line content += line
} }
if len(prompt) > 0 && multiline == MultilineNone { if len(content) > 0 && multiline == MultilineNone {
opts.Prompt = prompt if systemContent != "" {
if err := generate(cmd, opts); err != nil { opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: systemContent})
}
opts.Messages = append(opts.Messages, api.Message{Role: "user", Content: content})
assistant, err := chat(cmd, opts)
if err != nil {
return err return err
} }
if assistant != nil {
opts.Messages = append(opts.Messages, *assistant)
}
prompt = "" content = ""
} }
} }
} }

View file

@ -24,7 +24,7 @@ All durations are returned in nanoseconds.
### Streaming responses ### Streaming responses
Certain endpoints stream responses as JSON objects delineated with the newline (`\n`) character. Certain endpoints stream responses as JSON objects.
## Generate a completion ## Generate a completion
@ -32,10 +32,12 @@ Certain endpoints stream responses as JSON objects delineated with the newline (
POST /api/generate POST /api/generate
``` ```
Generate a response for a given prompt with a provided model. This is a streaming endpoint, so will be a series of responses. The final response object will include statistics and additional data from the request. Generate a response for a given prompt with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request.
### Parameters ### Parameters
`model` is required.
- `model`: (required) the [model name](#model-names) - `model`: (required) the [model name](#model-names)
- `prompt`: the prompt to generate a response for - `prompt`: the prompt to generate a response for
@ -43,11 +45,10 @@ Advanced parameters (optional):
- `format`: the format to return a response in. Currently the only accepted value is `json` - `format`: the format to return a response in. Currently the only accepted value is `json`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `system`: system prompt to (overrides what is defined in the `Modelfile`)
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`) - `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory - `system`: system prompt to (overrides what is defined in the `Modelfile`)
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
- `raw`: if `true` no formatting will be applied to the prompt and no context will be returned. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API, and are managing history yourself. - `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API.
### JSON mode ### JSON mode
@ -57,7 +58,7 @@ Enable JSON mode by setting the `format` parameter to `json`. This will structur
### Examples ### Examples
#### Request #### Request (Prompt)
```shell ```shell
curl http://localhost:11434/api/generate -d '{ curl http://localhost:11434/api/generate -d '{
@ -89,7 +90,7 @@ The final response in the stream also includes additional data about the generat
- `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt - `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt
- `eval_count`: number of tokens the response - `eval_count`: number of tokens the response
- `eval_duration`: time in nanoseconds spent generating the response - `eval_duration`: time in nanoseconds spent generating the response
- `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory - `context`: deprecated, an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory
- `response`: empty if the response was streamed, if not streamed, this will contain the full response - `response`: empty if the response was streamed, if not streamed, this will contain the full response
To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration`. To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration`.
@ -114,6 +115,8 @@ To calculate how fast the response is generated in tokens per second (token/s),
#### Request (No streaming) #### Request (No streaming)
A response can be recieved in one reply when streaming is off.
```shell ```shell
curl http://localhost:11434/api/generate -d '{ curl http://localhost:11434/api/generate -d '{
"model": "llama2", "model": "llama2",
@ -144,9 +147,9 @@ If `stream` is set to `false`, the response will be a single JSON object:
} }
``` ```
#### Request (Raw mode) #### Request (Raw Mode)
In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting and context. In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting.
```shell ```shell
curl http://localhost:11434/api/generate -d '{ curl http://localhost:11434/api/generate -d '{
@ -164,6 +167,7 @@ curl http://localhost:11434/api/generate -d '{
"model": "mistral", "model": "mistral",
"created_at": "2023-11-03T15:36:02.583064Z", "created_at": "2023-11-03T15:36:02.583064Z",
"response": " The sky appears blue because of a phenomenon called Rayleigh scattering.", "response": " The sky appears blue because of a phenomenon called Rayleigh scattering.",
"context": [1, 2, 3],
"done": true, "done": true,
"total_duration": 14648695333, "total_duration": 14648695333,
"load_duration": 3302671417, "load_duration": 3302671417,
@ -275,7 +279,6 @@ curl http://localhost:11434/api/generate -d '{
"model": "llama2", "model": "llama2",
"created_at": "2023-08-04T19:22:45.499127Z", "created_at": "2023-08-04T19:22:45.499127Z",
"response": "The sky is blue because it is the color of the sky.", "response": "The sky is blue because it is the color of the sky.",
"context": [1, 2, 3],
"done": true, "done": true,
"total_duration": 5589157167, "total_duration": 5589157167,
"load_duration": 3013701500, "load_duration": 3013701500,
@ -288,6 +291,135 @@ curl http://localhost:11434/api/generate -d '{
} }
``` ```
## Send Chat Messages
```shell
POST /api/chat
```
Generate the next message in a chat with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request.
### Parameters
`model` is required.
- `model`: (required) the [model name](#model-names)
- `messages`: the messages of the chat, this can be used to keep a chat memory
Advanced parameters (optional):
- `format`: the format to return a response in. Currently the only accepted value is `json`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
### Examples
#### Request
Send a chat message with a streaming response.
```shell
curl http://localhost:11434/api/generate -d '{
"model": "llama2",
"messages": [
{
"role": "user",
"content": "why is the sky blue?"
}
]
}'
```
#### Response
A stream of JSON objects is returned:
```json
{
"model": "llama2",
"created_at": "2023-08-04T08:52:19.385406455-07:00",
"message": {
"role": "assisant",
"content": "The"
},
"done": false
}
```
Final response:
```json
{
"model": "llama2",
"created_at": "2023-08-04T19:22:45.499127Z",
"done": true,
"total_duration": 5589157167,
"load_duration": 3013701500,
"sample_count": 114,
"sample_duration": 81442000,
"prompt_eval_count": 46,
"prompt_eval_duration": 1160282000,
"eval_count": 113,
"eval_duration": 1325948000
}
```
#### Request (With History)
Send a chat message with a conversation history.
```shell
curl http://localhost:11434/api/generate -d '{
"model": "llama2",
"messages": [
{
"role": "user",
"content": "why is the sky blue?"
},
{
"role": "assistant",
"content": "due to rayleigh scattering."
},
{
"role": "user",
"content": "how is that different than mie scattering?"
}
]
}'
```
#### Response
A stream of JSON objects is returned:
```json
{
"model": "llama2",
"created_at": "2023-08-04T08:52:19.385406455-07:00",
"message": {
"role": "assisant",
"content": "The"
},
"done": false
}
```
Final response:
```json
{
"model": "llama2",
"created_at": "2023-08-04T19:22:45.499127Z",
"done": true,
"total_duration": 5589157167,
"load_duration": 3013701500,
"sample_count": 114,
"sample_duration": 81442000,
"prompt_eval_count": 46,
"prompt_eval_duration": 1160282000,
"eval_count": 113,
"eval_duration": 1325948000
}
```
## Create a Model ## Create a Model
```shell ```shell

View file

@ -531,21 +531,31 @@ type prediction struct {
const maxBufferSize = 512 * format.KiloByte const maxBufferSize = 512 * format.KiloByte
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, format string, fn func(api.GenerateResponse)) error { type PredictRequest struct {
prevConvo, err := llm.Decode(ctx, prevContext) Model string
if err != nil { Prompt string
return err Format string
} CheckpointStart time.Time
CheckpointLoaded time.Time
}
// Remove leading spaces from prevConvo if present type PredictResponse struct {
prevConvo = strings.TrimPrefix(prevConvo, " ") Model string
CreatedAt time.Time
var nextContext strings.Builder TotalDuration time.Duration
nextContext.WriteString(prevConvo) LoadDuration time.Duration
nextContext.WriteString(prompt) Content string
Done bool
PromptEvalCount int
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration
Context []int
}
func (llm *llama) Predict(ctx context.Context, predict PredictRequest, fn func(PredictResponse)) error {
request := map[string]any{ request := map[string]any{
"prompt": nextContext.String(), "prompt": predict.Prompt,
"stream": true, "stream": true,
"n_predict": llm.NumPredict, "n_predict": llm.NumPredict,
"n_keep": llm.NumKeep, "n_keep": llm.NumKeep,
@ -567,7 +577,7 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
"stop": llm.Stop, "stop": llm.Stop,
} }
if format == "json" { if predict.Format == "json" {
request["grammar"] = jsonGrammar request["grammar"] = jsonGrammar
} }
@ -624,25 +634,25 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
} }
if p.Content != "" { if p.Content != "" {
fn(api.GenerateResponse{Response: p.Content}) fn(PredictResponse{
nextContext.WriteString(p.Content) Model: predict.Model,
CreatedAt: time.Now().UTC(),
Content: p.Content,
})
} }
if p.Stop { if p.Stop {
embd, err := llm.Encode(ctx, nextContext.String()) fn(PredictResponse{
if err != nil { Model: predict.Model,
return fmt.Errorf("encoding context: %v", err) CreatedAt: time.Now().UTC(),
} TotalDuration: time.Since(predict.CheckpointStart),
fn(api.GenerateResponse{
Done: true, Done: true,
Context: embd,
PromptEvalCount: p.Timings.PromptN, PromptEvalCount: p.Timings.PromptN,
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS), PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
EvalCount: p.Timings.PredictedN, EvalCount: p.Timings.PredictedN,
EvalDuration: parseDurationMs(p.Timings.PredictedMS), EvalDuration: parseDurationMs(p.Timings.PredictedMS),
}) })
return nil return nil
} }
} }

View file

@ -14,7 +14,7 @@ import (
) )
type LLM interface { type LLM interface {
Predict(context.Context, []int, string, string, func(api.GenerateResponse)) error Predict(context.Context, PredictRequest, func(PredictResponse)) error
Embedding(context.Context, string) ([]float64, error) Embedding(context.Context, string) ([]float64, error)
Encode(context.Context, string) ([]int, error) Encode(context.Context, string) ([]int, error)
Decode(context.Context, []int) (string, error) Decode(context.Context, []int) (string, error)

View file

@ -47,37 +47,82 @@ type Model struct {
Options map[string]interface{} Options map[string]interface{}
} }
func (m *Model) Prompt(request api.GenerateRequest) (string, error) { type PromptVars struct {
t := m.Template System string
if request.Template != "" { Prompt string
t = request.Template Response string
} }
tmpl, err := template.New("").Parse(t) func (m *Model) Prompt(p PromptVars) (string, error) {
var prompt strings.Builder
tmpl, err := template.New("").Parse(m.Template)
if err != nil { if err != nil {
return "", err return "", err
} }
var vars struct { if p.System == "" {
First bool // use the default system prompt for this model if one is not specified
System string p.System = m.System
Prompt string
}
vars.First = len(request.Context) == 0
vars.System = m.System
vars.Prompt = request.Prompt
if request.System != "" {
vars.System = request.System
} }
var sb strings.Builder var sb strings.Builder
if err := tmpl.Execute(&sb, vars); err != nil { if err := tmpl.Execute(&sb, p); err != nil {
return "", err return "", err
} }
prompt.WriteString(sb.String())
prompt.WriteString(p.Response)
return prompt.String(), nil
}
return sb.String(), nil func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
// build the prompt from the list of messages
var prompt strings.Builder
currentVars := PromptVars{}
writePrompt := func() error {
p, err := m.Prompt(currentVars)
if err != nil {
return err
}
prompt.WriteString(p)
currentVars = PromptVars{}
return nil
}
for _, msg := range msgs {
switch msg.Role {
case "system":
if currentVars.Prompt != "" || currentVars.System != "" {
if err := writePrompt(); err != nil {
return "", err
}
}
currentVars.System = msg.Content
case "user":
if currentVars.Prompt != "" || currentVars.System != "" {
if err := writePrompt(); err != nil {
return "", err
}
}
currentVars.Prompt = msg.Content
case "assistant":
currentVars.Response = msg.Content
if err := writePrompt(); err != nil {
return "", err
}
default:
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
}
}
// Append the last set of vars if they are non-empty
if currentVars.Prompt != "" || currentVars.System != "" {
if err := writePrompt(); err != nil {
return "", err
}
}
return prompt.String(), nil
} }
type ManifestV2 struct { type ManifestV2 struct {

View file

@ -2,17 +2,15 @@ package server
import ( import (
"testing" "testing"
"github.com/jmorganca/ollama/api"
) )
func TestModelPrompt(t *testing.T) { func TestModelPrompt(t *testing.T) {
var m Model m := Model{
req := api.GenerateRequest{
Template: "a{{ .Prompt }}b", Template: "a{{ .Prompt }}b",
Prompt: "<h1>",
} }
s, err := m.Prompt(req) s, err := m.Prompt(PromptVars{
Prompt: "<h1>",
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -60,17 +60,26 @@ var loaded struct {
var defaultSessionDuration = 5 * time.Minute var defaultSessionDuration = 5 * time.Minute
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error { func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sessionDuration time.Duration) (*Model, error) {
model, err := GetModel(modelName)
if err != nil {
return nil, err
}
workDir := c.GetString("workDir")
opts := api.DefaultOptions() opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil { if err := opts.FromMap(model.Options); err != nil {
log.Printf("could not load model options: %v", err) log.Printf("could not load model options: %v", err)
return err return nil, err
} }
if err := opts.FromMap(reqOpts); err != nil { if err := opts.FromMap(reqOpts); err != nil {
return err return nil, err
} }
ctx := c.Request.Context()
// check if the loaded model is still running in a subprocess, in case something unexpected happened // check if the loaded model is still running in a subprocess, in case something unexpected happened
if loaded.runner != nil { if loaded.runner != nil {
if err := loaded.runner.Ping(ctx); err != nil { if err := loaded.runner.Ping(ctx); err != nil {
@ -106,7 +115,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName) err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
} }
return err return nil, err
} }
loaded.Model = model loaded.Model = model
@ -140,7 +149,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
} }
loaded.expireTimer.Reset(sessionDuration) loaded.expireTimer.Reset(sessionDuration)
return nil return model, nil
} }
func GenerateHandler(c *gin.Context) { func GenerateHandler(c *gin.Context) {
@ -173,88 +182,262 @@ func GenerateHandler(c *gin.Context) {
return return
} }
model, err := GetModel(req.Model) sessionDuration := defaultSessionDuration
model, err := load(c, req.Model, req.Options, sessionDuration)
if err != nil { if err != nil {
var pErr *fs.PathError var pErr *fs.PathError
if errors.As(err, &pErr) { switch {
case errors.As(err, &pErr):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)}) c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
return case errors.Is(err, api.ErrInvalidOpts):
}
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return default:
}
workDir := c.GetString("workDir")
// TODO: set this duration from the request if specified
sessionDuration := defaultSessionDuration
if err := load(c.Request.Context(), workDir, model, req.Options, sessionDuration); err != nil {
if errors.Is(err, api.ErrInvalidOpts) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
// an empty request loads the model
if req.Prompt == "" && req.Template == "" && req.System == "" {
c.JSON(http.StatusOK, api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true})
return return
} }
checkpointLoaded := time.Now() checkpointLoaded := time.Now()
prompt := req.Prompt var prompt string
if !req.Raw { sendContext := false
prompt, err = model.Prompt(req) switch {
case req.Raw:
prompt = req.Prompt
case req.Prompt != "":
if req.Template != "" {
// override the default model template
model.Template = req.Template
}
var rebuild strings.Builder
if req.Context != nil {
// TODO: context is deprecated, at some point the context logic within this conditional should be removed
prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
// Remove leading spaces from prevCtx if present
prevCtx = strings.TrimPrefix(prevCtx, " ")
rebuild.WriteString(prevCtx)
}
p, err := model.Prompt(PromptVars{
System: req.System,
Prompt: req.Prompt,
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
rebuild.WriteString(p)
prompt = rebuild.String()
sendContext = true
} }
ch := make(chan any) ch := make(chan any)
var generated strings.Builder
go func() { go func() {
defer close(ch) defer close(ch)
// an empty request loads the model
if req.Prompt == "" && req.Template == "" && req.System == "" {
ch <- api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}
return
}
fn := func(r api.GenerateResponse) { fn := func(r llm.PredictResponse) {
// Update model expiration
loaded.expireAt = time.Now().Add(sessionDuration) loaded.expireAt = time.Now().Add(sessionDuration)
loaded.expireTimer.Reset(sessionDuration) loaded.expireTimer.Reset(sessionDuration)
r.Model = req.Model // Build up the full response
r.CreatedAt = time.Now().UTC() if _, err := generated.WriteString(r.Content); err != nil {
if r.Done { ch <- gin.H{"error": err.Error()}
r.TotalDuration = time.Since(checkpointStart) return
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} }
if req.Raw { resp := api.GenerateResponse{
// in raw mode the client must manage history on their own Model: r.Model,
r.Context = nil CreatedAt: r.CreatedAt,
Done: r.Done,
Response: r.Content,
EvalMetrics: api.EvalMetrics{
TotalDuration: r.TotalDuration,
LoadDuration: r.LoadDuration,
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
} }
ch <- r if r.Done && sendContext {
embd, err := loaded.runner.Encode(c.Request.Context(), req.Prompt+generated.String())
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
r.Context = embd
} }
if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, req.Format, fn); err != nil { ch <- resp
}
// Start prediction
predictReq := llm.PredictRequest{
Model: model.Name,
Prompt: prompt,
Format: req.Format,
CheckpointStart: checkpointStart,
CheckpointLoaded: checkpointLoaded,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
if req.Stream != nil && !*req.Stream { if req.Stream != nil && !*req.Stream {
var response api.GenerateResponse // Wait for the channel to close
generated := "" var r api.GenerateResponse
var sb strings.Builder
for resp := range ch { for resp := range ch {
if r, ok := resp.(api.GenerateResponse); ok { var ok bool
generated += r.Response if r, ok = resp.(api.GenerateResponse); !ok {
response = r
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
sb.WriteString(r.Response)
} }
response.Response = generated r.Response = sb.String()
c.JSON(http.StatusOK, response) c.JSON(http.StatusOK, r)
return
}
streamResponse(c, ch)
}
func ChatHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
checkpointStart := time.Now()
var req api.ChatRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// validate the request
switch {
case req.Model == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
case len(req.Format) > 0 && req.Format != "json":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
return
}
sessionDuration := defaultSessionDuration
model, err := load(c, req.Model, req.Options, sessionDuration)
if err != nil {
var pErr *fs.PathError
switch {
case errors.As(err, &pErr):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
case errors.Is(err, api.ErrInvalidOpts):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
// an empty request loads the model
if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true})
return
}
checkpointLoaded := time.Now()
if req.Template != "" {
// override the default model template
model.Template = req.Template
}
prompt, err := model.ChatPrompt(req.Messages)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ch := make(chan any)
go func() {
defer close(ch)
fn := func(r llm.PredictResponse) {
// Update model expiration
loaded.expireAt = time.Now().Add(sessionDuration)
loaded.expireTimer.Reset(sessionDuration)
resp := api.ChatResponse{
Model: r.Model,
CreatedAt: r.CreatedAt,
Done: r.Done,
EvalMetrics: api.EvalMetrics{
TotalDuration: r.TotalDuration,
LoadDuration: r.LoadDuration,
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}
if !r.Done {
resp.Message = &api.Message{Role: "assistant", Content: r.Content}
}
ch <- resp
}
// Start prediction
predictReq := llm.PredictRequest{
Model: model.Name,
Prompt: prompt,
Format: req.Format,
CheckpointStart: checkpointStart,
CheckpointLoaded: checkpointLoaded,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if req.Stream != nil && !*req.Stream {
// Wait for the channel to close
var r api.ChatResponse
var sb strings.Builder
for resp := range ch {
var ok bool
if r, ok = resp.(api.ChatResponse); !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if r.Message != nil {
sb.WriteString(r.Message.Content)
}
}
r.Message = &api.Message{Role: "assistant", Content: sb.String()}
c.JSON(http.StatusOK, r)
return return
} }
@ -281,15 +464,18 @@ func EmbeddingHandler(c *gin.Context) {
return return
} }
model, err := GetModel(req.Model) sessionDuration := defaultSessionDuration
_, err = load(c, req.Model, req.Options, sessionDuration)
if err != nil { if err != nil {
var pErr *fs.PathError
switch {
case errors.As(err, &pErr):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
case errors.Is(err, api.ErrInvalidOpts):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
} }
workDir := c.GetString("workDir")
if err := load(c.Request.Context(), workDir, model, req.Options, 5*time.Minute); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
@ -767,6 +953,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
r.POST("/api/pull", PullModelHandler) r.POST("/api/pull", PullModelHandler)
r.POST("/api/generate", GenerateHandler) r.POST("/api/generate", GenerateHandler)
r.POST("/api/chat", ChatHandler)
r.POST("/api/embeddings", EmbeddingHandler) r.POST("/api/embeddings", EmbeddingHandler)
r.POST("/api/create", CreateModelHandler) r.POST("/api/create", CreateModelHandler)
r.POST("/api/push", PushModelHandler) r.POST("/api/push", PushModelHandler)