Merge pull request #4910 from ollama/mxyng/detect-chat-template
fix create model when template detection errors
This commit is contained in:
commit
385a32ecb5
7 changed files with 76 additions and 54 deletions
18
llm/gguf.go
18
llm/gguf.go
|
@ -618,22 +618,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
offset, err := ws.Seek(0, io.SeekCurrent)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var alignment int64 = 32
|
var alignment int64 = 32
|
||||||
padding := llm.padding(offset, alignment)
|
|
||||||
if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tensor := range tensors {
|
for _, tensor := range tensors {
|
||||||
if _, err := tensor.WriteTo(ws); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
offset, err := ws.Seek(0, io.SeekCurrent)
|
offset, err := ws.Seek(0, io.SeekCurrent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -643,6 +629,10 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
|
||||||
if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
|
if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if _, err := tensor.WriteTo(ws); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -437,11 +437,9 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||||
config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture())
|
config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture())
|
||||||
|
|
||||||
if s := baseLayer.GGML.KV().ChatTemplate(); s != "" {
|
if s := baseLayer.GGML.KV().ChatTemplate(); s != "" {
|
||||||
t, err := templates.NamedTemplate(s)
|
if t, err := templates.NamedTemplate(s); err != nil {
|
||||||
if err != nil {
|
slog.Debug("template detection", "error", err)
|
||||||
return err
|
} else {
|
||||||
}
|
|
||||||
|
|
||||||
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -451,6 +449,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||||
layers = append(layers, layer)
|
layers = append(layers, layer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
layers = append(layers, baseLayer.Layer)
|
layers = append(layers, baseLayer.Layer)
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,11 +15,12 @@ import (
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
)
|
)
|
||||||
|
|
||||||
var stream bool = false
|
var stream bool = false
|
||||||
|
|
||||||
func createBinFile(t *testing.T) string {
|
func createBinFile(t *testing.T, kv map[string]any, ti []llm.Tensor) string {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
f, err := os.CreateTemp(t.TempDir(), "")
|
f, err := os.CreateTemp(t.TempDir(), "")
|
||||||
|
@ -28,19 +29,7 @@ func createBinFile(t *testing.T) string {
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil {
|
if err := llm.NewGGUFV3(binary.LittleEndian).Encode(f, kv, ti); err != nil {
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := binary.Write(f, binary.LittleEndian, uint32(3)); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := binary.Write(f, binary.LittleEndian, uint64(0)); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := binary.Write(f, binary.LittleEndian, uint64(0)); err != nil {
|
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,7 +90,7 @@ func TestCreateFromBin(t *testing.T) {
|
||||||
var s Server
|
var s Server
|
||||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -126,7 +115,7 @@ func TestCreateFromModel(t *testing.T) {
|
||||||
|
|
||||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -166,7 +155,7 @@ func TestCreateRemovesLayers(t *testing.T) {
|
||||||
|
|
||||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}", createBinFile(t)),
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}", createBinFile(t, nil, nil)),
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -186,7 +175,7 @@ func TestCreateRemovesLayers(t *testing.T) {
|
||||||
|
|
||||||
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t)),
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t, nil, nil)),
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -212,7 +201,7 @@ func TestCreateUnsetsSystem(t *testing.T) {
|
||||||
|
|
||||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Modelfile: fmt.Sprintf("FROM %s\nSYSTEM Say hi!", createBinFile(t)),
|
Modelfile: fmt.Sprintf("FROM %s\nSYSTEM Say hi!", createBinFile(t, nil, nil)),
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -232,7 +221,7 @@ func TestCreateUnsetsSystem(t *testing.T) {
|
||||||
|
|
||||||
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Modelfile: fmt.Sprintf("FROM %s\nSYSTEM \"\"", createBinFile(t)),
|
Modelfile: fmt.Sprintf("FROM %s\nSYSTEM \"\"", createBinFile(t, nil, nil)),
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -267,7 +256,7 @@ func TestCreateMergeParameters(t *testing.T) {
|
||||||
|
|
||||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Modelfile: fmt.Sprintf("FROM %s\nPARAMETER temperature 1\nPARAMETER top_k 10\nPARAMETER stop USER:\nPARAMETER stop ASSISTANT:", createBinFile(t)),
|
Modelfile: fmt.Sprintf("FROM %s\nPARAMETER temperature 1\nPARAMETER top_k 10\nPARAMETER stop USER:\nPARAMETER stop ASSISTANT:", createBinFile(t, nil, nil)),
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -369,7 +358,7 @@ func TestCreateReplacesMessages(t *testing.T) {
|
||||||
|
|
||||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Modelfile: fmt.Sprintf("FROM %s\nMESSAGE assistant \"What is my purpose?\"\nMESSAGE user \"You run tests.\"\nMESSAGE assistant \"Oh, my god.\"", createBinFile(t)),
|
Modelfile: fmt.Sprintf("FROM %s\nMESSAGE assistant \"What is my purpose?\"\nMESSAGE user \"You run tests.\"\nMESSAGE assistant \"Oh, my god.\"", createBinFile(t, nil, nil)),
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -444,7 +433,7 @@ func TestCreateTemplateSystem(t *testing.T) {
|
||||||
|
|
||||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}\nSYSTEM Say hello!\nTEMPLATE {{ .System }} {{ .Prompt }}\nSYSTEM Say bye!", createBinFile(t)),
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}\nSYSTEM Say hello!\nTEMPLATE {{ .System }} {{ .Prompt }}\nSYSTEM Say bye!", createBinFile(t, nil, nil)),
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -489,7 +478,7 @@ func TestCreateLicenses(t *testing.T) {
|
||||||
|
|
||||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Modelfile: fmt.Sprintf("FROM %s\nLICENSE MIT\nLICENSE Apache-2.0", createBinFile(t)),
|
Modelfile: fmt.Sprintf("FROM %s\nLICENSE MIT\nLICENSE Apache-2.0", createBinFile(t, nil, nil)),
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -526,3 +515,46 @@ func TestCreateLicenses(t *testing.T) {
|
||||||
t.Errorf("expected Apache-2.0, actual %s", apache)
|
t.Errorf("expected Apache-2.0, actual %s", apache)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCreateDetectTemplate(t *testing.T) {
|
||||||
|
p := t.TempDir()
|
||||||
|
t.Setenv("OLLAMA_MODELS", p)
|
||||||
|
var s Server
|
||||||
|
|
||||||
|
t.Run("matched", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
|
||||||
|
"tokenizer.chat_template": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
|
||||||
|
}, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-06cd2687a518d624073f125f1db1c5c727f77c75e84a138fe745186dbbbb4cd7"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-542b217f179c7825eeb5bca3c77d2b75ed05bafbd3451d9188891a60a85337c6"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unmatched", func(t *testing.T) {
|
||||||
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
|
Name: "test",
|
||||||
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||||
|
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||||
|
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -16,7 +16,7 @@ func TestDelete(t *testing.T) {
|
||||||
|
|
||||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||||
})
|
})
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
if w.Code != http.StatusOK {
|
||||||
|
@ -25,7 +25,7 @@ func TestDelete(t *testing.T) {
|
||||||
|
|
||||||
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
Name: "test2",
|
Name: "test2",
|
||||||
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t)),
|
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t, nil, nil)),
|
||||||
})
|
})
|
||||||
|
|
||||||
if w.Code != http.StatusOK {
|
if w.Code != http.StatusOK {
|
||||||
|
|
|
@ -29,7 +29,7 @@ func TestList(t *testing.T) {
|
||||||
for _, n := range expectNames {
|
for _, n := range expectNames {
|
||||||
createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
Name: n,
|
Name: n,
|
||||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -261,7 +261,7 @@ func TestCase(t *testing.T) {
|
||||||
t.Run(tt, func(t *testing.T) {
|
t.Run(tt, func(t *testing.T) {
|
||||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
Name: tt,
|
Name: tt,
|
||||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -277,7 +277,7 @@ func TestCase(t *testing.T) {
|
||||||
t.Run("create", func(t *testing.T) {
|
t.Run("create", func(t *testing.T) {
|
||||||
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||||
Name: strings.ToUpper(tt),
|
Name: strings.ToUpper(tt),
|
||||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
|
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,8 @@ var templatesOnce = sync.OnceValues(func() ([]*Template, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Bytes = bts
|
// normalize line endings
|
||||||
|
t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
|
||||||
}
|
}
|
||||||
|
|
||||||
return templates, nil
|
return templates, nil
|
||||||
|
|
Loading…
Reference in a new issue