chat api (#991)
- update chat docs - add messages chat endpoint - remove deprecated context and template generate parameters from docs - context and template are still supported for the time being and will continue to work as expected - add partial response to chat history
This commit is contained in:
parent
0cca1486dd
commit
7a0899d62d
9 changed files with 667 additions and 256 deletions
|
@ -221,6 +221,19 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ChatResponseFunc func(ChatResponse) error
|
||||||
|
|
||||||
|
func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error {
|
||||||
|
return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error {
|
||||||
|
var resp ChatResponse
|
||||||
|
if err := json.Unmarshal(bts, &resp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return fn(resp)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
type PullProgressFunc func(ProgressResponse) error
|
type PullProgressFunc func(ProgressResponse) error
|
||||||
|
|
||||||
func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
|
func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
|
||||||
|
|
74
api/types.go
74
api/types.go
|
@ -36,7 +36,7 @@ type GenerateRequest struct {
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
System string `json:"system"`
|
System string `json:"system"`
|
||||||
Template string `json:"template"`
|
Template string `json:"template"`
|
||||||
Context []int `json:"context,omitempty"`
|
Context []int `json:"context,omitempty"` // DEPRECATED: context is deprecated, use the /chat endpoint instead for chat history
|
||||||
Stream *bool `json:"stream,omitempty"`
|
Stream *bool `json:"stream,omitempty"`
|
||||||
Raw bool `json:"raw,omitempty"`
|
Raw bool `json:"raw,omitempty"`
|
||||||
Format string `json:"format"`
|
Format string `json:"format"`
|
||||||
|
@ -44,6 +44,41 @@ type GenerateRequest struct {
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ChatRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []Message `json:"messages"`
|
||||||
|
Template string `json:"template"`
|
||||||
|
Stream *bool `json:"stream,omitempty"`
|
||||||
|
Format string `json:"format"`
|
||||||
|
|
||||||
|
Options map[string]interface{} `json:"options"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Role string `json:"role"` // one of ["system", "user", "assistant"]
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatResponse struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
Message *Message `json:"message,omitempty"`
|
||||||
|
|
||||||
|
Done bool `json:"done"`
|
||||||
|
Context []int `json:"context,omitempty"`
|
||||||
|
|
||||||
|
EvalMetrics
|
||||||
|
}
|
||||||
|
|
||||||
|
type EvalMetrics struct {
|
||||||
|
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||||
|
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||||
|
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
||||||
|
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
||||||
|
EvalCount int `json:"eval_count,omitempty"`
|
||||||
|
EvalDuration time.Duration `json:"eval_duration,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// Options specfied in GenerateRequest, if you add a new option here add it to the API docs also
|
// Options specfied in GenerateRequest, if you add a new option here add it to the API docs also
|
||||||
type Options struct {
|
type Options struct {
|
||||||
Runner
|
Runner
|
||||||
|
@ -173,39 +208,34 @@ type GenerateResponse struct {
|
||||||
Done bool `json:"done"`
|
Done bool `json:"done"`
|
||||||
Context []int `json:"context,omitempty"`
|
Context []int `json:"context,omitempty"`
|
||||||
|
|
||||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
EvalMetrics
|
||||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
|
||||||
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
|
||||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration,omitempty"`
|
|
||||||
EvalCount int `json:"eval_count,omitempty"`
|
|
||||||
EvalDuration time.Duration `json:"eval_duration,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *GenerateResponse) Summary() {
|
func (m *EvalMetrics) Summary() {
|
||||||
if r.TotalDuration > 0 {
|
if m.TotalDuration > 0 {
|
||||||
fmt.Fprintf(os.Stderr, "total duration: %v\n", r.TotalDuration)
|
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.LoadDuration > 0 {
|
if m.LoadDuration > 0 {
|
||||||
fmt.Fprintf(os.Stderr, "load duration: %v\n", r.LoadDuration)
|
fmt.Fprintf(os.Stderr, "load duration: %v\n", m.LoadDuration)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.PromptEvalCount > 0 {
|
if m.PromptEvalCount > 0 {
|
||||||
fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", r.PromptEvalCount)
|
fmt.Fprintf(os.Stderr, "prompt eval count: %d token(s)\n", m.PromptEvalCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.PromptEvalDuration > 0 {
|
if m.PromptEvalDuration > 0 {
|
||||||
fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", r.PromptEvalDuration)
|
fmt.Fprintf(os.Stderr, "prompt eval duration: %s\n", m.PromptEvalDuration)
|
||||||
fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(r.PromptEvalCount)/r.PromptEvalDuration.Seconds())
|
fmt.Fprintf(os.Stderr, "prompt eval rate: %.2f tokens/s\n", float64(m.PromptEvalCount)/m.PromptEvalDuration.Seconds())
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.EvalCount > 0 {
|
if m.EvalCount > 0 {
|
||||||
fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", r.EvalCount)
|
fmt.Fprintf(os.Stderr, "eval count: %d token(s)\n", m.EvalCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.EvalDuration > 0 {
|
if m.EvalDuration > 0 {
|
||||||
fmt.Fprintf(os.Stderr, "eval duration: %s\n", r.EvalDuration)
|
fmt.Fprintf(os.Stderr, "eval duration: %s\n", m.EvalDuration)
|
||||||
fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(r.EvalCount)/r.EvalDuration.Seconds())
|
fmt.Fprintf(os.Stderr, "eval rate: %.2f tokens/s\n", float64(m.EvalCount)/m.EvalDuration.Seconds())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
238
cmd/cmd.go
238
cmd/cmd.go
|
@ -159,7 +159,54 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return RunGenerate(cmd, args)
|
interactive := true
|
||||||
|
|
||||||
|
opts := runOptions{
|
||||||
|
Model: name,
|
||||||
|
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||||
|
Options: map[string]interface{}{},
|
||||||
|
}
|
||||||
|
|
||||||
|
format, err := cmd.Flags().GetString("format")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
opts.Format = format
|
||||||
|
|
||||||
|
prompts := args[1:]
|
||||||
|
|
||||||
|
// prepend stdin to the prompt if provided
|
||||||
|
if !term.IsTerminal(int(os.Stdin.Fd())) {
|
||||||
|
in, err := io.ReadAll(os.Stdin)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
prompts = append([]string{string(in)}, prompts...)
|
||||||
|
opts.WordWrap = false
|
||||||
|
interactive = false
|
||||||
|
}
|
||||||
|
msg := api.Message{
|
||||||
|
Role: "user",
|
||||||
|
Content: strings.Join(prompts, " "),
|
||||||
|
}
|
||||||
|
opts.Messages = append(opts.Messages, msg)
|
||||||
|
if len(prompts) > 0 {
|
||||||
|
interactive = false
|
||||||
|
}
|
||||||
|
|
||||||
|
nowrap, err := cmd.Flags().GetBool("nowordwrap")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
opts.WordWrap = !nowrap
|
||||||
|
|
||||||
|
if !interactive {
|
||||||
|
_, err := chat(cmd, opts)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return chatInteractive(cmd, opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||||
|
@ -411,83 +458,26 @@ func PullHandler(cmd *cobra.Command, args []string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func RunGenerate(cmd *cobra.Command, args []string) error {
|
type runOptions struct {
|
||||||
interactive := true
|
|
||||||
|
|
||||||
opts := generateOptions{
|
|
||||||
Model: args[0],
|
|
||||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
|
||||||
Options: map[string]interface{}{},
|
|
||||||
}
|
|
||||||
|
|
||||||
format, err := cmd.Flags().GetString("format")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
opts.Format = format
|
|
||||||
|
|
||||||
prompts := args[1:]
|
|
||||||
|
|
||||||
// prepend stdin to the prompt if provided
|
|
||||||
if !term.IsTerminal(int(os.Stdin.Fd())) {
|
|
||||||
in, err := io.ReadAll(os.Stdin)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
prompts = append([]string{string(in)}, prompts...)
|
|
||||||
opts.WordWrap = false
|
|
||||||
interactive = false
|
|
||||||
}
|
|
||||||
opts.Prompt = strings.Join(prompts, " ")
|
|
||||||
if len(prompts) > 0 {
|
|
||||||
interactive = false
|
|
||||||
}
|
|
||||||
|
|
||||||
nowrap, err := cmd.Flags().GetBool("nowordwrap")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
opts.WordWrap = !nowrap
|
|
||||||
|
|
||||||
if !interactive {
|
|
||||||
return generate(cmd, opts)
|
|
||||||
}
|
|
||||||
|
|
||||||
return generateInteractive(cmd, opts)
|
|
||||||
}
|
|
||||||
|
|
||||||
type generateContextKey string
|
|
||||||
|
|
||||||
type generateOptions struct {
|
|
||||||
Model string
|
Model string
|
||||||
Prompt string
|
Messages []api.Message
|
||||||
WordWrap bool
|
WordWrap bool
|
||||||
Format string
|
Format string
|
||||||
System string
|
|
||||||
Template string
|
Template string
|
||||||
Options map[string]interface{}
|
Options map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func generate(cmd *cobra.Command, opts generateOptions) error {
|
func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
p := progress.NewProgress(os.Stderr)
|
p := progress.NewProgress(os.Stderr)
|
||||||
defer p.StopAndClear()
|
defer p.StopAndClear()
|
||||||
|
|
||||||
spinner := progress.NewSpinner("")
|
spinner := progress.NewSpinner("")
|
||||||
p.Add("", spinner)
|
p.Add("", spinner)
|
||||||
|
|
||||||
var latest api.GenerateResponse
|
|
||||||
|
|
||||||
generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int)
|
|
||||||
if !ok {
|
|
||||||
generateContext = []int{}
|
|
||||||
}
|
|
||||||
|
|
||||||
termWidth, _, err := term.GetSize(int(os.Stdout.Fd()))
|
termWidth, _, err := term.GetSize(int(os.Stdout.Fd()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
opts.WordWrap = false
|
opts.WordWrap = false
|
||||||
|
@ -506,24 +496,24 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
|
||||||
|
|
||||||
var currentLineLength int
|
var currentLineLength int
|
||||||
var wordBuffer string
|
var wordBuffer string
|
||||||
|
var latest api.ChatResponse
|
||||||
|
var fullResponse strings.Builder
|
||||||
|
var role string
|
||||||
|
|
||||||
request := api.GenerateRequest{
|
fn := func(response api.ChatResponse) error {
|
||||||
Model: opts.Model,
|
|
||||||
Prompt: opts.Prompt,
|
|
||||||
Context: generateContext,
|
|
||||||
Format: opts.Format,
|
|
||||||
System: opts.System,
|
|
||||||
Template: opts.Template,
|
|
||||||
Options: opts.Options,
|
|
||||||
}
|
|
||||||
fn := func(response api.GenerateResponse) error {
|
|
||||||
p.StopAndClear()
|
p.StopAndClear()
|
||||||
|
|
||||||
latest = response
|
latest = response
|
||||||
|
if response.Message == nil {
|
||||||
|
// warm-up response or done
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
role = response.Message.Role
|
||||||
|
content := response.Message.Content
|
||||||
|
fullResponse.WriteString(content)
|
||||||
|
|
||||||
termWidth, _, _ = term.GetSize(int(os.Stdout.Fd()))
|
termWidth, _, _ = term.GetSize(int(os.Stdout.Fd()))
|
||||||
if opts.WordWrap && termWidth >= 10 {
|
if opts.WordWrap && termWidth >= 10 {
|
||||||
for _, ch := range response.Response {
|
for _, ch := range content {
|
||||||
if currentLineLength+1 > termWidth-5 {
|
if currentLineLength+1 > termWidth-5 {
|
||||||
if len(wordBuffer) > termWidth-10 {
|
if len(wordBuffer) > termWidth-10 {
|
||||||
fmt.Printf("%s%c", wordBuffer, ch)
|
fmt.Printf("%s%c", wordBuffer, ch)
|
||||||
|
@ -551,7 +541,7 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
fmt.Printf("%s%s", wordBuffer, response.Response)
|
fmt.Printf("%s%s", wordBuffer, content)
|
||||||
if len(wordBuffer) > 0 {
|
if len(wordBuffer) > 0 {
|
||||||
wordBuffer = ""
|
wordBuffer = ""
|
||||||
}
|
}
|
||||||
|
@ -560,35 +550,35 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := client.Generate(cancelCtx, &request, fn); err != nil {
|
req := &api.ChatRequest{
|
||||||
if errors.Is(err, context.Canceled) {
|
Model: opts.Model,
|
||||||
return nil
|
Messages: opts.Messages,
|
||||||
}
|
Format: opts.Format,
|
||||||
return err
|
Template: opts.Template,
|
||||||
|
Options: opts.Options,
|
||||||
}
|
}
|
||||||
if opts.Prompt != "" {
|
if err := client.Chat(cancelCtx, req, fn); err != nil {
|
||||||
fmt.Println()
|
if errors.Is(err, context.Canceled) {
|
||||||
fmt.Println()
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !latest.Done {
|
if len(opts.Messages) > 0 {
|
||||||
return nil
|
fmt.Println()
|
||||||
|
fmt.Println()
|
||||||
}
|
}
|
||||||
|
|
||||||
verbose, err := cmd.Flags().GetBool("verbose")
|
verbose, err := cmd.Flags().GetBool("verbose")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if verbose {
|
if verbose {
|
||||||
latest.Summary()
|
latest.Summary()
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := cmd.Context()
|
return &api.Message{Role: role, Content: fullResponse.String()}, nil
|
||||||
ctx = context.WithValue(ctx, generateContextKey("context"), latest.Context)
|
|
||||||
cmd.SetContext(ctx)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type MultilineState int
|
type MultilineState int
|
||||||
|
@ -600,13 +590,10 @@ const (
|
||||||
MultilineTemplate
|
MultilineTemplate
|
||||||
)
|
)
|
||||||
|
|
||||||
func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
func chatInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||||
// load the model
|
// load the model
|
||||||
loadOpts := generateOptions{
|
loadOpts := runOptions{Model: opts.Model}
|
||||||
Model: opts.Model,
|
if _, err := chat(cmd, loadOpts); err != nil {
|
||||||
Prompt: "",
|
|
||||||
}
|
|
||||||
if err := generate(cmd, loadOpts); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -677,7 +664,9 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
||||||
defer fmt.Printf(readline.EndBracketedPaste)
|
defer fmt.Printf(readline.EndBracketedPaste)
|
||||||
|
|
||||||
var multiline MultilineState
|
var multiline MultilineState
|
||||||
var prompt string
|
var content string
|
||||||
|
var systemContent string
|
||||||
|
opts.Messages = make([]api.Message, 0)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
line, err := scanner.Readline()
|
line, err := scanner.Readline()
|
||||||
|
@ -691,7 +680,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
scanner.Prompt.UseAlt = false
|
scanner.Prompt.UseAlt = false
|
||||||
prompt = ""
|
content = ""
|
||||||
|
|
||||||
continue
|
continue
|
||||||
case err != nil:
|
case err != nil:
|
||||||
|
@ -699,37 +688,37 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case strings.HasPrefix(prompt, `"""`):
|
case strings.HasPrefix(content, `"""`):
|
||||||
// if the prompt so far starts with """ then we're in multiline mode
|
// if the prompt so far starts with """ then we're in multiline mode
|
||||||
// and we need to keep reading until we find a line that ends with """
|
// and we need to keep reading until we find a line that ends with """
|
||||||
cut, found := strings.CutSuffix(line, `"""`)
|
cut, found := strings.CutSuffix(line, `"""`)
|
||||||
prompt += cut + "\n"
|
content += cut + "\n"
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt = strings.TrimPrefix(prompt, `"""`)
|
content = strings.TrimPrefix(content, `"""`)
|
||||||
scanner.Prompt.UseAlt = false
|
scanner.Prompt.UseAlt = false
|
||||||
|
|
||||||
switch multiline {
|
switch multiline {
|
||||||
case MultilineSystem:
|
case MultilineSystem:
|
||||||
opts.System = prompt
|
systemContent = content
|
||||||
prompt = ""
|
content = ""
|
||||||
fmt.Println("Set system template.\n")
|
fmt.Println("Set system template.\n")
|
||||||
case MultilineTemplate:
|
case MultilineTemplate:
|
||||||
opts.Template = prompt
|
opts.Template = content
|
||||||
prompt = ""
|
content = ""
|
||||||
fmt.Println("Set model template.\n")
|
fmt.Println("Set model template.\n")
|
||||||
}
|
}
|
||||||
multiline = MultilineNone
|
multiline = MultilineNone
|
||||||
case strings.HasPrefix(line, `"""`) && len(prompt) == 0:
|
case strings.HasPrefix(line, `"""`) && len(content) == 0:
|
||||||
scanner.Prompt.UseAlt = true
|
scanner.Prompt.UseAlt = true
|
||||||
multiline = MultilinePrompt
|
multiline = MultilinePrompt
|
||||||
prompt += line + "\n"
|
content += line + "\n"
|
||||||
continue
|
continue
|
||||||
case scanner.Pasting:
|
case scanner.Pasting:
|
||||||
prompt += line + "\n"
|
content += line + "\n"
|
||||||
continue
|
continue
|
||||||
case strings.HasPrefix(line, "/list"):
|
case strings.HasPrefix(line, "/list"):
|
||||||
args := strings.Fields(line)
|
args := strings.Fields(line)
|
||||||
|
@ -791,17 +780,17 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
||||||
line = strings.TrimPrefix(line, `"""`)
|
line = strings.TrimPrefix(line, `"""`)
|
||||||
if strings.HasPrefix(args[2], `"""`) {
|
if strings.HasPrefix(args[2], `"""`) {
|
||||||
cut, found := strings.CutSuffix(line, `"""`)
|
cut, found := strings.CutSuffix(line, `"""`)
|
||||||
prompt += cut + "\n"
|
content += cut + "\n"
|
||||||
if found {
|
if found {
|
||||||
opts.System = prompt
|
systemContent = content
|
||||||
if args[1] == "system" {
|
if args[1] == "system" {
|
||||||
fmt.Println("Set system template.\n")
|
fmt.Println("Set system template.\n")
|
||||||
} else {
|
} else {
|
||||||
fmt.Println("Set prompt template.\n")
|
fmt.Println("Set prompt template.\n")
|
||||||
}
|
}
|
||||||
prompt = ""
|
content = ""
|
||||||
} else {
|
} else {
|
||||||
prompt = `"""` + prompt
|
content = `"""` + content
|
||||||
if args[1] == "system" {
|
if args[1] == "system" {
|
||||||
multiline = MultilineSystem
|
multiline = MultilineSystem
|
||||||
} else {
|
} else {
|
||||||
|
@ -810,7 +799,7 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
||||||
scanner.Prompt.UseAlt = true
|
scanner.Prompt.UseAlt = true
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
opts.System = line
|
systemContent = line
|
||||||
fmt.Println("Set system template.\n")
|
fmt.Println("Set system template.\n")
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
@ -858,8 +847,8 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
||||||
}
|
}
|
||||||
case "system":
|
case "system":
|
||||||
switch {
|
switch {
|
||||||
case opts.System != "":
|
case systemContent != "":
|
||||||
fmt.Println(opts.System + "\n")
|
fmt.Println(systemContent + "\n")
|
||||||
case resp.System != "":
|
case resp.System != "":
|
||||||
fmt.Println(resp.System + "\n")
|
fmt.Println(resp.System + "\n")
|
||||||
default:
|
default:
|
||||||
|
@ -899,16 +888,23 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
|
||||||
fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0])
|
fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0])
|
||||||
continue
|
continue
|
||||||
default:
|
default:
|
||||||
prompt += line
|
content += line
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(prompt) > 0 && multiline == MultilineNone {
|
if len(content) > 0 && multiline == MultilineNone {
|
||||||
opts.Prompt = prompt
|
if systemContent != "" {
|
||||||
if err := generate(cmd, opts); err != nil {
|
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: systemContent})
|
||||||
|
}
|
||||||
|
opts.Messages = append(opts.Messages, api.Message{Role: "user", Content: content})
|
||||||
|
assistant, err := chat(cmd, opts)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if assistant != nil {
|
||||||
|
opts.Messages = append(opts.Messages, *assistant)
|
||||||
|
}
|
||||||
|
|
||||||
prompt = ""
|
content = ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
152
docs/api.md
152
docs/api.md
|
@ -24,7 +24,7 @@ All durations are returned in nanoseconds.
|
||||||
|
|
||||||
### Streaming responses
|
### Streaming responses
|
||||||
|
|
||||||
Certain endpoints stream responses as JSON objects delineated with the newline (`\n`) character.
|
Certain endpoints stream responses as JSON objects.
|
||||||
|
|
||||||
## Generate a completion
|
## Generate a completion
|
||||||
|
|
||||||
|
@ -32,10 +32,12 @@ Certain endpoints stream responses as JSON objects delineated with the newline (
|
||||||
POST /api/generate
|
POST /api/generate
|
||||||
```
|
```
|
||||||
|
|
||||||
Generate a response for a given prompt with a provided model. This is a streaming endpoint, so will be a series of responses. The final response object will include statistics and additional data from the request.
|
Generate a response for a given prompt with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request.
|
||||||
|
|
||||||
### Parameters
|
### Parameters
|
||||||
|
|
||||||
|
`model` is required.
|
||||||
|
|
||||||
- `model`: (required) the [model name](#model-names)
|
- `model`: (required) the [model name](#model-names)
|
||||||
- `prompt`: the prompt to generate a response for
|
- `prompt`: the prompt to generate a response for
|
||||||
|
|
||||||
|
@ -43,11 +45,10 @@ Advanced parameters (optional):
|
||||||
|
|
||||||
- `format`: the format to return a response in. Currently the only accepted value is `json`
|
- `format`: the format to return a response in. Currently the only accepted value is `json`
|
||||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||||
- `system`: system prompt to (overrides what is defined in the `Modelfile`)
|
|
||||||
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
|
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
|
||||||
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
|
- `system`: system prompt to (overrides what is defined in the `Modelfile`)
|
||||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||||
- `raw`: if `true` no formatting will be applied to the prompt and no context will be returned. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API, and are managing history yourself.
|
- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API.
|
||||||
|
|
||||||
### JSON mode
|
### JSON mode
|
||||||
|
|
||||||
|
@ -57,7 +58,7 @@ Enable JSON mode by setting the `format` parameter to `json`. This will structur
|
||||||
|
|
||||||
### Examples
|
### Examples
|
||||||
|
|
||||||
#### Request
|
#### Request (Prompt)
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/generate -d '{
|
curl http://localhost:11434/api/generate -d '{
|
||||||
|
@ -89,7 +90,7 @@ The final response in the stream also includes additional data about the generat
|
||||||
- `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt
|
- `prompt_eval_duration`: time spent in nanoseconds evaluating the prompt
|
||||||
- `eval_count`: number of tokens the response
|
- `eval_count`: number of tokens the response
|
||||||
- `eval_duration`: time in nanoseconds spent generating the response
|
- `eval_duration`: time in nanoseconds spent generating the response
|
||||||
- `context`: an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory
|
- `context`: deprecated, an encoding of the conversation used in this response, this can be sent in the next request to keep a conversational memory
|
||||||
- `response`: empty if the response was streamed, if not streamed, this will contain the full response
|
- `response`: empty if the response was streamed, if not streamed, this will contain the full response
|
||||||
|
|
||||||
To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration`.
|
To calculate how fast the response is generated in tokens per second (token/s), divide `eval_count` / `eval_duration`.
|
||||||
|
@ -114,6 +115,8 @@ To calculate how fast the response is generated in tokens per second (token/s),
|
||||||
|
|
||||||
#### Request (No streaming)
|
#### Request (No streaming)
|
||||||
|
|
||||||
|
A response can be recieved in one reply when streaming is off.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/generate -d '{
|
curl http://localhost:11434/api/generate -d '{
|
||||||
"model": "llama2",
|
"model": "llama2",
|
||||||
|
@ -144,9 +147,9 @@ If `stream` is set to `false`, the response will be a single JSON object:
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Request (Raw mode)
|
#### Request (Raw Mode)
|
||||||
|
|
||||||
In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting and context.
|
In some cases you may wish to bypass the templating system and provide a full prompt. In this case, you can use the `raw` parameter to disable formatting.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/generate -d '{
|
curl http://localhost:11434/api/generate -d '{
|
||||||
|
@ -164,6 +167,7 @@ curl http://localhost:11434/api/generate -d '{
|
||||||
"model": "mistral",
|
"model": "mistral",
|
||||||
"created_at": "2023-11-03T15:36:02.583064Z",
|
"created_at": "2023-11-03T15:36:02.583064Z",
|
||||||
"response": " The sky appears blue because of a phenomenon called Rayleigh scattering.",
|
"response": " The sky appears blue because of a phenomenon called Rayleigh scattering.",
|
||||||
|
"context": [1, 2, 3],
|
||||||
"done": true,
|
"done": true,
|
||||||
"total_duration": 14648695333,
|
"total_duration": 14648695333,
|
||||||
"load_duration": 3302671417,
|
"load_duration": 3302671417,
|
||||||
|
@ -275,7 +279,6 @@ curl http://localhost:11434/api/generate -d '{
|
||||||
"model": "llama2",
|
"model": "llama2",
|
||||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||||
"response": "The sky is blue because it is the color of the sky.",
|
"response": "The sky is blue because it is the color of the sky.",
|
||||||
"context": [1, 2, 3],
|
|
||||||
"done": true,
|
"done": true,
|
||||||
"total_duration": 5589157167,
|
"total_duration": 5589157167,
|
||||||
"load_duration": 3013701500,
|
"load_duration": 3013701500,
|
||||||
|
@ -288,6 +291,135 @@ curl http://localhost:11434/api/generate -d '{
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Send Chat Messages
|
||||||
|
```shell
|
||||||
|
POST /api/chat
|
||||||
|
```
|
||||||
|
|
||||||
|
Generate the next message in a chat with a provided model. This is a streaming endpoint, so there will be a series of responses. The final response object will include statistics and additional data from the request.
|
||||||
|
|
||||||
|
### Parameters
|
||||||
|
|
||||||
|
`model` is required.
|
||||||
|
|
||||||
|
- `model`: (required) the [model name](#model-names)
|
||||||
|
- `messages`: the messages of the chat, this can be used to keep a chat memory
|
||||||
|
|
||||||
|
Advanced parameters (optional):
|
||||||
|
|
||||||
|
- `format`: the format to return a response in. Currently the only accepted value is `json`
|
||||||
|
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||||
|
- `template`: the full prompt or prompt template (overrides what is defined in the `Modelfile`)
|
||||||
|
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
#### Request
|
||||||
|
Send a chat message with a streaming response.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl http://localhost:11434/api/generate -d '{
|
||||||
|
"model": "llama2",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "why is the sky blue?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Response
|
||||||
|
|
||||||
|
A stream of JSON objects is returned:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "llama2",
|
||||||
|
"created_at": "2023-08-04T08:52:19.385406455-07:00",
|
||||||
|
"message": {
|
||||||
|
"role": "assisant",
|
||||||
|
"content": "The"
|
||||||
|
},
|
||||||
|
"done": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Final response:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "llama2",
|
||||||
|
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||||
|
"done": true,
|
||||||
|
"total_duration": 5589157167,
|
||||||
|
"load_duration": 3013701500,
|
||||||
|
"sample_count": 114,
|
||||||
|
"sample_duration": 81442000,
|
||||||
|
"prompt_eval_count": 46,
|
||||||
|
"prompt_eval_duration": 1160282000,
|
||||||
|
"eval_count": 113,
|
||||||
|
"eval_duration": 1325948000
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Request (With History)
|
||||||
|
Send a chat message with a conversation history.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl http://localhost:11434/api/generate -d '{
|
||||||
|
"model": "llama2",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "why is the sky blue?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "due to rayleigh scattering."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "how is that different than mie scattering?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Response
|
||||||
|
|
||||||
|
A stream of JSON objects is returned:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "llama2",
|
||||||
|
"created_at": "2023-08-04T08:52:19.385406455-07:00",
|
||||||
|
"message": {
|
||||||
|
"role": "assisant",
|
||||||
|
"content": "The"
|
||||||
|
},
|
||||||
|
"done": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Final response:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "llama2",
|
||||||
|
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||||
|
"done": true,
|
||||||
|
"total_duration": 5589157167,
|
||||||
|
"load_duration": 3013701500,
|
||||||
|
"sample_count": 114,
|
||||||
|
"sample_duration": 81442000,
|
||||||
|
"prompt_eval_count": 46,
|
||||||
|
"prompt_eval_duration": 1160282000,
|
||||||
|
"eval_count": 113,
|
||||||
|
"eval_duration": 1325948000
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Create a Model
|
## Create a Model
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
|
54
llm/llama.go
54
llm/llama.go
|
@ -531,21 +531,31 @@ type prediction struct {
|
||||||
|
|
||||||
const maxBufferSize = 512 * format.KiloByte
|
const maxBufferSize = 512 * format.KiloByte
|
||||||
|
|
||||||
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, format string, fn func(api.GenerateResponse)) error {
|
type PredictRequest struct {
|
||||||
prevConvo, err := llm.Decode(ctx, prevContext)
|
Model string
|
||||||
if err != nil {
|
Prompt string
|
||||||
return err
|
Format string
|
||||||
}
|
CheckpointStart time.Time
|
||||||
|
CheckpointLoaded time.Time
|
||||||
|
}
|
||||||
|
|
||||||
// Remove leading spaces from prevConvo if present
|
type PredictResponse struct {
|
||||||
prevConvo = strings.TrimPrefix(prevConvo, " ")
|
Model string
|
||||||
|
CreatedAt time.Time
|
||||||
var nextContext strings.Builder
|
TotalDuration time.Duration
|
||||||
nextContext.WriteString(prevConvo)
|
LoadDuration time.Duration
|
||||||
nextContext.WriteString(prompt)
|
Content string
|
||||||
|
Done bool
|
||||||
|
PromptEvalCount int
|
||||||
|
PromptEvalDuration time.Duration
|
||||||
|
EvalCount int
|
||||||
|
EvalDuration time.Duration
|
||||||
|
Context []int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (llm *llama) Predict(ctx context.Context, predict PredictRequest, fn func(PredictResponse)) error {
|
||||||
request := map[string]any{
|
request := map[string]any{
|
||||||
"prompt": nextContext.String(),
|
"prompt": predict.Prompt,
|
||||||
"stream": true,
|
"stream": true,
|
||||||
"n_predict": llm.NumPredict,
|
"n_predict": llm.NumPredict,
|
||||||
"n_keep": llm.NumKeep,
|
"n_keep": llm.NumKeep,
|
||||||
|
@ -567,7 +577,7 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
||||||
"stop": llm.Stop,
|
"stop": llm.Stop,
|
||||||
}
|
}
|
||||||
|
|
||||||
if format == "json" {
|
if predict.Format == "json" {
|
||||||
request["grammar"] = jsonGrammar
|
request["grammar"] = jsonGrammar
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -624,25 +634,25 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.Content != "" {
|
if p.Content != "" {
|
||||||
fn(api.GenerateResponse{Response: p.Content})
|
fn(PredictResponse{
|
||||||
nextContext.WriteString(p.Content)
|
Model: predict.Model,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
Content: p.Content,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.Stop {
|
if p.Stop {
|
||||||
embd, err := llm.Encode(ctx, nextContext.String())
|
fn(PredictResponse{
|
||||||
if err != nil {
|
Model: predict.Model,
|
||||||
return fmt.Errorf("encoding context: %v", err)
|
CreatedAt: time.Now().UTC(),
|
||||||
}
|
TotalDuration: time.Since(predict.CheckpointStart),
|
||||||
|
|
||||||
fn(api.GenerateResponse{
|
|
||||||
Done: true,
|
Done: true,
|
||||||
Context: embd,
|
|
||||||
PromptEvalCount: p.Timings.PromptN,
|
PromptEvalCount: p.Timings.PromptN,
|
||||||
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
|
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
|
||||||
EvalCount: p.Timings.PredictedN,
|
EvalCount: p.Timings.PredictedN,
|
||||||
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
|
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type LLM interface {
|
type LLM interface {
|
||||||
Predict(context.Context, []int, string, string, func(api.GenerateResponse)) error
|
Predict(context.Context, PredictRequest, func(PredictResponse)) error
|
||||||
Embedding(context.Context, string) ([]float64, error)
|
Embedding(context.Context, string) ([]float64, error)
|
||||||
Encode(context.Context, string) ([]int, error)
|
Encode(context.Context, string) ([]int, error)
|
||||||
Decode(context.Context, []int) (string, error)
|
Decode(context.Context, []int) (string, error)
|
||||||
|
|
|
@ -47,37 +47,82 @@ type Model struct {
|
||||||
Options map[string]interface{}
|
Options map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
|
type PromptVars struct {
|
||||||
t := m.Template
|
System string
|
||||||
if request.Template != "" {
|
Prompt string
|
||||||
t = request.Template
|
Response string
|
||||||
}
|
}
|
||||||
|
|
||||||
tmpl, err := template.New("").Parse(t)
|
func (m *Model) Prompt(p PromptVars) (string, error) {
|
||||||
|
var prompt strings.Builder
|
||||||
|
tmpl, err := template.New("").Parse(m.Template)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
var vars struct {
|
if p.System == "" {
|
||||||
First bool
|
// use the default system prompt for this model if one is not specified
|
||||||
System string
|
p.System = m.System
|
||||||
Prompt string
|
|
||||||
}
|
|
||||||
|
|
||||||
vars.First = len(request.Context) == 0
|
|
||||||
vars.System = m.System
|
|
||||||
vars.Prompt = request.Prompt
|
|
||||||
|
|
||||||
if request.System != "" {
|
|
||||||
vars.System = request.System
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
if err := tmpl.Execute(&sb, vars); err != nil {
|
if err := tmpl.Execute(&sb, p); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
prompt.WriteString(sb.String())
|
||||||
|
prompt.WriteString(p.Response)
|
||||||
|
return prompt.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
return sb.String(), nil
|
func (m *Model) ChatPrompt(msgs []api.Message) (string, error) {
|
||||||
|
// build the prompt from the list of messages
|
||||||
|
var prompt strings.Builder
|
||||||
|
currentVars := PromptVars{}
|
||||||
|
|
||||||
|
writePrompt := func() error {
|
||||||
|
p, err := m.Prompt(currentVars)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
prompt.WriteString(p)
|
||||||
|
currentVars = PromptVars{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, msg := range msgs {
|
||||||
|
switch msg.Role {
|
||||||
|
case "system":
|
||||||
|
if currentVars.Prompt != "" || currentVars.System != "" {
|
||||||
|
if err := writePrompt(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
currentVars.System = msg.Content
|
||||||
|
case "user":
|
||||||
|
if currentVars.Prompt != "" || currentVars.System != "" {
|
||||||
|
if err := writePrompt(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
currentVars.Prompt = msg.Content
|
||||||
|
case "assistant":
|
||||||
|
currentVars.Response = msg.Content
|
||||||
|
if err := writePrompt(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append the last set of vars if they are non-empty
|
||||||
|
if currentVars.Prompt != "" || currentVars.System != "" {
|
||||||
|
if err := writePrompt(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return prompt.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type ManifestV2 struct {
|
type ManifestV2 struct {
|
||||||
|
|
|
@ -2,17 +2,15 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestModelPrompt(t *testing.T) {
|
func TestModelPrompt(t *testing.T) {
|
||||||
var m Model
|
m := Model{
|
||||||
req := api.GenerateRequest{
|
|
||||||
Template: "a{{ .Prompt }}b",
|
Template: "a{{ .Prompt }}b",
|
||||||
Prompt: "<h1>",
|
|
||||||
}
|
}
|
||||||
s, err := m.Prompt(req)
|
s, err := m.Prompt(PromptVars{
|
||||||
|
Prompt: "<h1>",
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
295
server/routes.go
295
server/routes.go
|
@ -60,17 +60,26 @@ var loaded struct {
|
||||||
var defaultSessionDuration = 5 * time.Minute
|
var defaultSessionDuration = 5 * time.Minute
|
||||||
|
|
||||||
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
|
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
|
||||||
func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error {
|
func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sessionDuration time.Duration) (*Model, error) {
|
||||||
|
model, err := GetModel(modelName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
workDir := c.GetString("workDir")
|
||||||
|
|
||||||
opts := api.DefaultOptions()
|
opts := api.DefaultOptions()
|
||||||
if err := opts.FromMap(model.Options); err != nil {
|
if err := opts.FromMap(model.Options); err != nil {
|
||||||
log.Printf("could not load model options: %v", err)
|
log.Printf("could not load model options: %v", err)
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := opts.FromMap(reqOpts); err != nil {
|
if err := opts.FromMap(reqOpts); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
// check if the loaded model is still running in a subprocess, in case something unexpected happened
|
// check if the loaded model is still running in a subprocess, in case something unexpected happened
|
||||||
if loaded.runner != nil {
|
if loaded.runner != nil {
|
||||||
if err := loaded.runner.Ping(ctx); err != nil {
|
if err := loaded.runner.Ping(ctx); err != nil {
|
||||||
|
@ -106,7 +115,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
|
||||||
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
|
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded.Model = model
|
loaded.Model = model
|
||||||
|
@ -140,7 +149,7 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded.expireTimer.Reset(sessionDuration)
|
loaded.expireTimer.Reset(sessionDuration)
|
||||||
return nil
|
return model, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateHandler(c *gin.Context) {
|
func GenerateHandler(c *gin.Context) {
|
||||||
|
@ -173,88 +182,262 @@ func GenerateHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
model, err := GetModel(req.Model)
|
sessionDuration := defaultSessionDuration
|
||||||
|
model, err := load(c, req.Model, req.Options, sessionDuration)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var pErr *fs.PathError
|
var pErr *fs.PathError
|
||||||
if errors.As(err, &pErr) {
|
switch {
|
||||||
|
case errors.As(err, &pErr):
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
||||||
return
|
case errors.Is(err, api.ErrInvalidOpts):
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
default:
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
workDir := c.GetString("workDir")
|
// an empty request loads the model
|
||||||
|
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
||||||
// TODO: set this duration from the request if specified
|
c.JSON(http.StatusOK, api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true})
|
||||||
sessionDuration := defaultSessionDuration
|
|
||||||
if err := load(c.Request.Context(), workDir, model, req.Options, sessionDuration); err != nil {
|
|
||||||
if errors.Is(err, api.ErrInvalidOpts) {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
checkpointLoaded := time.Now()
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
prompt := req.Prompt
|
var prompt string
|
||||||
if !req.Raw {
|
sendContext := false
|
||||||
prompt, err = model.Prompt(req)
|
switch {
|
||||||
|
case req.Raw:
|
||||||
|
prompt = req.Prompt
|
||||||
|
case req.Prompt != "":
|
||||||
|
if req.Template != "" {
|
||||||
|
// override the default model template
|
||||||
|
model.Template = req.Template
|
||||||
|
}
|
||||||
|
|
||||||
|
var rebuild strings.Builder
|
||||||
|
if req.Context != nil {
|
||||||
|
// TODO: context is deprecated, at some point the context logic within this conditional should be removed
|
||||||
|
prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove leading spaces from prevCtx if present
|
||||||
|
prevCtx = strings.TrimPrefix(prevCtx, " ")
|
||||||
|
rebuild.WriteString(prevCtx)
|
||||||
|
}
|
||||||
|
p, err := model.Prompt(PromptVars{
|
||||||
|
System: req.System,
|
||||||
|
Prompt: req.Prompt,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
rebuild.WriteString(p)
|
||||||
|
prompt = rebuild.String()
|
||||||
|
sendContext = true
|
||||||
}
|
}
|
||||||
|
|
||||||
ch := make(chan any)
|
ch := make(chan any)
|
||||||
|
var generated strings.Builder
|
||||||
go func() {
|
go func() {
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
// an empty request loads the model
|
|
||||||
if req.Prompt == "" && req.Template == "" && req.System == "" {
|
|
||||||
ch <- api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
fn := func(r api.GenerateResponse) {
|
fn := func(r llm.PredictResponse) {
|
||||||
|
// Update model expiration
|
||||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||||
loaded.expireTimer.Reset(sessionDuration)
|
loaded.expireTimer.Reset(sessionDuration)
|
||||||
|
|
||||||
r.Model = req.Model
|
// Build up the full response
|
||||||
r.CreatedAt = time.Now().UTC()
|
if _, err := generated.WriteString(r.Content); err != nil {
|
||||||
if r.Done {
|
ch <- gin.H{"error": err.Error()}
|
||||||
r.TotalDuration = time.Since(checkpointStart)
|
return
|
||||||
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Raw {
|
resp := api.GenerateResponse{
|
||||||
// in raw mode the client must manage history on their own
|
Model: r.Model,
|
||||||
r.Context = nil
|
CreatedAt: r.CreatedAt,
|
||||||
|
Done: r.Done,
|
||||||
|
Response: r.Content,
|
||||||
|
EvalMetrics: api.EvalMetrics{
|
||||||
|
TotalDuration: r.TotalDuration,
|
||||||
|
LoadDuration: r.LoadDuration,
|
||||||
|
PromptEvalCount: r.PromptEvalCount,
|
||||||
|
PromptEvalDuration: r.PromptEvalDuration,
|
||||||
|
EvalCount: r.EvalCount,
|
||||||
|
EvalDuration: r.EvalDuration,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
ch <- r
|
if r.Done && sendContext {
|
||||||
|
embd, err := loaded.runner.Encode(c.Request.Context(), req.Prompt+generated.String())
|
||||||
|
if err != nil {
|
||||||
|
ch <- gin.H{"error": err.Error()}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.Context = embd
|
||||||
|
}
|
||||||
|
|
||||||
|
ch <- resp
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, req.Format, fn); err != nil {
|
// Start prediction
|
||||||
|
predictReq := llm.PredictRequest{
|
||||||
|
Model: model.Name,
|
||||||
|
Prompt: prompt,
|
||||||
|
Format: req.Format,
|
||||||
|
CheckpointStart: checkpointStart,
|
||||||
|
CheckpointLoaded: checkpointLoaded,
|
||||||
|
}
|
||||||
|
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if req.Stream != nil && !*req.Stream {
|
if req.Stream != nil && !*req.Stream {
|
||||||
var response api.GenerateResponse
|
// Wait for the channel to close
|
||||||
generated := ""
|
var r api.GenerateResponse
|
||||||
|
var sb strings.Builder
|
||||||
for resp := range ch {
|
for resp := range ch {
|
||||||
if r, ok := resp.(api.GenerateResponse); ok {
|
var ok bool
|
||||||
generated += r.Response
|
if r, ok = resp.(api.GenerateResponse); !ok {
|
||||||
response = r
|
|
||||||
} else {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
sb.WriteString(r.Response)
|
||||||
}
|
}
|
||||||
response.Response = generated
|
r.Response = sb.String()
|
||||||
c.JSON(http.StatusOK, response)
|
c.JSON(http.StatusOK, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
streamResponse(c, ch)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ChatHandler(c *gin.Context) {
|
||||||
|
loaded.mu.Lock()
|
||||||
|
defer loaded.mu.Unlock()
|
||||||
|
|
||||||
|
checkpointStart := time.Now()
|
||||||
|
|
||||||
|
var req api.ChatRequest
|
||||||
|
err := c.ShouldBindJSON(&req)
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||||
|
return
|
||||||
|
case err != nil:
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate the request
|
||||||
|
switch {
|
||||||
|
case req.Model == "":
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||||
|
return
|
||||||
|
case len(req.Format) > 0 && req.Format != "json":
|
||||||
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "format must be json"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionDuration := defaultSessionDuration
|
||||||
|
model, err := load(c, req.Model, req.Options, sessionDuration)
|
||||||
|
if err != nil {
|
||||||
|
var pErr *fs.PathError
|
||||||
|
switch {
|
||||||
|
case errors.As(err, &pErr):
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
||||||
|
case errors.Is(err, api.ErrInvalidOpts):
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
default:
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// an empty request loads the model
|
||||||
|
if len(req.Messages) == 0 {
|
||||||
|
c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
|
if req.Template != "" {
|
||||||
|
// override the default model template
|
||||||
|
model.Template = req.Template
|
||||||
|
}
|
||||||
|
prompt, err := model.ChatPrompt(req.Messages)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := make(chan any)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
|
||||||
|
fn := func(r llm.PredictResponse) {
|
||||||
|
// Update model expiration
|
||||||
|
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||||
|
loaded.expireTimer.Reset(sessionDuration)
|
||||||
|
|
||||||
|
resp := api.ChatResponse{
|
||||||
|
Model: r.Model,
|
||||||
|
CreatedAt: r.CreatedAt,
|
||||||
|
Done: r.Done,
|
||||||
|
EvalMetrics: api.EvalMetrics{
|
||||||
|
TotalDuration: r.TotalDuration,
|
||||||
|
LoadDuration: r.LoadDuration,
|
||||||
|
PromptEvalCount: r.PromptEvalCount,
|
||||||
|
PromptEvalDuration: r.PromptEvalDuration,
|
||||||
|
EvalCount: r.EvalCount,
|
||||||
|
EvalDuration: r.EvalDuration,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if !r.Done {
|
||||||
|
resp.Message = &api.Message{Role: "assistant", Content: r.Content}
|
||||||
|
}
|
||||||
|
|
||||||
|
ch <- resp
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start prediction
|
||||||
|
predictReq := llm.PredictRequest{
|
||||||
|
Model: model.Name,
|
||||||
|
Prompt: prompt,
|
||||||
|
Format: req.Format,
|
||||||
|
CheckpointStart: checkpointStart,
|
||||||
|
CheckpointLoaded: checkpointLoaded,
|
||||||
|
}
|
||||||
|
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||||
|
ch <- gin.H{"error": err.Error()}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if req.Stream != nil && !*req.Stream {
|
||||||
|
// Wait for the channel to close
|
||||||
|
var r api.ChatResponse
|
||||||
|
var sb strings.Builder
|
||||||
|
for resp := range ch {
|
||||||
|
var ok bool
|
||||||
|
if r, ok = resp.(api.ChatResponse); !ok {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Message != nil {
|
||||||
|
sb.WriteString(r.Message.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.Message = &api.Message{Role: "assistant", Content: sb.String()}
|
||||||
|
c.JSON(http.StatusOK, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -281,15 +464,18 @@ func EmbeddingHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
model, err := GetModel(req.Model)
|
sessionDuration := defaultSessionDuration
|
||||||
|
_, err = load(c, req.Model, req.Options, sessionDuration)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
var pErr *fs.PathError
|
||||||
return
|
switch {
|
||||||
}
|
case errors.As(err, &pErr):
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
|
||||||
workDir := c.GetString("workDir")
|
case errors.Is(err, api.ErrInvalidOpts):
|
||||||
if err := load(c.Request.Context(), workDir, model, req.Options, 5*time.Minute); err != nil {
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
default:
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -767,6 +953,7 @@ func Serve(ln net.Listener, allowOrigins []string) error {
|
||||||
|
|
||||||
r.POST("/api/pull", PullModelHandler)
|
r.POST("/api/pull", PullModelHandler)
|
||||||
r.POST("/api/generate", GenerateHandler)
|
r.POST("/api/generate", GenerateHandler)
|
||||||
|
r.POST("/api/chat", ChatHandler)
|
||||||
r.POST("/api/embeddings", EmbeddingHandler)
|
r.POST("/api/embeddings", EmbeddingHandler)
|
||||||
r.POST("/api/create", CreateModelHandler)
|
r.POST("/api/create", CreateModelHandler)
|
||||||
r.POST("/api/push", PushModelHandler)
|
r.POST("/api/push", PushModelHandler)
|
||||||
|
|
Loading…
Reference in a new issue