Merge pull request #5051 from ollama/mxyng/capabilities
add model capabilities
This commit is contained in:
commit
dddb58a38b
31 changed files with 354 additions and 192 deletions
|
@ -28,11 +28,16 @@ import (
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Capability string
|
||||||
|
|
||||||
|
const CapabilityCompletion = Capability("completion")
|
||||||
|
|
||||||
type registryOptions struct {
|
type registryOptions struct {
|
||||||
Insecure bool
|
Insecure bool
|
||||||
Username string
|
Username string
|
||||||
|
@ -48,16 +53,43 @@ type Model struct {
|
||||||
ParentModel string
|
ParentModel string
|
||||||
AdapterPaths []string
|
AdapterPaths []string
|
||||||
ProjectorPaths []string
|
ProjectorPaths []string
|
||||||
Template string
|
|
||||||
System string
|
System string
|
||||||
License []string
|
License []string
|
||||||
Digest string
|
Digest string
|
||||||
Options map[string]interface{}
|
Options map[string]interface{}
|
||||||
Messages []Message
|
Messages []Message
|
||||||
|
|
||||||
|
Template *template.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) IsEmbedding() bool {
|
func (m *Model) Has(caps ...Capability) bool {
|
||||||
return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert")
|
for _, cap := range caps {
|
||||||
|
switch cap {
|
||||||
|
case CapabilityCompletion:
|
||||||
|
f, err := os.Open(m.ModelPath)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("couldn't open model file", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
// TODO(mxyng): decode the GGML into model to avoid doing this multiple times
|
||||||
|
ggml, _, err := llm.DecodeGGML(f, 0)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("couldn't decode ggml", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
slog.Error("unknown capability", "capability", cap)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) String() string {
|
func (m *Model) String() string {
|
||||||
|
@ -82,10 +114,10 @@ func (m *Model) String() string {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.Template != "" {
|
if m.Template != nil {
|
||||||
modelfile.Commands = append(modelfile.Commands, parser.Command{
|
modelfile.Commands = append(modelfile.Commands, parser.Command{
|
||||||
Name: "template",
|
Name: "template",
|
||||||
Args: m.Template,
|
Args: m.Template.String(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,13 +167,6 @@ type Message struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ManifestV2 struct {
|
|
||||||
SchemaVersion int `json:"schemaVersion"`
|
|
||||||
MediaType string `json:"mediaType"`
|
|
||||||
Config *Layer `json:"config"`
|
|
||||||
Layers []*Layer `json:"layers"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ConfigV2 struct {
|
type ConfigV2 struct {
|
||||||
ModelFormat string `json:"model_format"`
|
ModelFormat string `json:"model_format"`
|
||||||
ModelFamily string `json:"model_family"`
|
ModelFamily string `json:"model_family"`
|
||||||
|
@ -160,7 +185,7 @@ type RootFS struct {
|
||||||
DiffIDs []string `json:"diff_ids"`
|
DiffIDs []string `json:"diff_ids"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
|
func GetManifest(mp ModelPath) (*Manifest, string, error) {
|
||||||
fp, err := mp.GetManifestPath()
|
fp, err := mp.GetManifestPath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
|
@ -170,7 +195,7 @@ func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
var manifest *ManifestV2
|
var manifest *Manifest
|
||||||
|
|
||||||
bts, err := os.ReadFile(fp)
|
bts, err := os.ReadFile(fp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -198,8 +223,7 @@ func GetModel(name string) (*Model, error) {
|
||||||
Name: mp.GetFullTagname(),
|
Name: mp.GetFullTagname(),
|
||||||
ShortName: mp.GetShortTagname(),
|
ShortName: mp.GetShortTagname(),
|
||||||
Digest: digest,
|
Digest: digest,
|
||||||
Template: "{{ .Prompt }}",
|
Template: template.DefaultTemplate,
|
||||||
License: []string{},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
filename, err := GetBlobsPath(manifest.Config.Digest)
|
filename, err := GetBlobsPath(manifest.Config.Digest)
|
||||||
|
@ -235,13 +259,17 @@ func GetModel(name string) (*Model, error) {
|
||||||
model.AdapterPaths = append(model.AdapterPaths, filename)
|
model.AdapterPaths = append(model.AdapterPaths, filename)
|
||||||
case "application/vnd.ollama.image.projector":
|
case "application/vnd.ollama.image.projector":
|
||||||
model.ProjectorPaths = append(model.ProjectorPaths, filename)
|
model.ProjectorPaths = append(model.ProjectorPaths, filename)
|
||||||
case "application/vnd.ollama.image.template":
|
case "application/vnd.ollama.image.prompt",
|
||||||
|
"application/vnd.ollama.image.template":
|
||||||
bts, err := os.ReadFile(filename)
|
bts, err := os.ReadFile(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
model.Template = string(bts)
|
model.Template, err = template.Parse(string(bts))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
case "application/vnd.ollama.image.system":
|
case "application/vnd.ollama.image.system":
|
||||||
bts, err := os.ReadFile(filename)
|
bts, err := os.ReadFile(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -249,13 +277,6 @@ func GetModel(name string) (*Model, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
model.System = string(bts)
|
model.System = string(bts)
|
||||||
case "application/vnd.ollama.image.prompt":
|
|
||||||
bts, err := os.ReadFile(filename)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
model.Template = string(bts)
|
|
||||||
case "application/vnd.ollama.image.params":
|
case "application/vnd.ollama.image.params":
|
||||||
params, err := os.Open(filename)
|
params, err := os.Open(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -822,7 +843,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||||
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||||
mp := ParseModelPath(name)
|
mp := ParseModelPath(name)
|
||||||
|
|
||||||
var manifest *ManifestV2
|
var manifest *Manifest
|
||||||
var err error
|
var err error
|
||||||
var noprune string
|
var noprune string
|
||||||
|
|
||||||
|
@ -929,7 +950,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*ManifestV2, error) {
|
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
|
||||||
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||||
|
|
||||||
headers := make(http.Header)
|
headers := make(http.Header)
|
||||||
|
@ -940,7 +961,7 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
var m *ManifestV2
|
var m *Manifest
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,10 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Manifest struct {
|
type Manifest struct {
|
||||||
ManifestV2
|
SchemaVersion int `json:"schemaVersion"`
|
||||||
|
MediaType string `json:"mediaType"`
|
||||||
|
Config *Layer `json:"config"`
|
||||||
|
Layers []*Layer `json:"layers"`
|
||||||
|
|
||||||
filepath string
|
filepath string
|
||||||
fi os.FileInfo
|
fi os.FileInfo
|
||||||
|
@ -66,7 +69,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
||||||
|
|
||||||
p := filepath.Join(manifests, n.Filepath())
|
p := filepath.Join(manifests, n.Filepath())
|
||||||
|
|
||||||
var m ManifestV2
|
var m Manifest
|
||||||
f, err := os.Open(p)
|
f, err := os.Open(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -83,12 +86,11 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Manifest{
|
m.filepath = p
|
||||||
ManifestV2: m,
|
m.fi = fi
|
||||||
filepath: p,
|
m.digest = fmt.Sprintf("%x", sha256sum.Sum(nil))
|
||||||
fi: fi,
|
|
||||||
digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
|
return &m, nil
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
|
func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
|
||||||
|
@ -108,7 +110,7 @@ func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
m := ManifestV2{
|
m := Manifest{
|
||||||
SchemaVersion: 2,
|
SchemaVersion: 2,
|
||||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||||
Config: config,
|
Config: config,
|
||||||
|
|
|
@ -25,7 +25,7 @@ func createManifest(t *testing.T, path, name string) {
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil {
|
if err := json.NewEncoder(f).Encode(Manifest{}); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,7 @@ import (
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/convert"
|
"github.com/ollama/ollama/convert"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/templates"
|
"github.com/ollama/ollama/template"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -256,7 +256,7 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
|
||||||
func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
||||||
for _, layer := range layers {
|
for _, layer := range layers {
|
||||||
if s := layer.GGML.KV().ChatTemplate(); s != "" {
|
if s := layer.GGML.KV().ChatTemplate(); s != "" {
|
||||||
if t, err := templates.NamedTemplate(s); err != nil {
|
if t, err := template.Named(s); err != nil {
|
||||||
slog.Debug("template detection", "error", err)
|
slog.Debug("template detection", "error", err)
|
||||||
} else {
|
} else {
|
||||||
tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||||
|
|
|
@ -4,10 +4,11 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
|
||||||
"text/template/parse"
|
"text/template/parse"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
// isResponseNode checks if the node contains .Response
|
// isResponseNode checks if the node contains .Response
|
||||||
|
@ -53,13 +54,8 @@ func formatTemplateForResponse(tmpl *template.Template, generate bool) {
|
||||||
|
|
||||||
// Prompt renders a prompt from a template. If generate is set to true,
|
// Prompt renders a prompt from a template. If generate is set to true,
|
||||||
// the response and parts of the template following it are not rendered
|
// the response and parts of the template following it are not rendered
|
||||||
func Prompt(tmpl, system, prompt, response string, generate bool) (string, error) {
|
func Prompt(tmpl *template.Template, system, prompt, response string, generate bool) (string, error) {
|
||||||
parsed, err := template.New("").Option("missingkey=zero").Parse(tmpl)
|
formatTemplateForResponse(tmpl, generate)
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
formatTemplateForResponse(parsed, generate)
|
|
||||||
|
|
||||||
vars := map[string]any{
|
vars := map[string]any{
|
||||||
"System": system,
|
"System": system,
|
||||||
|
@ -68,14 +64,14 @@ func Prompt(tmpl, system, prompt, response string, generate bool) (string, error
|
||||||
}
|
}
|
||||||
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
if err := parsed.Execute(&sb, vars); err != nil {
|
if err := tmpl.Execute(&sb, vars); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func countTokens(tmpl string, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
|
func countTokens(tmpl *template.Template, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
|
||||||
rendered, err := Prompt(tmpl, system, prompt, response, false)
|
rendered, err := Prompt(tmpl, system, prompt, response, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
|
@ -91,7 +87,7 @@ func countTokens(tmpl string, system string, prompt string, response string, enc
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
|
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
|
||||||
func ChatPrompt(tmpl string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
|
func ChatPrompt(tmpl *template.Template, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
|
||||||
type prompt struct {
|
type prompt struct {
|
||||||
System string
|
System string
|
||||||
Prompt string
|
Prompt string
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPrompt(t *testing.T) {
|
func TestPrompt(t *testing.T) {
|
||||||
|
@ -61,7 +62,12 @@ func TestPrompt(t *testing.T) {
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
got, err := Prompt(tc.template, tc.system, tc.prompt, tc.response, tc.generate)
|
tmpl, err := template.Parse(tc.template)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := Prompt(tmpl, tc.system, tc.prompt, tc.response, tc.generate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("error = %v", err)
|
t.Errorf("error = %v", err)
|
||||||
}
|
}
|
||||||
|
@ -192,7 +198,12 @@ func TestChatPrompt(t *testing.T) {
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
got, err := ChatPrompt(tc.template, tc.messages, tc.window, encode)
|
tmpl, err := template.Parse(tc.template)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := ChatPrompt(tmpl, tc.messages, tc.window, encode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("error = %v", err)
|
t.Errorf("error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,6 +31,7 @@ import (
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/openai"
|
"github.com/ollama/ollama/openai"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
|
@ -121,8 +122,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if model.IsEmbedding() {
|
if !model.Has(CapabilityCompletion) {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support generate"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -161,6 +162,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tmpl, err := template.Parse(req.Template)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
checkpointLoaded := time.Now()
|
checkpointLoaded := time.Now()
|
||||||
|
|
||||||
var prompt string
|
var prompt string
|
||||||
|
@ -169,7 +176,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
prompt = req.Prompt
|
prompt = req.Prompt
|
||||||
case req.Prompt != "":
|
case req.Prompt != "":
|
||||||
if req.Template == "" {
|
if req.Template == "" {
|
||||||
req.Template = model.Template
|
model.Template, err = template.Parse(req.Template)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.System == "" {
|
if req.System == "" {
|
||||||
|
@ -187,7 +198,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
|
|
||||||
sb.WriteString(req.Prompt)
|
sb.WriteString(req.Prompt)
|
||||||
|
|
||||||
p, err := Prompt(req.Template, req.System, sb.String(), "", true)
|
p, err := Prompt(tmpl, req.System, sb.String(), "", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
|
@ -242,7 +253,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||||
|
|
||||||
if !req.Raw {
|
if !req.Raw {
|
||||||
p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false)
|
p, err := Prompt(tmpl, req.System, req.Prompt, generated.String(), false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
|
@ -680,7 +691,10 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Template != "" {
|
if req.Template != "" {
|
||||||
m.Template = req.Template
|
m.Template, err = template.Parse(req.Template)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
msgs := make([]api.Message, 0)
|
msgs := make([]api.Message, 0)
|
||||||
|
@ -701,7 +715,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||||
resp := &api.ShowResponse{
|
resp := &api.ShowResponse{
|
||||||
License: strings.Join(m.License, "\n"),
|
License: strings.Join(m.License, "\n"),
|
||||||
System: m.System,
|
System: m.System,
|
||||||
Template: m.Template,
|
Template: m.Template.String(),
|
||||||
Details: modelDetails,
|
Details: modelDetails,
|
||||||
Messages: msgs,
|
Messages: msgs,
|
||||||
ModifiedAt: manifest.fi.ModTime(),
|
ModifiedAt: manifest.fi.ModTime(),
|
||||||
|
@ -1248,7 +1262,7 @@ func (s *Server) ProcessHandler(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
||||||
func chatPrompt(ctx context.Context, runner *runnerRef, template string, messages []api.Message, numCtx int) (string, error) {
|
func chatPrompt(ctx context.Context, runner *runnerRef, template *template.Template, messages []api.Message, numCtx int) (string, error) {
|
||||||
encode := func(s string) ([]int, error) {
|
encode := func(s string) ([]int, error) {
|
||||||
return runner.llama.Tokenize(ctx, s)
|
return runner.llama.Tokenize(ctx, s)
|
||||||
}
|
}
|
||||||
|
@ -1296,8 +1310,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if model.IsEmbedding() {
|
if !model.Has(CapabilityCompletion) {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support chat"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
158
template/template.go
Normal file
158
template/template.go
Normal file
|
@ -0,0 +1,158 @@
|
||||||
|
package template
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"embed"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"text/template"
|
||||||
|
"text/template/parse"
|
||||||
|
|
||||||
|
"github.com/agnivade/levenshtein"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed index.json
|
||||||
|
var indexBytes []byte
|
||||||
|
|
||||||
|
//go:embed *.gotmpl
|
||||||
|
var templatesFS embed.FS
|
||||||
|
|
||||||
|
var templatesOnce = sync.OnceValues(func() ([]*named, error) {
|
||||||
|
var templates []*named
|
||||||
|
if err := json.Unmarshal(indexBytes, &templates); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range templates {
|
||||||
|
bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalize line endings
|
||||||
|
t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return templates, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
type named struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Template string `json:"template"`
|
||||||
|
Bytes []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t named) Reader() io.Reader {
|
||||||
|
return bytes.NewReader(t.Bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Named(s string) (*named, error) {
|
||||||
|
templates, err := templatesOnce()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var template *named
|
||||||
|
score := math.MaxInt
|
||||||
|
for _, t := range templates {
|
||||||
|
if s := levenshtein.ComputeDistance(s, t.Template); s < score {
|
||||||
|
score = s
|
||||||
|
template = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if score < 100 {
|
||||||
|
return template, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.New("no matching template found")
|
||||||
|
}
|
||||||
|
|
||||||
|
type Template struct {
|
||||||
|
*template.Template
|
||||||
|
raw string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Template) String() string {
|
||||||
|
return t.raw
|
||||||
|
}
|
||||||
|
|
||||||
|
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
|
||||||
|
|
||||||
|
func Parse(s string) (*Template, error) {
|
||||||
|
t, err := template.New("").Option("missingkey=zero").Parse(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Template{Template: t, raw: s}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Template) Vars() []string {
|
||||||
|
var vars []string
|
||||||
|
for _, n := range t.Tree.Root.Nodes {
|
||||||
|
vars = append(vars, parseNode(n)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
set := make(map[string]struct{})
|
||||||
|
for _, n := range vars {
|
||||||
|
set[strings.ToLower(n)] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
vars = maps.Keys(set)
|
||||||
|
slices.Sort(vars)
|
||||||
|
return vars
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseNode(n parse.Node) []string {
|
||||||
|
switch n := n.(type) {
|
||||||
|
case *parse.ActionNode:
|
||||||
|
return parseNode(n.Pipe)
|
||||||
|
case *parse.IfNode:
|
||||||
|
names := parseNode(n.Pipe)
|
||||||
|
names = append(names, parseNode(n.List)...)
|
||||||
|
if n.ElseList != nil {
|
||||||
|
names = append(names, parseNode(n.ElseList)...)
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
case *parse.RangeNode:
|
||||||
|
names := parseNode(n.Pipe)
|
||||||
|
names = append(names, parseNode(n.List)...)
|
||||||
|
if n.ElseList != nil {
|
||||||
|
names = append(names, parseNode(n.ElseList)...)
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
case *parse.WithNode:
|
||||||
|
names := parseNode(n.Pipe)
|
||||||
|
names = append(names, parseNode(n.List)...)
|
||||||
|
if n.ElseList != nil {
|
||||||
|
names = append(names, parseNode(n.ElseList)...)
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
case *parse.PipeNode:
|
||||||
|
var names []string
|
||||||
|
for _, c := range n.Cmds {
|
||||||
|
for _, a := range c.Args {
|
||||||
|
names = append(names, parseNode(a)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
case *parse.ListNode:
|
||||||
|
var names []string
|
||||||
|
for _, n := range n.Nodes {
|
||||||
|
names = append(names, parseNode(n)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return names
|
||||||
|
case *parse.FieldNode:
|
||||||
|
return n.Ident
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
89
template/template_test.go
Normal file
89
template/template_test.go
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
package template
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNamed(t *testing.T) {
|
||||||
|
f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(f)
|
||||||
|
for scanner.Scan() {
|
||||||
|
var ss map[string]string
|
||||||
|
if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range ss {
|
||||||
|
t.Run(k, func(t *testing.T) {
|
||||||
|
kv := llm.KV{"tokenizer.chat_template": v}
|
||||||
|
s := kv.ChatTemplate()
|
||||||
|
r, err := Named(s)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Name != k {
|
||||||
|
t.Errorf("expected %q, got %q", k, r.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := io.Copy(&b, r.Reader()); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpl, err := template.New(s).Parse(b.String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tmpl.Tree.Root.String() == "" {
|
||||||
|
t.Errorf("empty %s template", k)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParse(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
template string
|
||||||
|
vars []string
|
||||||
|
}{
|
||||||
|
{"{{ .Prompt }}", []string{"prompt"}},
|
||||||
|
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}},
|
||||||
|
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
|
||||||
|
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "system", "tools"}},
|
||||||
|
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
|
||||||
|
{"{{ range .Messages }}{{ if eq .Role \"system\" }}SYSTEM: {{ .Content }}{{ else if eq .Role \"user\" }}USER: {{ .Content }}{{ else if eq .Role \"assistant\" }}ASSISTANT: {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role"}},
|
||||||
|
{"{{ .Prompt }} {{ .Suffix }}", []string{"prompt", "suffix"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run("", func(t *testing.T) {
|
||||||
|
tmpl, err := Parse(tt.template)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
vars := tmpl.Vars()
|
||||||
|
if !slices.Equal(tt.vars, vars) {
|
||||||
|
t.Errorf("expected %v, got %v", tt.vars, vars)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,70 +0,0 @@
|
||||||
package templates
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"embed"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"math"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/agnivade/levenshtein"
|
|
||||||
)
|
|
||||||
|
|
||||||
//go:embed index.json
|
|
||||||
var indexBytes []byte
|
|
||||||
|
|
||||||
//go:embed *.gotmpl
|
|
||||||
var templatesFS embed.FS
|
|
||||||
|
|
||||||
var templatesOnce = sync.OnceValues(func() ([]*Template, error) {
|
|
||||||
var templates []*Template
|
|
||||||
if err := json.Unmarshal(indexBytes, &templates); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, t := range templates {
|
|
||||||
bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// normalize line endings
|
|
||||||
t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
|
|
||||||
}
|
|
||||||
|
|
||||||
return templates, nil
|
|
||||||
})
|
|
||||||
|
|
||||||
type Template struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Template string `json:"template"`
|
|
||||||
Bytes []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t Template) Reader() io.Reader {
|
|
||||||
return bytes.NewReader(t.Bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NamedTemplate(s string) (*Template, error) {
|
|
||||||
templates, err := templatesOnce()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var template *Template
|
|
||||||
score := math.MaxInt
|
|
||||||
for _, t := range templates {
|
|
||||||
if s := levenshtein.ComputeDistance(s, t.Template); s < score {
|
|
||||||
score = s
|
|
||||||
template = t
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if score < 100 {
|
|
||||||
return template, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errors.New("no matching template found")
|
|
||||||
}
|
|
|
@ -1,59 +0,0 @@
|
||||||
package templates
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
"text/template"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/llm"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestKVChatTemplate(t *testing.T) {
|
|
||||||
f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
scanner := bufio.NewScanner(f)
|
|
||||||
for scanner.Scan() {
|
|
||||||
var ss map[string]string
|
|
||||||
if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range ss {
|
|
||||||
t.Run(k, func(t *testing.T) {
|
|
||||||
kv := llm.KV{"tokenizer.chat_template": v}
|
|
||||||
s := kv.ChatTemplate()
|
|
||||||
r, err := NamedTemplate(s)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.Name != k {
|
|
||||||
t.Errorf("expected %q, got %q", k, r.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
if _, err := io.Copy(&b, r.Reader()); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tmpl, err := template.New(s).Parse(b.String())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tmpl.Tree.Root.String() == "" {
|
|
||||||
t.Errorf("empty %s template", k)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in a new issue