From fccf8d179f50d0221bc3445460555520121b6913 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 21 Jul 2023 13:33:56 -0700 Subject: [PATCH] partial decode ggml bin for more info --- {llama => llm}/ggml-alloc.c | 0 {llama => llm}/ggml-alloc.h | 0 {llama => llm}/ggml-cuda.cu | 0 {llama => llm}/ggml-cuda.h | 0 {llama => llm}/ggml-metal.h | 0 {llama => llm}/ggml-metal.m | 0 {llama => llm}/ggml-metal.metal | 0 {llama => llm}/ggml-mpi.c | 0 {llama => llm}/ggml-mpi.h | 0 {llama => llm}/ggml-opencl.cpp | 0 {llama => llm}/ggml-opencl.h | 0 {llama => llm}/ggml.c | 0 llm/ggml.go | 180 +++++++++++++++++++++++++++++ {llama => llm}/ggml.h | 0 {llama => llm}/k_quants.c | 0 {llama => llm}/k_quants.h | 0 {llama => llm}/llama-util.h | 0 {llama => llm}/llama.cpp | 0 {llama => llm}/llama.go | 100 ++++++++++------ {llama => llm}/llama.h | 0 {llama => llm}/llama_darwin.go | 2 +- llm/llm.go | 40 +++++++ {llama => llm}/update-llama-cpp.sh | 0 {llama => llm}/utils.go | 2 +- server/images.go | 58 ++++++---- server/routes.go | 23 ++-- 26 files changed, 336 insertions(+), 69 deletions(-) rename {llama => llm}/ggml-alloc.c (100%) rename {llama => llm}/ggml-alloc.h (100%) rename {llama => llm}/ggml-cuda.cu (100%) rename {llama => llm}/ggml-cuda.h (100%) rename {llama => llm}/ggml-metal.h (100%) rename {llama => llm}/ggml-metal.m (100%) rename {llama => llm}/ggml-metal.metal (100%) rename {llama => llm}/ggml-mpi.c (100%) rename {llama => llm}/ggml-mpi.h (100%) rename {llama => llm}/ggml-opencl.cpp (100%) rename {llama => llm}/ggml-opencl.h (100%) rename {llama => llm}/ggml.c (100%) create mode 100644 llm/ggml.go rename {llama => llm}/ggml.h (100%) rename {llama => llm}/k_quants.c (100%) rename {llama => llm}/k_quants.h (100%) rename {llama => llm}/llama-util.h (100%) rename {llama => llm}/llama.cpp (100%) rename {llama => llm}/llama.go (80%) rename {llama => llm}/llama.h (100%) rename {llama => llm}/llama_darwin.go (98%) create mode 100644 llm/llm.go rename {llama => llm}/update-llama-cpp.sh (100%) rename {llama => llm}/utils.go (92%) diff --git a/llama/ggml-alloc.c b/llm/ggml-alloc.c similarity index 100% rename from llama/ggml-alloc.c rename to llm/ggml-alloc.c diff --git a/llama/ggml-alloc.h b/llm/ggml-alloc.h similarity index 100% rename from llama/ggml-alloc.h rename to llm/ggml-alloc.h diff --git a/llama/ggml-cuda.cu b/llm/ggml-cuda.cu similarity index 100% rename from llama/ggml-cuda.cu rename to llm/ggml-cuda.cu diff --git a/llama/ggml-cuda.h b/llm/ggml-cuda.h similarity index 100% rename from llama/ggml-cuda.h rename to llm/ggml-cuda.h diff --git a/llama/ggml-metal.h b/llm/ggml-metal.h similarity index 100% rename from llama/ggml-metal.h rename to llm/ggml-metal.h diff --git a/llama/ggml-metal.m b/llm/ggml-metal.m similarity index 100% rename from llama/ggml-metal.m rename to llm/ggml-metal.m diff --git a/llama/ggml-metal.metal b/llm/ggml-metal.metal similarity index 100% rename from llama/ggml-metal.metal rename to llm/ggml-metal.metal diff --git a/llama/ggml-mpi.c b/llm/ggml-mpi.c similarity index 100% rename from llama/ggml-mpi.c rename to llm/ggml-mpi.c diff --git a/llama/ggml-mpi.h b/llm/ggml-mpi.h similarity index 100% rename from llama/ggml-mpi.h rename to llm/ggml-mpi.h diff --git a/llama/ggml-opencl.cpp b/llm/ggml-opencl.cpp similarity index 100% rename from llama/ggml-opencl.cpp rename to llm/ggml-opencl.cpp diff --git a/llama/ggml-opencl.h b/llm/ggml-opencl.h similarity index 100% rename from llama/ggml-opencl.h rename to llm/ggml-opencl.h diff --git a/llama/ggml.c b/llm/ggml.c similarity index 100% rename from llama/ggml.c rename to llm/ggml.c diff --git a/llm/ggml.go b/llm/ggml.go new file mode 100644 index 00000000..a5f013c3 --- /dev/null +++ b/llm/ggml.go @@ -0,0 +1,180 @@ +package llm + +import ( + "encoding/binary" + "errors" + "fmt" + "io" +) + +type ModelFamily string + +const ModelFamilyLlama ModelFamily = "llama" + +type ModelType uint32 + +const ( + ModelType3B ModelType = 26 + ModelType7B ModelType = 32 + ModelType13B ModelType = 40 + ModelType30B ModelType = 60 + ModelType65B ModelType = 80 +) + +type FileType uint32 + +const ( + FileTypeF32 FileType = iota + FileTypeF16 + FileTypeQ4_0 + FileTypeQ4_1 + FileTypeQ4_1_F16 + FileTypeQ8_0 = iota + 3 + FileTypeQ5_0 + FileTypeQ5_1 + FileTypeQ2_K + FileTypeQ3_K + FileTypeQ4_K + FileTypeQ5_K + FileTypeQ6_K + FileTypeUnknown = -1 +) + +type GGML struct { + ModelFamily + ModelType + + magic uint32 + container + + llamaHyperparameters +} + +type container interface { + Name() string + Decode(io.Reader) error +} + +type containerGGML struct { +} + +func (c *containerGGML) Name() string { + return "ggml" +} + +func (c *containerGGML) Decode(r io.Reader) error { + return nil +} + +type containerGGMF struct { + version uint32 +} + +func (c *containerGGMF) Name() string { + return "ggmf" +} + +func (c *containerGGMF) Decode(r io.Reader) error { + var version uint32 + binary.Read(r, binary.LittleEndian, &version) + + switch version { + case 1: + default: + return errors.New("invalid version") + } + + c.version = version + return nil +} + +type containerGGJT struct { + version uint32 +} + +func (c *containerGGJT) Name() string { + return "ggjt" +} + +func (c *containerGGJT) Decode(r io.Reader) error { + var version uint32 + binary.Read(r, binary.LittleEndian, &version) + + switch version { + case 1, 2, 3: + default: + return errors.New("invalid version") + } + + c.version = version + return nil +} + +type containerLORA struct { + version uint32 +} + +func (c *containerLORA) Name() string { + return "ggla" +} + +func (c *containerLORA) Decode(r io.Reader) error { + var version uint32 + binary.Read(r, binary.LittleEndian, &version) + + switch version { + case 1: + default: + return errors.New("invalid version") + } + + c.version = version + return nil +} + +const ( + // / Magic constant for `ggml` files (unversioned). + FILE_MAGIC_GGML = 0x67676d6c + // / Magic constant for `ggml` files (versioned, ggmf). + FILE_MAGIC_GGMF = 0x67676d66 + // / Magic constant for `ggml` files (versioned, ggjt). + FILE_MAGIC_GGJT = 0x67676a74 + // / Magic constant for `ggla` files (LoRA adapter). + FILE_MAGIC_GGLA = 0x67676C61 +) + +func DecodeGGML(r io.ReadSeeker, hint ModelFamily) (*GGML, error) { + var ggml GGML + binary.Read(r, binary.LittleEndian, &ggml.magic) + + switch ggml.magic { + case FILE_MAGIC_GGML: + ggml.container = &containerGGML{} + case FILE_MAGIC_GGMF: + ggml.container = &containerGGMF{} + case FILE_MAGIC_GGJT: + ggml.container = &containerGGJT{} + case FILE_MAGIC_GGLA: + ggml.container = &containerLORA{} + default: + return nil, errors.New("invalid file magic") + } + + if err := ggml.Decode(r); err != nil { + return nil, err + } + + // different model types may have different layouts for hyperparameters + switch hint { + case ModelFamilyLlama: + binary.Read(r, binary.LittleEndian, &ggml.llamaHyperparameters) + // TODO: sanity check hyperparameters + default: + return nil, fmt.Errorf("unsupported model type: %s", hint) + } + + // final model type + ggml.ModelFamily = hint + ggml.ModelType = ModelType(ggml.NumLayer) + return &ggml, nil +} diff --git a/llama/ggml.h b/llm/ggml.h similarity index 100% rename from llama/ggml.h rename to llm/ggml.h diff --git a/llama/k_quants.c b/llm/k_quants.c similarity index 100% rename from llama/k_quants.c rename to llm/k_quants.c diff --git a/llama/k_quants.h b/llm/k_quants.h similarity index 100% rename from llama/k_quants.h rename to llm/k_quants.h diff --git a/llama/llama-util.h b/llm/llama-util.h similarity index 100% rename from llama/llama-util.h rename to llm/llama-util.h diff --git a/llama/llama.cpp b/llm/llama.cpp similarity index 100% rename from llama/llama.cpp rename to llm/llama.cpp diff --git a/llama/llama.go b/llm/llama.go similarity index 80% rename from llama/llama.go rename to llm/llama.go index aba6c513..dead6b70 100644 --- a/llama/llama.go +++ b/llm/llama.go @@ -1,4 +1,4 @@ -package llama +package llm /* #cgo CPPFLAGS: -O3 -Wall -Wextra -Wno-unused-function -Wno-unused-variable -DNDEBUG -DGGML_USE_K_QUANTS @@ -105,7 +105,7 @@ import ( //go:embed ggml-metal.metal var fs embed.FS -type LLM struct { +type llama struct { params *C.struct_llama_context_params model *C.struct_llama_model ctx *C.struct_llama_context @@ -120,12 +120,28 @@ type LLM struct { api.Options } -func New(model string, opts api.Options) (*LLM, error) { +type llamaHyperparameters struct { + // NumVocab is the size of the model's vocabulary. + NumVocab uint32 + + // NumEmbd is the size of the model's embedding layer. + NumEmbd uint32 + NumMult uint32 + NumHead uint32 + + // NumLayer is the number of layers in the model. + NumLayer uint32 + NumRot uint32 + // FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc. + FileType +} + +func newLlama(model string, opts api.Options) (*llama, error) { if _, err := os.Stat(model); err != nil { return nil, err } - llm := LLM{Options: opts} + llm := llama{Options: opts} C.llama_backend_init(C.bool(llm.UseNUMA)) @@ -168,7 +184,7 @@ func New(model string, opts api.Options) (*LLM, error) { return &llm, nil } -func (llm *LLM) Close() { +func (llm *llama) Close() { llm.gc = true llm.mu.Lock() @@ -180,17 +196,16 @@ func (llm *LLM) Close() { C.llama_print_timings(llm.ctx) } +func (llm *llama) SetOptions(opts api.Options) { + llm.Options = opts +} + var errNeedMoreData = errors.New("need more data") -func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error { +func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error { C.llama_reset_timings(llm.ctx) - tokens := make([]C.llama_token, len(ctx)) - for i := range tokens { - tokens[i] = C.llama_token(ctx[i]) - } - - llm.marshalPrompt(tokens, prompt) + llm.marshalPrompt(ctx, prompt) C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed)) @@ -205,7 +220,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) return err } - b.WriteString(llm.Decode(token)) + b.WriteString(llm.Decode(int(token))) if err := llm.checkStopConditions(b); err != nil { if errors.Is(err, io.EOF) { @@ -243,7 +258,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) return nil } -func (llm *LLM) checkStopConditions(b bytes.Buffer) error { +func (llm *llama) checkStopConditions(b bytes.Buffer) error { for _, stopCondition := range llm.Stop { if stopCondition == strings.TrimSpace(b.String()) { return io.EOF @@ -255,12 +270,17 @@ func (llm *LLM) checkStopConditions(b bytes.Buffer) error { return nil } -func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token { +func (llm *llama) marshalPrompt(ctx []int, prompt string) []C.llama_token { tokens := append(ctx, llm.Encode(prompt)...) if llm.NumKeep < 0 { llm.NumKeep = len(tokens) } + cTokens := make([]C.llama_token, len(tokens)) + for i := range tokens { + cTokens[i] = C.llama_token(tokens[i]) + } + // min(llm.NumCtx - 4, llm.NumKeep) if llm.NumCtx-4 < llm.NumKeep { llm.NumKeep = llm.NumCtx - 4 @@ -269,25 +289,25 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke if len(tokens) >= llm.NumCtx { // truncate input numLeft := (llm.NumCtx - llm.NumKeep) / 2 - truncated := tokens[:llm.NumKeep] - erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft - truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...) - copy(llm.last, tokens[len(tokens)-llm.NumCtx:]) + truncated := cTokens[:llm.NumKeep] + erasedBlocks := (len(cTokens) - llm.NumKeep - numLeft - 1) / numLeft + truncated = append(truncated, cTokens[llm.NumKeep+erasedBlocks*numLeft:]...) + copy(llm.last, cTokens[len(cTokens)-llm.NumCtx:]) - tokens = truncated + cTokens = truncated log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated)) } else { - llm.last = make([]C.llama_token, llm.NumCtx-len(tokens)) - llm.last = append(llm.last, tokens...) + llm.last = make([]C.llama_token, llm.NumCtx-len(cTokens)) + llm.last = append(llm.last, cTokens...) } var i int - for i = 0; i < len(llm.embd) && i < len(tokens) && llm.embd[i] == tokens[i]; i++ { + for i = 0; i < len(llm.embd) && i < len(cTokens) && llm.embd[i] == cTokens[i]; i++ { // noop } - llm.embd = tokens - if i == len(tokens) { + llm.embd = cTokens + if i == len(cTokens) { // evaluate at least one token to generate logits i-- } @@ -295,31 +315,36 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke llm.cursor = i log.Printf("prompt: num_past=%d cached=%v eval=%v", i, len(llm.embd[:i]), len(llm.embd[i:])) - return tokens + return cTokens } -func (llm *LLM) Encode(prompt string) []C.llama_token { +func (llm *llama) Encode(prompt string) []int { cPrompt := C.CString(prompt) defer C.free(unsafe.Pointer(cPrompt)) - tokens := make([]C.llama_token, len(prompt)+1) - if n := C.llama_tokenize(llm.ctx, cPrompt, unsafe.SliceData(tokens), C.int(len(tokens)), true); n > 0 { - return tokens[:n] + cTokens := make([]C.llama_token, len(prompt)+1) + if n := C.llama_tokenize(llm.ctx, cPrompt, unsafe.SliceData(cTokens), C.int(len(cTokens)), true); n > 0 { + tokens := make([]int, n) + for i := range cTokens[:n] { + tokens[i] = int(cTokens[i]) + } + + return tokens } return nil } -func (llm *LLM) Decode(tokens ...C.llama_token) string { +func (llm *llama) Decode(tokens ...int) string { var sb strings.Builder for _, token := range tokens { - sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token))) + sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, C.llama_token(token)))) } return sb.String() } -func (llm *LLM) next() (C.llama_token, error) { +func (llm *llama) next() (C.llama_token, error) { llm.mu.Lock() defer llm.mu.Unlock() @@ -410,7 +435,7 @@ func (llm *LLM) next() (C.llama_token, error) { return token, nil } -func (llm *LLM) Embedding(input string) ([]float64, error) { +func (llm *llama) Embedding(input string) ([]float64, error) { if !llm.EmbeddingOnly { return nil, errors.New("llama: embedding not enabled") } @@ -420,7 +445,12 @@ func (llm *LLM) Embedding(input string) ([]float64, error) { return nil, errors.New("llama: tokenize embedding") } - retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), 0, C.int(llm.NumThread)) + cTokens := make([]C.llama_token, len(tokens)) + for i := range tokens { + cTokens[i] = C.llama_token(tokens[i]) + } + + retval := C.llama_eval(llm.ctx, unsafe.SliceData(cTokens), C.int(len(tokens)), 0, C.int(llm.NumThread)) if retval != 0 { return nil, errors.New("llama: eval") } diff --git a/llama/llama.h b/llm/llama.h similarity index 100% rename from llama/llama.h rename to llm/llama.h diff --git a/llama/llama_darwin.go b/llm/llama_darwin.go similarity index 98% rename from llama/llama_darwin.go rename to llm/llama_darwin.go index 7a3ed43b..19e7b9e1 100644 --- a/llama/llama_darwin.go +++ b/llm/llama_darwin.go @@ -1,4 +1,4 @@ -package llama +package llm import ( "bytes" diff --git a/llm/llm.go b/llm/llm.go new file mode 100644 index 00000000..b537865e --- /dev/null +++ b/llm/llm.go @@ -0,0 +1,40 @@ +package llm + +import ( + "fmt" + "os" + + "github.com/jmorganca/ollama/api" +) + +type LLM interface { + Predict([]int, string, func(api.GenerateResponse)) error + Embedding(string) ([]float64, error) + Encode(string) []int + Decode(...int) string + SetOptions(api.Options) + Close() +} + +func New(model string, opts api.Options) (LLM, error) { + if _, err := os.Stat(model); err != nil { + return nil, err + } + + f, err := os.Open(model) + if err != nil { + return nil, err + } + + ggml, err := DecodeGGML(f, ModelFamilyLlama) + if err != nil { + return nil, err + } + + switch ggml.ModelFamily { + case ModelFamilyLlama: + return newLlama(model, opts) + default: + return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily) + } +} diff --git a/llama/update-llama-cpp.sh b/llm/update-llama-cpp.sh similarity index 100% rename from llama/update-llama-cpp.sh rename to llm/update-llama-cpp.sh diff --git a/llama/utils.go b/llm/utils.go similarity index 92% rename from llama/utils.go rename to llm/utils.go index 8b52ad5c..4dc03c80 100644 --- a/llama/utils.go +++ b/llm/utils.go @@ -1,4 +1,4 @@ -package llama +package llm import ( "fmt" diff --git a/server/images.go b/server/images.go index 2ec24854..f179b233 100644 --- a/server/images.go +++ b/server/images.go @@ -19,7 +19,7 @@ import ( "strings" "github.com/jmorganca/ollama/api" - "github.com/jmorganca/ollama/llama" + "github.com/jmorganca/ollama/llm" "github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/vector" ) @@ -98,9 +98,14 @@ type LayerReader struct { } type ConfigV2 struct { + ModelFamily llm.ModelFamily `json:"model_family"` + ModelType llm.ModelType `json:"model_type"` + FileType llm.FileType `json:"file_type"` + RootFS RootFS `json:"rootfs"` + + // required by spec Architecture string `json:"architecture"` OS string `json:"os"` - RootFS RootFS `json:"rootfs"` } type RootFS struct { @@ -245,6 +250,11 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api return err } + config := ConfigV2{ + Architecture: "amd64", + OS: "linux", + } + var layers []*LayerReader params := make(map[string][]string) embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()} @@ -283,6 +293,18 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api } defer file.Close() + ggml, err := llm.DecodeGGML(file, llm.ModelFamilyLlama) + if err != nil { + return err + } + + config.ModelFamily = ggml.ModelFamily + config.ModelType = ggml.ModelType + config.FileType = ggml.FileType + + // reset the file + file.Seek(0, io.SeekStart) + l, err := CreateLayer(file) if err != nil { return fmt.Errorf("failed to create layer: %v", err) @@ -291,6 +313,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api layers = append(layers, l) } } + if mf != nil { log.Printf("manifest = %#v", mf) for _, l := range mf.Layers { @@ -320,7 +343,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api layers = append(layers, layer) case "template", "system", "prompt": fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) - // remove the prompt layer if one exists + // remove the layer if one exists mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) layers = removeLayerFromLayers(layers, mediaType) @@ -382,7 +405,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api // Create a layer for the config object fn(api.ProgressResponse{Status: "creating config layer"}) - cfg, err := createConfigLayer(digests) + cfg, err := createConfigLayer(config, digests) if err != nil { return err } @@ -429,13 +452,13 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) { } e.opts.EmbeddingOnly = true - llm, err := llama.New(e.model, e.opts) + llmModel, err := llm.New(e.model, e.opts) if err != nil { return nil, fmt.Errorf("load model to generate embeddings: %v", err) } defer func() { - if llm != nil { - llm.Close() + if llmModel != nil { + llmModel.Close() } }() @@ -479,7 +502,7 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) { Total: len(data) - 1, Completed: i, }) - embed, err := llm.Embedding(d) + embed, err := llmModel.Embedding(d) if err != nil { log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err) continue @@ -675,7 +698,7 @@ func getLayerDigests(layers []*LayerReader) ([]string, error) { // CreateLayer creates a Layer object from a given file func CreateLayer(f io.ReadSeeker) (*LayerReader, error) { digest, size := GetSHA256Digest(f) - f.Seek(0, 0) + f.Seek(0, io.SeekStart) layer := &LayerReader{ Layer: Layer{ @@ -767,10 +790,6 @@ func DeleteModel(name string) error { return err } - if err != nil { - return err - } - // only delete the files which are still in the deleteMap for k, v := range deleteMap { if v { @@ -969,15 +988,10 @@ func pullModelManifest(mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, err return m, err } -func createConfigLayer(layers []string) (*LayerReader, error) { - // TODO change architecture and OS - config := ConfigV2{ - Architecture: "arm64", - OS: "linux", - RootFS: RootFS{ - Type: "layers", - DiffIDs: layers, - }, +func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) { + config.RootFS = RootFS{ + Type: "layers", + DiffIDs: layers, } configJSON, err := json.Marshal(config) diff --git a/server/routes.go b/server/routes.go index 14040332..0f058a26 100644 --- a/server/routes.go +++ b/server/routes.go @@ -21,14 +21,14 @@ import ( "gonum.org/v1/gonum/mat" "github.com/jmorganca/ollama/api" - "github.com/jmorganca/ollama/llama" + "github.com/jmorganca/ollama/llm" "github.com/jmorganca/ollama/vector" ) var loaded struct { mu sync.Mutex - llm *llama.LLM + llm llm.LLM Embeddings []vector.Embedding expireAt time.Time @@ -63,11 +63,16 @@ func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Dur loaded.Embeddings = model.Embeddings } - llm, err := llama.New(model.ModelPath, opts) + llmModel, err := llm.New(model.ModelPath, opts) if err != nil { return err } + // set cache values before modifying opts + loaded.llm = llmModel + loaded.digest = model.Digest + loaded.options = opts + if opts.NumKeep < 0 { promptWithSystem, err := model.Prompt(api.GenerateRequest{}, "") if err != nil { @@ -79,15 +84,13 @@ func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Dur return err } - tokensWithSystem := llm.Encode(promptWithSystem) - tokensNoSystem := llm.Encode(promptNoSystem) + tokensWithSystem := llmModel.Encode(promptWithSystem) + tokensNoSystem := llmModel.Encode(promptNoSystem) - llm.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) + 1 + opts.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) + 1 + + llmModel.SetOptions(opts) } - - loaded.llm = llm - loaded.digest = model.Digest - loaded.options = opts } loaded.expireAt = time.Now().Add(sessionDuration)