tools
This commit is contained in:
parent
9e35d9bbee
commit
d02bbebb11
7 changed files with 263 additions and 53 deletions
39
api/types.go
39
api/types.go
|
@ -97,6 +97,9 @@ type ChatRequest struct {
|
||||||
// followin the request.
|
// followin the request.
|
||||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||||
|
|
||||||
|
// Tools is an optional list of tools the model has access to.
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options.
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
}
|
}
|
||||||
|
@ -105,9 +108,36 @@ type ChatRequest struct {
|
||||||
// role ("system", "user", or "assistant"), the content and an optional list
|
// role ("system", "user", or "assistant"), the content and an optional list
|
||||||
// of images.
|
// of images.
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content,omitempty"`
|
||||||
Images []ImageData `json:"images,omitempty"`
|
Images []ImageData `json:"images,omitempty"`
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments map[string]any `json:"arguments"`
|
||||||
|
} `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Parameters struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Required []string `json:"required"`
|
||||||
|
Properties map[string]struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
} `json:"properties"`
|
||||||
|
} `json:"parameters"`
|
||||||
|
} `json:"function"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Message) UnmarshalJSON(b []byte) error {
|
func (m *Message) UnmarshalJSON(b []byte) error {
|
||||||
|
@ -374,6 +404,9 @@ type GenerateResponse struct {
|
||||||
// Response is the textual response itself.
|
// Response is the textual response itself.
|
||||||
Response string `json:"response"`
|
Response string `json:"response"`
|
||||||
|
|
||||||
|
// ToolCalls is the list of tools the model wants to call
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
|
||||||
// Done specifies if the response is complete.
|
// Done specifies if the response is complete.
|
||||||
Done bool `json:"done"`
|
Done bool `json:"done"`
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,10 @@ var errCapabilityCompletion = errors.New("completion")
|
||||||
|
|
||||||
type Capability string
|
type Capability string
|
||||||
|
|
||||||
const CapabilityCompletion = Capability("completion")
|
const (
|
||||||
|
CapabilityCompletion = Capability("completion")
|
||||||
|
CapabilityTools = Capability("tools")
|
||||||
|
)
|
||||||
|
|
||||||
type registryOptions struct {
|
type registryOptions struct {
|
||||||
Insecure bool
|
Insecure bool
|
||||||
|
@ -88,6 +91,10 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
|
||||||
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
|
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
|
||||||
errs = append(errs, errCapabilityCompletion)
|
errs = append(errs, errCapabilityCompletion)
|
||||||
}
|
}
|
||||||
|
case CapabilityTools:
|
||||||
|
if !slices.Contains(m.Template.Vars(), "tools") {
|
||||||
|
errs = append(errs, errors.New("tools"))
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
slog.Error("unknown capability", "capability", cap)
|
slog.Error("unknown capability", "capability", cap)
|
||||||
return fmt.Errorf("unknown capability: %s", cap)
|
return fmt.Errorf("unknown capability: %s", cap)
|
||||||
|
@ -95,7 +102,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := errors.Join(errs...); err != nil {
|
if err := errors.Join(errs...); err != nil {
|
||||||
return fmt.Errorf("missing capabilities: %w", errors.Join(errs...))
|
return fmt.Errorf("does not support %w", errors.Join(errs...))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
105
server/model.go
105
server/model.go
|
@ -4,6 +4,7 @@ import (
|
||||||
"archive/zip"
|
"archive/zip"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -11,7 +12,11 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"text/template/parse"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/convert"
|
"github.com/ollama/ollama/convert"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
|
@ -289,3 +294,103 @@ func detectContentType(r io.Reader) (string, error) {
|
||||||
|
|
||||||
return "unknown", nil
|
return "unknown", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||||
|
// mxyng: this only really works if the input contains tool calls in some JSON format
|
||||||
|
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
||||||
|
// create a subtree from the node that ranges over .ToolCalls
|
||||||
|
tmpl := m.Template.Subtree(func(n parse.Node) bool {
|
||||||
|
if t, ok := n.(*parse.RangeNode); ok {
|
||||||
|
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
|
||||||
|
if tmpl == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := tmpl.Execute(&b, map[string][]map[string]any{
|
||||||
|
"ToolCalls": {
|
||||||
|
{
|
||||||
|
"Function": map[string]any{
|
||||||
|
"Name": "@@name@@",
|
||||||
|
"Arguments": "@@arguments@@",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
var kv map[string]string
|
||||||
|
// execute the subtree with placeholders to identify the keys
|
||||||
|
if err := json.Unmarshal(b.Bytes(), &kv); err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the keys that correspond to the name and arguments fields
|
||||||
|
var name, arguments string
|
||||||
|
for k, v := range kv {
|
||||||
|
switch v {
|
||||||
|
case "@@name@@":
|
||||||
|
name = k
|
||||||
|
case "@@arguments@@":
|
||||||
|
arguments = k
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var sm []map[string]any
|
||||||
|
decoder := json.NewDecoder(strings.NewReader(s))
|
||||||
|
for {
|
||||||
|
// incrementally decode the JSON into a list of JSON objects
|
||||||
|
// skipping over any invalid tokens
|
||||||
|
if err := decoder.Decode(&sm); err != nil {
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if errors.As(err, new(*json.SyntaxError)) {
|
||||||
|
r := decoder.Buffered()
|
||||||
|
if _, err := r.Read(make([]byte, decoder.InputOffset()+1)); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
decoder = json.NewDecoder(r)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// break as soon as a valid object is decoded
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
var toolCalls []api.ToolCall
|
||||||
|
for _, kv := range sm {
|
||||||
|
call := api.ToolCall{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
Type: "function",
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range kv {
|
||||||
|
switch k {
|
||||||
|
case name:
|
||||||
|
call.Function.Name = v.(string)
|
||||||
|
case arguments:
|
||||||
|
call.Function.Arguments = v.(map[string]any)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls = append(toolCalls, call)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
return toolCalls, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
|
||||||
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
||||||
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
||||||
// latest message and 2) system messages
|
// latest message and 2) system messages
|
||||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
|
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
|
||||||
var system []api.Message
|
var system []api.Message
|
||||||
// always include the last message
|
// always include the last message
|
||||||
n := len(msgs) - 1
|
n := len(msgs) - 1
|
||||||
|
@ -29,7 +29,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||||
}
|
}
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
|
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,7 +57,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||||
|
|
||||||
// truncate any messages that do not fit into the context window
|
// truncate any messages that do not fit into the context window
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
|
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -192,7 +192,7 @@ func TestChatPrompt(t *testing.T) {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
|
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
|
||||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||||
prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs)
|
prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -265,6 +265,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Response = sb.String()
|
r.Response = sb.String()
|
||||||
|
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||||
|
r.ToolCalls = toolCalls
|
||||||
|
r.Response = ""
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, r)
|
c.JSON(http.StatusOK, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -1279,6 +1284,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
caps := []Capability{CapabilityCompletion}
|
caps := []Capability{CapabilityCompletion}
|
||||||
|
if req.Tools != nil {
|
||||||
|
caps = append(caps, CapabilityTools)
|
||||||
|
}
|
||||||
|
|
||||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||||
if errors.Is(err, errCapabilityCompletion) {
|
if errors.Is(err, errCapabilityCompletion) {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
||||||
|
@ -1305,7 +1314,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...)
|
req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...)
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
|
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages, req.Tools)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
|
@ -1348,13 +1357,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if req.Stream != nil && !*req.Stream {
|
if req.Stream != nil && !*req.Stream {
|
||||||
var r api.ChatResponse
|
var resp api.ChatResponse
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for rr := range ch {
|
for rr := range ch {
|
||||||
switch t := rr.(type) {
|
switch t := rr.(type) {
|
||||||
case api.ChatResponse:
|
case api.ChatResponse:
|
||||||
sb.WriteString(t.Message.Content)
|
sb.WriteString(t.Message.Content)
|
||||||
r = t
|
resp = t
|
||||||
case gin.H:
|
case gin.H:
|
||||||
msg, ok := t["error"].(string)
|
msg, ok := t["error"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -1369,8 +1378,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Message.Content = sb.String()
|
resp.Message.Content = sb.String()
|
||||||
c.JSON(http.StatusOK, r)
|
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||||
|
resp.Message.ToolCalls = toolCalls
|
||||||
|
resp.Message.Content = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, resp)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"text/template"
|
"text/template"
|
||||||
"text/template/parse"
|
"text/template/parse"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/agnivade/levenshtein"
|
"github.com/agnivade/levenshtein"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
@ -102,8 +103,18 @@ var response = parse.ActionNode{
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var funcs = template.FuncMap{
|
||||||
|
"json": func(v any) string {
|
||||||
|
b, _ := json.Marshal(v)
|
||||||
|
return string(b)
|
||||||
|
},
|
||||||
|
"now": func() string {
|
||||||
|
return time.Now().Format("2006-01-02 15:04:05")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
func Parse(s string) (*Template, error) {
|
func Parse(s string) (*Template, error) {
|
||||||
tmpl := template.New("").Option("missingkey=zero")
|
tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
|
||||||
|
|
||||||
tmpl, err := tmpl.Parse(s)
|
tmpl, err := tmpl.Parse(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -127,7 +138,7 @@ func (t *Template) Vars() []string {
|
||||||
var vars []string
|
var vars []string
|
||||||
for _, tt := range t.Templates() {
|
for _, tt := range t.Templates() {
|
||||||
for _, n := range tt.Root.Nodes {
|
for _, n := range tt.Root.Nodes {
|
||||||
vars = append(vars, parseNode(n)...)
|
vars = append(vars, Identifiers(n)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -143,17 +154,65 @@ func (t *Template) Vars() []string {
|
||||||
|
|
||||||
type Values struct {
|
type Values struct {
|
||||||
Messages []api.Message
|
Messages []api.Message
|
||||||
|
Tools []api.Tool
|
||||||
|
|
||||||
// forceLegacy is a flag used to test compatibility with legacy templates
|
// forceLegacy is a flag used to test compatibility with legacy templates
|
||||||
forceLegacy bool
|
forceLegacy bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
|
||||||
|
var walk func(parse.Node) parse.Node
|
||||||
|
walk = func(n parse.Node) parse.Node {
|
||||||
|
if fn(n) {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
switch t := n.(type) {
|
||||||
|
case *parse.ListNode:
|
||||||
|
for _, c := range t.Nodes {
|
||||||
|
if n := walk(c); n != nil {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case *parse.BranchNode:
|
||||||
|
for _, n := range []*parse.ListNode{t.List, t.ElseList} {
|
||||||
|
if n != nil {
|
||||||
|
if n := walk(n); n != nil {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case *parse.IfNode:
|
||||||
|
return walk(&t.BranchNode)
|
||||||
|
case *parse.WithNode:
|
||||||
|
return walk(&t.BranchNode)
|
||||||
|
case *parse.RangeNode:
|
||||||
|
return walk(&t.BranchNode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if n := walk(t.Tree.Root); n != nil {
|
||||||
|
return (&template.Template{
|
||||||
|
Tree: &parse.Tree{
|
||||||
|
Root: &parse.ListNode{
|
||||||
|
Nodes: []parse.Node{n},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}).Funcs(funcs)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Template) Execute(w io.Writer, v Values) error {
|
func (t *Template) Execute(w io.Writer, v Values) error {
|
||||||
system, messages := collate(v.Messages)
|
system, messages := collate(v.Messages)
|
||||||
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
||||||
return t.Template.Execute(w, map[string]any{
|
return t.Template.Execute(w, map[string]any{
|
||||||
"System": system,
|
"System": system,
|
||||||
"Messages": messages,
|
"Messages": messages,
|
||||||
|
"Tools": v.Tools,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -161,7 +220,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
var prompt, response string
|
var prompt, response string
|
||||||
for _, m := range messages {
|
for _, m := range messages {
|
||||||
execute := func () error {
|
execute := func() error {
|
||||||
if err := t.Template.Execute(&b, map[string]any{
|
if err := t.Template.Execute(&b, map[string]any{
|
||||||
"System": system,
|
"System": system,
|
||||||
"Prompt": prompt,
|
"Prompt": prompt,
|
||||||
|
@ -198,12 +257,8 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||||
|
|
||||||
var cut bool
|
var cut bool
|
||||||
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
|
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
|
||||||
switch t := n.(type) {
|
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
|
||||||
case *parse.ActionNode:
|
cut = true
|
||||||
case *parse.FieldNode:
|
|
||||||
if slices.Contains(t.Ident, "Response") {
|
|
||||||
cut = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return cut
|
return cut
|
||||||
|
@ -255,50 +310,46 @@ func collate(msgs []api.Message) (string, []*api.Message) {
|
||||||
return strings.Join(system, "\n\n"), collated
|
return strings.Join(system, "\n\n"), collated
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseNode(n parse.Node) []string {
|
// Identifiers walks the node tree returning any identifiers it finds along the way
|
||||||
|
func Identifiers(n parse.Node) []string {
|
||||||
switch n := n.(type) {
|
switch n := n.(type) {
|
||||||
|
case *parse.ListNode:
|
||||||
|
var names []string
|
||||||
|
for _, n := range n.Nodes {
|
||||||
|
names = append(names, Identifiers(n)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return names
|
||||||
|
case *parse.TemplateNode:
|
||||||
|
return Identifiers(n.Pipe)
|
||||||
case *parse.ActionNode:
|
case *parse.ActionNode:
|
||||||
return parseNode(n.Pipe)
|
return Identifiers(n.Pipe)
|
||||||
|
case *parse.BranchNode:
|
||||||
|
names := Identifiers(n.Pipe)
|
||||||
|
for _, n := range []*parse.ListNode{n.List, n.ElseList} {
|
||||||
|
if n != nil {
|
||||||
|
names = append(names, Identifiers(n)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return names
|
||||||
case *parse.IfNode:
|
case *parse.IfNode:
|
||||||
names := parseNode(n.Pipe)
|
return Identifiers(&n.BranchNode)
|
||||||
names = append(names, parseNode(n.List)...)
|
|
||||||
if n.ElseList != nil {
|
|
||||||
names = append(names, parseNode(n.ElseList)...)
|
|
||||||
}
|
|
||||||
return names
|
|
||||||
case *parse.RangeNode:
|
case *parse.RangeNode:
|
||||||
names := parseNode(n.Pipe)
|
return Identifiers(&n.BranchNode)
|
||||||
names = append(names, parseNode(n.List)...)
|
|
||||||
if n.ElseList != nil {
|
|
||||||
names = append(names, parseNode(n.ElseList)...)
|
|
||||||
}
|
|
||||||
return names
|
|
||||||
case *parse.WithNode:
|
case *parse.WithNode:
|
||||||
names := parseNode(n.Pipe)
|
return Identifiers(&n.BranchNode)
|
||||||
names = append(names, parseNode(n.List)...)
|
|
||||||
if n.ElseList != nil {
|
|
||||||
names = append(names, parseNode(n.ElseList)...)
|
|
||||||
}
|
|
||||||
return names
|
|
||||||
case *parse.PipeNode:
|
case *parse.PipeNode:
|
||||||
var names []string
|
var names []string
|
||||||
for _, c := range n.Cmds {
|
for _, c := range n.Cmds {
|
||||||
for _, a := range c.Args {
|
for _, a := range c.Args {
|
||||||
names = append(names, parseNode(a)...)
|
names = append(names, Identifiers(a)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return names
|
|
||||||
case *parse.ListNode:
|
|
||||||
var names []string
|
|
||||||
for _, n := range n.Nodes {
|
|
||||||
names = append(names, parseNode(n)...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return names
|
return names
|
||||||
case *parse.FieldNode:
|
case *parse.FieldNode:
|
||||||
return n.Ident
|
return n.Ident
|
||||||
case *parse.TemplateNode:
|
case *parse.VariableNode:
|
||||||
return parseNode(n.Pipe)
|
return n.Ident
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
Loading…
Reference in a new issue