package server import ( "fmt" "log/slog" "strings" "text/template" "text/template/parse" "github.com/jmorganca/ollama/api" ) // isResponseNode checks if the node contains .Response func isResponseNode(node *parse.ActionNode) bool { for _, cmd := range node.Pipe.Cmds { for _, arg := range cmd.Args { if fieldNode, ok := arg.(*parse.FieldNode); ok && len(fieldNode.Ident) > 0 { if fieldNode.Ident[0] == "Response" { return true } } } } return false } // formatTemplateForResponse formats the template AST to: // 1. remove all nodes after the first .Response (if generate=true) // 2. add a .Response node to the end if it doesn't exist // TODO(jmorganca): this should recursively cut the template before the first .Response func formatTemplateForResponse(tmpl *template.Template, generate bool) { var found bool for i, node := range tmpl.Tree.Root.Nodes { if actionNode, ok := node.(*parse.ActionNode); ok { if isResponseNode(actionNode) { found = true if generate { tmpl.Tree.Root.Nodes = tmpl.Tree.Root.Nodes[:i+1] break } } } } if !found { // add the response node if it doesn't exist responseFieldNode := &parse.FieldNode{NodeType: parse.NodeField, Ident: []string{"Response"}} responsePipeNode := &parse.PipeNode{NodeType: parse.NodePipe, Cmds: []*parse.CommandNode{{NodeType: parse.NodeCommand, Args: []parse.Node{responseFieldNode}}}} responseActionNode := &parse.ActionNode{NodeType: parse.NodeAction, Pipe: responsePipeNode} tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, responseActionNode) } } // Prompt renders a prompt from a template. If generate is set to true, // the response and parts of the template following it are not rendered func Prompt(tmpl, system, prompt, response string, generate bool) (string, error) { parsed, err := template.New("").Option("missingkey=zero").Parse(tmpl) if err != nil { return "", err } formatTemplateForResponse(parsed, generate) vars := map[string]any{ "System": system, "Prompt": prompt, "Response": response, } var sb strings.Builder if err := parsed.Execute(&sb, vars); err != nil { return "", err } return sb.String(), nil } func countTokens(tmpl string, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) { rendered, err := Prompt(tmpl, system, prompt, response, false) if err != nil { return 0, err } tokens, err := encode(rendered) if err != nil { slog.Error("failed to encode prompt", "err", err) return 0, err } return len(tokens), err } // ChatPrompt builds up a prompt from a series of messages, truncating based on context window size func ChatPrompt(tmpl string, system string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) { type prompt struct { System string Prompt string Response string images []int tokens int } var p prompt // Set the first system prompt to the model's system prompt if system != "" { p.System = system } // iterate through messages to build up {system,user,response} prompts var imgId int var prompts []prompt for _, msg := range messages { switch strings.ToLower(msg.Role) { case "system": if p.System != "" || p.Prompt != "" || p.Response != "" { prompts = append(prompts, p) p = prompt{} } p.System = msg.Content case "user": if p.Prompt != "" || p.Response != "" { prompts = append(prompts, p) p = prompt{} } p.Prompt = msg.Content for range msg.Images { p.Prompt += fmt.Sprintf(" [img-%d]", imgId) p.images = append(p.images, imgId) imgId += 1 } case "assistant": if p.Response != "" { prompts = append(prompts, p) p = prompt{} } p.Response = msg.Content default: return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role) } } // add final prompt if p.System != "" || p.Prompt != "" || p.Response != "" { prompts = append(prompts, p) } // calculate token lengths for each prompt, estimating 768 tokens per images for i, p := range prompts { tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode) if err != nil { return "", err } prompts[i].tokens = tokens + len(prompts[i].images)*768 } // truncate images and prompts starting from the beginning of the list // until either one prompt remains or the total tokens fits the context window // TODO (jmorganca): this doesn't account for the context window room required for the response for { var required int for _, p := range prompts { required += p.tokens } required += 1 // for bos token if required <= window { slog.Debug("prompt now fits in context window", "required", required, "window", window) break } prompt := &prompts[0] if len(prompt.images) > 1 { img := prompt.images[0] slog.Debug("prompt longer than context window, removing image", "id", img, "required", required, "window", window) prompt.images = prompt.images[1:] prompt.Prompt = strings.Replace(prompt.Prompt, fmt.Sprintf(" [img-%d]", img), "", 1) prompt.tokens -= 768 continue } if len(prompts) > 1 { slog.Debug("required tokens longer than context window, removing first prompt", "prompt", prompts[0].tokens, "required", required, "window", window) system := prompt.System prompts = prompts[1:] if system != "" && prompts[0].System == "" { prompts[0].System = system tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode) if err != nil { return "", err } prompts[0].tokens = tokens + len(prompts[0].images)*768 } continue } // stop truncating if there's only one prompt left break } var sb strings.Builder for i, p := range prompts { // last prompt should leave the response unrendered (for completion) rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1) if err != nil { return "", err } sb.WriteString(rendered) } return sb.String(), nil }