cmd: speed up gguf creates (#6324)

This commit is contained in:
Josh 2024-08-12 11:46:09 -07:00 committed by GitHub
parent 01d544d373
commit 980dd15f81
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 96 additions and 3 deletions

View file

@ -176,9 +176,20 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
mediatype = "application/vnd.ollama.image.projector" mediatype = "application/vnd.ollama.image.projector"
} }
layer, err := NewLayer(io.NewSectionReader(file, offset, n), mediatype) var layer Layer
if err != nil { if digest != "" && n == stat.Size() && offset == 0 {
return nil, err layer, err = NewLayerFromLayer(digest, mediatype, file.Name())
if err != nil {
slog.Debug("could not create new layer from layer", "error", err)
}
}
// Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size())
if layer.Digest == "" {
layer, err = NewLayer(io.NewSectionReader(file, offset, n), mediatype)
if err != nil {
return nil, err
}
} }
layers = append(layers, &layerGGML{layer, ggml}) layers = append(layers, &layerGGML{layer, ggml})

View file

@ -2,8 +2,10 @@ package server
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
@ -11,6 +13,7 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
) )
@ -133,3 +136,82 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
}) })
} }
} }
func TestParseFromFileFromLayer(t *testing.T) {
tempModels := t.TempDir()
file, err := os.CreateTemp(tempModels, "")
if err != nil {
t.Fatalf("failed to open file: %v", err)
}
defer file.Close()
if err := llm.WriteGGUF(file, llm.KV{"general.architecture": "gemma"}, []llm.Tensor{}); err != nil {
t.Fatalf("failed to write gguf: %v", err)
}
if _, err := file.Seek(0, io.SeekStart); err != nil {
t.Fatalf("failed to seek to start: %v", err)
}
layers, err := parseFromFile(context.Background(), file, "", func(api.ProgressResponse) {})
if err != nil {
t.Fatalf("failed to parse from file: %v", err)
}
if len(layers) != 1 {
t.Fatalf("got %d != want 1", len(layers))
}
if _, err := file.Seek(0, io.SeekStart); err != nil {
t.Fatalf("failed to seek to start: %v", err)
}
layers2, err := parseFromFile(context.Background(), file, layers[0].Digest, func(api.ProgressResponse) {})
if err != nil {
t.Fatalf("failed to parse from file: %v", err)
}
if len(layers2) != 1 {
t.Fatalf("got %d != want 1", len(layers2))
}
if layers[0].Digest != layers2[0].Digest {
t.Fatalf("got %s != want %s", layers[0].Digest, layers2[0].Digest)
}
if layers[0].Size != layers2[0].Size {
t.Fatalf("got %d != want %d", layers[0].Size, layers2[0].Size)
}
if layers[0].MediaType != layers2[0].MediaType {
t.Fatalf("got %v != want %v", layers[0].MediaType, layers2[0].MediaType)
}
}
func TestParseLayerFromCopy(t *testing.T) {
tempModels := t.TempDir()
file2, err := os.CreateTemp(tempModels, "")
if err != nil {
t.Fatalf("failed to open file: %v", err)
}
defer file2.Close()
for range 5 {
if err := llm.WriteGGUF(file2, llm.KV{"general.architecture": "gemma"}, []llm.Tensor{}); err != nil {
t.Fatalf("failed to write gguf: %v", err)
}
}
if _, err := file2.Seek(0, io.SeekStart); err != nil {
t.Fatalf("failed to seek to start: %v", err)
}
layers, err := parseFromFile(context.Background(), file2, "", func(api.ProgressResponse) {})
if err != nil {
t.Fatalf("failed to parse from file: %v", err)
}
if len(layers) != 5 {
t.Fatalf("got %d != want 5", len(layers))
}
}