219 lines
5.8 KiB
Go
219 lines
5.8 KiB
Go
package server
|
|
|
|
import (
|
|
"fmt"
|
|
"log/slog"
|
|
"strings"
|
|
"text/template"
|
|
"text/template/parse"
|
|
|
|
"github.com/jmorganca/ollama/api"
|
|
)
|
|
|
|
// isResponseNode checks if the node contains .Response
|
|
func isResponseNode(node *parse.ActionNode) bool {
|
|
for _, cmd := range node.Pipe.Cmds {
|
|
for _, arg := range cmd.Args {
|
|
if fieldNode, ok := arg.(*parse.FieldNode); ok && len(fieldNode.Ident) > 0 {
|
|
if fieldNode.Ident[0] == "Response" {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// formatTemplateForResponse formats the template AST to:
|
|
// 1. remove all nodes after the first .Response (if generate=true)
|
|
// 2. add a .Response node to the end if it doesn't exist
|
|
// TODO(jmorganca): this should recursively cut the template before the first .Response
|
|
func formatTemplateForResponse(tmpl *template.Template, generate bool) {
|
|
var found bool
|
|
for i, node := range tmpl.Tree.Root.Nodes {
|
|
if actionNode, ok := node.(*parse.ActionNode); ok {
|
|
if isResponseNode(actionNode) {
|
|
found = true
|
|
if generate {
|
|
tmpl.Tree.Root.Nodes = tmpl.Tree.Root.Nodes[:i+1]
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
// add the response node if it doesn't exist
|
|
responseFieldNode := &parse.FieldNode{NodeType: parse.NodeField, Ident: []string{"Response"}}
|
|
responsePipeNode := &parse.PipeNode{NodeType: parse.NodePipe, Cmds: []*parse.CommandNode{{NodeType: parse.NodeCommand, Args: []parse.Node{responseFieldNode}}}}
|
|
responseActionNode := &parse.ActionNode{NodeType: parse.NodeAction, Pipe: responsePipeNode}
|
|
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, responseActionNode)
|
|
}
|
|
}
|
|
|
|
// Prompt renders a prompt from a template. If generate is set to true,
|
|
// the response and parts of the template following it are not rendered
|
|
func Prompt(tmpl, system, prompt, response string, generate bool) (string, error) {
|
|
parsed, err := template.New("").Option("missingkey=zero").Parse(tmpl)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
formatTemplateForResponse(parsed, generate)
|
|
|
|
vars := map[string]any{
|
|
"System": system,
|
|
"Prompt": prompt,
|
|
"Response": response,
|
|
}
|
|
|
|
var sb strings.Builder
|
|
if err := parsed.Execute(&sb, vars); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return sb.String(), nil
|
|
}
|
|
|
|
func countTokens(tmpl string, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
|
|
rendered, err := Prompt(tmpl, system, prompt, response, false)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
tokens, err := encode(rendered)
|
|
if err != nil {
|
|
slog.Error("failed to encode prompt", "err", err)
|
|
return 0, err
|
|
}
|
|
|
|
return len(tokens), err
|
|
}
|
|
|
|
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
|
|
func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
|
|
type prompt struct {
|
|
System string
|
|
Prompt string
|
|
Response string
|
|
|
|
images []int
|
|
tokens int
|
|
}
|
|
|
|
var p prompt
|
|
|
|
// iterate through messages to build up {system,user,response} prompts
|
|
var imgId int
|
|
var prompts []prompt
|
|
for _, msg := range messages {
|
|
switch strings.ToLower(msg.Role) {
|
|
case "system":
|
|
if p.System != "" || p.Prompt != "" || p.Response != "" {
|
|
prompts = append(prompts, p)
|
|
p = prompt{}
|
|
}
|
|
|
|
p.System = msg.Content
|
|
case "user":
|
|
if p.Prompt != "" || p.Response != "" {
|
|
prompts = append(prompts, p)
|
|
p = prompt{}
|
|
}
|
|
|
|
p.Prompt = msg.Content
|
|
|
|
for range msg.Images {
|
|
p.Prompt += fmt.Sprintf(" [img-%d]", imgId)
|
|
p.images = append(p.images, imgId)
|
|
imgId += 1
|
|
}
|
|
case "assistant":
|
|
if p.Response != "" {
|
|
prompts = append(prompts, p)
|
|
p = prompt{}
|
|
}
|
|
|
|
p.Response = msg.Content
|
|
default:
|
|
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
|
}
|
|
}
|
|
|
|
// add final prompt
|
|
if p.System != "" || p.Prompt != "" || p.Response != "" {
|
|
prompts = append(prompts, p)
|
|
}
|
|
|
|
// calculate token lengths for each prompt, estimating 768 tokens per images
|
|
for i, p := range prompts {
|
|
tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
prompts[i].tokens = tokens + len(prompts[i].images)*768
|
|
}
|
|
|
|
// truncate images and prompts starting from the beginning of the list
|
|
// until either one prompt remains or the total tokens fits the context window
|
|
// TODO (jmorganca): this doesn't account for the context window room required for the response
|
|
for {
|
|
var required int
|
|
for _, p := range prompts {
|
|
required += p.tokens
|
|
}
|
|
|
|
required += 1 // for bos token
|
|
|
|
if required <= window {
|
|
slog.Debug("prompt now fits in context window", "required", required, "window", window)
|
|
break
|
|
}
|
|
|
|
prompt := &prompts[0]
|
|
|
|
if len(prompt.images) > 1 {
|
|
img := prompt.images[0]
|
|
slog.Debug("prompt longer than context window, removing image", "id", img, "required", required, "window", window)
|
|
prompt.images = prompt.images[1:]
|
|
prompt.Prompt = strings.Replace(prompt.Prompt, fmt.Sprintf(" [img-%d]", img), "", 1)
|
|
prompt.tokens -= 768
|
|
continue
|
|
}
|
|
|
|
if len(prompts) > 1 {
|
|
slog.Debug("required tokens longer than context window, removing first prompt", "prompt", prompts[0].tokens, "required", required, "window", window)
|
|
system := prompt.System
|
|
prompts = prompts[1:]
|
|
|
|
if system != "" && prompts[0].System == "" {
|
|
prompts[0].System = system
|
|
|
|
tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
prompts[0].tokens = tokens + len(prompts[0].images)*768
|
|
}
|
|
|
|
continue
|
|
}
|
|
|
|
// stop truncating if there's only one prompt left
|
|
break
|
|
}
|
|
|
|
var sb strings.Builder
|
|
for i, p := range prompts {
|
|
// last prompt should leave the response unrendered (for completion)
|
|
rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
sb.WriteString(rendered)
|
|
}
|
|
|
|
return sb.String(), nil
|
|
}
|