partial decode ggml bin for more info
This commit is contained in:
parent
5b5cc9c9f1
commit
fccf8d179f
26 changed files with 336 additions and 69 deletions
180
llm/ggml.go
Normal file
180
llm/ggml.go
Normal file
|
@ -0,0 +1,180 @@
|
|||
package llm
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
type ModelFamily string
|
||||
|
||||
const ModelFamilyLlama ModelFamily = "llama"
|
||||
|
||||
type ModelType uint32
|
||||
|
||||
const (
|
||||
ModelType3B ModelType = 26
|
||||
ModelType7B ModelType = 32
|
||||
ModelType13B ModelType = 40
|
||||
ModelType30B ModelType = 60
|
||||
ModelType65B ModelType = 80
|
||||
)
|
||||
|
||||
type FileType uint32
|
||||
|
||||
const (
|
||||
FileTypeF32 FileType = iota
|
||||
FileTypeF16
|
||||
FileTypeQ4_0
|
||||
FileTypeQ4_1
|
||||
FileTypeQ4_1_F16
|
||||
FileTypeQ8_0 = iota + 3
|
||||
FileTypeQ5_0
|
||||
FileTypeQ5_1
|
||||
FileTypeQ2_K
|
||||
FileTypeQ3_K
|
||||
FileTypeQ4_K
|
||||
FileTypeQ5_K
|
||||
FileTypeQ6_K
|
||||
FileTypeUnknown = -1
|
||||
)
|
||||
|
||||
type GGML struct {
|
||||
ModelFamily
|
||||
ModelType
|
||||
|
||||
magic uint32
|
||||
container
|
||||
|
||||
llamaHyperparameters
|
||||
}
|
||||
|
||||
type container interface {
|
||||
Name() string
|
||||
Decode(io.Reader) error
|
||||
}
|
||||
|
||||
type containerGGML struct {
|
||||
}
|
||||
|
||||
func (c *containerGGML) Name() string {
|
||||
return "ggml"
|
||||
}
|
||||
|
||||
func (c *containerGGML) Decode(r io.Reader) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type containerGGMF struct {
|
||||
version uint32
|
||||
}
|
||||
|
||||
func (c *containerGGMF) Name() string {
|
||||
return "ggmf"
|
||||
}
|
||||
|
||||
func (c *containerGGMF) Decode(r io.Reader) error {
|
||||
var version uint32
|
||||
binary.Read(r, binary.LittleEndian, &version)
|
||||
|
||||
switch version {
|
||||
case 1:
|
||||
default:
|
||||
return errors.New("invalid version")
|
||||
}
|
||||
|
||||
c.version = version
|
||||
return nil
|
||||
}
|
||||
|
||||
type containerGGJT struct {
|
||||
version uint32
|
||||
}
|
||||
|
||||
func (c *containerGGJT) Name() string {
|
||||
return "ggjt"
|
||||
}
|
||||
|
||||
func (c *containerGGJT) Decode(r io.Reader) error {
|
||||
var version uint32
|
||||
binary.Read(r, binary.LittleEndian, &version)
|
||||
|
||||
switch version {
|
||||
case 1, 2, 3:
|
||||
default:
|
||||
return errors.New("invalid version")
|
||||
}
|
||||
|
||||
c.version = version
|
||||
return nil
|
||||
}
|
||||
|
||||
type containerLORA struct {
|
||||
version uint32
|
||||
}
|
||||
|
||||
func (c *containerLORA) Name() string {
|
||||
return "ggla"
|
||||
}
|
||||
|
||||
func (c *containerLORA) Decode(r io.Reader) error {
|
||||
var version uint32
|
||||
binary.Read(r, binary.LittleEndian, &version)
|
||||
|
||||
switch version {
|
||||
case 1:
|
||||
default:
|
||||
return errors.New("invalid version")
|
||||
}
|
||||
|
||||
c.version = version
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
// / Magic constant for `ggml` files (unversioned).
|
||||
FILE_MAGIC_GGML = 0x67676d6c
|
||||
// / Magic constant for `ggml` files (versioned, ggmf).
|
||||
FILE_MAGIC_GGMF = 0x67676d66
|
||||
// / Magic constant for `ggml` files (versioned, ggjt).
|
||||
FILE_MAGIC_GGJT = 0x67676a74
|
||||
// / Magic constant for `ggla` files (LoRA adapter).
|
||||
FILE_MAGIC_GGLA = 0x67676C61
|
||||
)
|
||||
|
||||
func DecodeGGML(r io.ReadSeeker, hint ModelFamily) (*GGML, error) {
|
||||
var ggml GGML
|
||||
binary.Read(r, binary.LittleEndian, &ggml.magic)
|
||||
|
||||
switch ggml.magic {
|
||||
case FILE_MAGIC_GGML:
|
||||
ggml.container = &containerGGML{}
|
||||
case FILE_MAGIC_GGMF:
|
||||
ggml.container = &containerGGMF{}
|
||||
case FILE_MAGIC_GGJT:
|
||||
ggml.container = &containerGGJT{}
|
||||
case FILE_MAGIC_GGLA:
|
||||
ggml.container = &containerLORA{}
|
||||
default:
|
||||
return nil, errors.New("invalid file magic")
|
||||
}
|
||||
|
||||
if err := ggml.Decode(r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// different model types may have different layouts for hyperparameters
|
||||
switch hint {
|
||||
case ModelFamilyLlama:
|
||||
binary.Read(r, binary.LittleEndian, &ggml.llamaHyperparameters)
|
||||
// TODO: sanity check hyperparameters
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported model type: %s", hint)
|
||||
}
|
||||
|
||||
// final model type
|
||||
ggml.ModelFamily = hint
|
||||
ggml.ModelType = ModelType(ggml.NumLayer)
|
||||
return &ggml, nil
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package llama
|
||||
package llm
|
||||
|
||||
/*
|
||||
#cgo CPPFLAGS: -O3 -Wall -Wextra -Wno-unused-function -Wno-unused-variable -DNDEBUG -DGGML_USE_K_QUANTS
|
||||
|
@ -105,7 +105,7 @@ import (
|
|||
//go:embed ggml-metal.metal
|
||||
var fs embed.FS
|
||||
|
||||
type LLM struct {
|
||||
type llama struct {
|
||||
params *C.struct_llama_context_params
|
||||
model *C.struct_llama_model
|
||||
ctx *C.struct_llama_context
|
||||
|
@ -120,12 +120,28 @@ type LLM struct {
|
|||
api.Options
|
||||
}
|
||||
|
||||
func New(model string, opts api.Options) (*LLM, error) {
|
||||
type llamaHyperparameters struct {
|
||||
// NumVocab is the size of the model's vocabulary.
|
||||
NumVocab uint32
|
||||
|
||||
// NumEmbd is the size of the model's embedding layer.
|
||||
NumEmbd uint32
|
||||
NumMult uint32
|
||||
NumHead uint32
|
||||
|
||||
// NumLayer is the number of layers in the model.
|
||||
NumLayer uint32
|
||||
NumRot uint32
|
||||
// FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
|
||||
FileType
|
||||
}
|
||||
|
||||
func newLlama(model string, opts api.Options) (*llama, error) {
|
||||
if _, err := os.Stat(model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
llm := LLM{Options: opts}
|
||||
llm := llama{Options: opts}
|
||||
|
||||
C.llama_backend_init(C.bool(llm.UseNUMA))
|
||||
|
||||
|
@ -168,7 +184,7 @@ func New(model string, opts api.Options) (*LLM, error) {
|
|||
return &llm, nil
|
||||
}
|
||||
|
||||
func (llm *LLM) Close() {
|
||||
func (llm *llama) Close() {
|
||||
llm.gc = true
|
||||
|
||||
llm.mu.Lock()
|
||||
|
@ -180,17 +196,16 @@ func (llm *LLM) Close() {
|
|||
C.llama_print_timings(llm.ctx)
|
||||
}
|
||||
|
||||
func (llm *llama) SetOptions(opts api.Options) {
|
||||
llm.Options = opts
|
||||
}
|
||||
|
||||
var errNeedMoreData = errors.New("need more data")
|
||||
|
||||
func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
||||
func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
||||
C.llama_reset_timings(llm.ctx)
|
||||
|
||||
tokens := make([]C.llama_token, len(ctx))
|
||||
for i := range tokens {
|
||||
tokens[i] = C.llama_token(ctx[i])
|
||||
}
|
||||
|
||||
llm.marshalPrompt(tokens, prompt)
|
||||
llm.marshalPrompt(ctx, prompt)
|
||||
|
||||
C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed))
|
||||
|
||||
|
@ -205,7 +220,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
|
|||
return err
|
||||
}
|
||||
|
||||
b.WriteString(llm.Decode(token))
|
||||
b.WriteString(llm.Decode(int(token)))
|
||||
|
||||
if err := llm.checkStopConditions(b); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
|
@ -243,7 +258,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
|
|||
return nil
|
||||
}
|
||||
|
||||
func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
|
||||
func (llm *llama) checkStopConditions(b bytes.Buffer) error {
|
||||
for _, stopCondition := range llm.Stop {
|
||||
if stopCondition == strings.TrimSpace(b.String()) {
|
||||
return io.EOF
|
||||
|
@ -255,12 +270,17 @@ func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token {
|
||||
func (llm *llama) marshalPrompt(ctx []int, prompt string) []C.llama_token {
|
||||
tokens := append(ctx, llm.Encode(prompt)...)
|
||||
if llm.NumKeep < 0 {
|
||||
llm.NumKeep = len(tokens)
|
||||
}
|
||||
|
||||
cTokens := make([]C.llama_token, len(tokens))
|
||||
for i := range tokens {
|
||||
cTokens[i] = C.llama_token(tokens[i])
|
||||
}
|
||||
|
||||
// min(llm.NumCtx - 4, llm.NumKeep)
|
||||
if llm.NumCtx-4 < llm.NumKeep {
|
||||
llm.NumKeep = llm.NumCtx - 4
|
||||
|
@ -269,25 +289,25 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke
|
|||
if len(tokens) >= llm.NumCtx {
|
||||
// truncate input
|
||||
numLeft := (llm.NumCtx - llm.NumKeep) / 2
|
||||
truncated := tokens[:llm.NumKeep]
|
||||
erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft
|
||||
truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...)
|
||||
copy(llm.last, tokens[len(tokens)-llm.NumCtx:])
|
||||
truncated := cTokens[:llm.NumKeep]
|
||||
erasedBlocks := (len(cTokens) - llm.NumKeep - numLeft - 1) / numLeft
|
||||
truncated = append(truncated, cTokens[llm.NumKeep+erasedBlocks*numLeft:]...)
|
||||
copy(llm.last, cTokens[len(cTokens)-llm.NumCtx:])
|
||||
|
||||
tokens = truncated
|
||||
cTokens = truncated
|
||||
log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated))
|
||||
} else {
|
||||
llm.last = make([]C.llama_token, llm.NumCtx-len(tokens))
|
||||
llm.last = append(llm.last, tokens...)
|
||||
llm.last = make([]C.llama_token, llm.NumCtx-len(cTokens))
|
||||
llm.last = append(llm.last, cTokens...)
|
||||
}
|
||||
|
||||
var i int
|
||||
for i = 0; i < len(llm.embd) && i < len(tokens) && llm.embd[i] == tokens[i]; i++ {
|
||||
for i = 0; i < len(llm.embd) && i < len(cTokens) && llm.embd[i] == cTokens[i]; i++ {
|
||||
// noop
|
||||
}
|
||||
|
||||
llm.embd = tokens
|
||||
if i == len(tokens) {
|
||||
llm.embd = cTokens
|
||||
if i == len(cTokens) {
|
||||
// evaluate at least one token to generate logits
|
||||
i--
|
||||
}
|
||||
|
@ -295,31 +315,36 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke
|
|||
llm.cursor = i
|
||||
|
||||
log.Printf("prompt: num_past=%d cached=%v eval=%v", i, len(llm.embd[:i]), len(llm.embd[i:]))
|
||||
return tokens
|
||||
return cTokens
|
||||
}
|
||||
|
||||
func (llm *LLM) Encode(prompt string) []C.llama_token {
|
||||
func (llm *llama) Encode(prompt string) []int {
|
||||
cPrompt := C.CString(prompt)
|
||||
defer C.free(unsafe.Pointer(cPrompt))
|
||||
|
||||
tokens := make([]C.llama_token, len(prompt)+1)
|
||||
if n := C.llama_tokenize(llm.ctx, cPrompt, unsafe.SliceData(tokens), C.int(len(tokens)), true); n > 0 {
|
||||
return tokens[:n]
|
||||
cTokens := make([]C.llama_token, len(prompt)+1)
|
||||
if n := C.llama_tokenize(llm.ctx, cPrompt, unsafe.SliceData(cTokens), C.int(len(cTokens)), true); n > 0 {
|
||||
tokens := make([]int, n)
|
||||
for i := range cTokens[:n] {
|
||||
tokens[i] = int(cTokens[i])
|
||||
}
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *LLM) Decode(tokens ...C.llama_token) string {
|
||||
func (llm *llama) Decode(tokens ...int) string {
|
||||
var sb strings.Builder
|
||||
for _, token := range tokens {
|
||||
sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token)))
|
||||
sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, C.llama_token(token))))
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (llm *LLM) next() (C.llama_token, error) {
|
||||
func (llm *llama) next() (C.llama_token, error) {
|
||||
llm.mu.Lock()
|
||||
defer llm.mu.Unlock()
|
||||
|
||||
|
@ -410,7 +435,7 @@ func (llm *LLM) next() (C.llama_token, error) {
|
|||
return token, nil
|
||||
}
|
||||
|
||||
func (llm *LLM) Embedding(input string) ([]float64, error) {
|
||||
func (llm *llama) Embedding(input string) ([]float64, error) {
|
||||
if !llm.EmbeddingOnly {
|
||||
return nil, errors.New("llama: embedding not enabled")
|
||||
}
|
||||
|
@ -420,7 +445,12 @@ func (llm *LLM) Embedding(input string) ([]float64, error) {
|
|||
return nil, errors.New("llama: tokenize embedding")
|
||||
}
|
||||
|
||||
retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), 0, C.int(llm.NumThread))
|
||||
cTokens := make([]C.llama_token, len(tokens))
|
||||
for i := range tokens {
|
||||
cTokens[i] = C.llama_token(tokens[i])
|
||||
}
|
||||
|
||||
retval := C.llama_eval(llm.ctx, unsafe.SliceData(cTokens), C.int(len(tokens)), 0, C.int(llm.NumThread))
|
||||
if retval != 0 {
|
||||
return nil, errors.New("llama: eval")
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package llama
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
40
llm/llm.go
Normal file
40
llm/llm.go
Normal file
|
@ -0,0 +1,40 @@
|
|||
package llm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
type LLM interface {
|
||||
Predict([]int, string, func(api.GenerateResponse)) error
|
||||
Embedding(string) ([]float64, error)
|
||||
Encode(string) []int
|
||||
Decode(...int) string
|
||||
SetOptions(api.Options)
|
||||
Close()
|
||||
}
|
||||
|
||||
func New(model string, opts api.Options) (LLM, error) {
|
||||
if _, err := os.Stat(model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := os.Open(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ggml, err := DecodeGGML(f, ModelFamilyLlama)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch ggml.ModelFamily {
|
||||
case ModelFamilyLlama:
|
||||
return newLlama(model, opts)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily)
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package llama
|
||||
package llm
|
||||
|
||||
import (
|
||||
"fmt"
|
|
@ -19,7 +19,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/llama"
|
||||
"github.com/jmorganca/ollama/llm"
|
||||
"github.com/jmorganca/ollama/parser"
|
||||
"github.com/jmorganca/ollama/vector"
|
||||
)
|
||||
|
@ -98,9 +98,14 @@ type LayerReader struct {
|
|||
}
|
||||
|
||||
type ConfigV2 struct {
|
||||
ModelFamily llm.ModelFamily `json:"model_family"`
|
||||
ModelType llm.ModelType `json:"model_type"`
|
||||
FileType llm.FileType `json:"file_type"`
|
||||
RootFS RootFS `json:"rootfs"`
|
||||
|
||||
// required by spec
|
||||
Architecture string `json:"architecture"`
|
||||
OS string `json:"os"`
|
||||
RootFS RootFS `json:"rootfs"`
|
||||
}
|
||||
|
||||
type RootFS struct {
|
||||
|
@ -245,6 +250,11 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
|||
return err
|
||||
}
|
||||
|
||||
config := ConfigV2{
|
||||
Architecture: "amd64",
|
||||
OS: "linux",
|
||||
}
|
||||
|
||||
var layers []*LayerReader
|
||||
params := make(map[string][]string)
|
||||
embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()}
|
||||
|
@ -283,6 +293,18 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
|||
}
|
||||
defer file.Close()
|
||||
|
||||
ggml, err := llm.DecodeGGML(file, llm.ModelFamilyLlama)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config.ModelFamily = ggml.ModelFamily
|
||||
config.ModelType = ggml.ModelType
|
||||
config.FileType = ggml.FileType
|
||||
|
||||
// reset the file
|
||||
file.Seek(0, io.SeekStart)
|
||||
|
||||
l, err := CreateLayer(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create layer: %v", err)
|
||||
|
@ -291,6 +313,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
|||
layers = append(layers, l)
|
||||
}
|
||||
}
|
||||
|
||||
if mf != nil {
|
||||
log.Printf("manifest = %#v", mf)
|
||||
for _, l := range mf.Layers {
|
||||
|
@ -320,7 +343,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
|||
layers = append(layers, layer)
|
||||
case "template", "system", "prompt":
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
|
||||
// remove the prompt layer if one exists
|
||||
// remove the layer if one exists
|
||||
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
|
||||
layers = removeLayerFromLayers(layers, mediaType)
|
||||
|
||||
|
@ -382,7 +405,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
|||
|
||||
// Create a layer for the config object
|
||||
fn(api.ProgressResponse{Status: "creating config layer"})
|
||||
cfg, err := createConfigLayer(digests)
|
||||
cfg, err := createConfigLayer(config, digests)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -429,13 +452,13 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
|
|||
}
|
||||
|
||||
e.opts.EmbeddingOnly = true
|
||||
llm, err := llama.New(e.model, e.opts)
|
||||
llmModel, err := llm.New(e.model, e.opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if llm != nil {
|
||||
llm.Close()
|
||||
if llmModel != nil {
|
||||
llmModel.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -479,7 +502,7 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
|
|||
Total: len(data) - 1,
|
||||
Completed: i,
|
||||
})
|
||||
embed, err := llm.Embedding(d)
|
||||
embed, err := llmModel.Embedding(d)
|
||||
if err != nil {
|
||||
log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
|
||||
continue
|
||||
|
@ -675,7 +698,7 @@ func getLayerDigests(layers []*LayerReader) ([]string, error) {
|
|||
// CreateLayer creates a Layer object from a given file
|
||||
func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
|
||||
digest, size := GetSHA256Digest(f)
|
||||
f.Seek(0, 0)
|
||||
f.Seek(0, io.SeekStart)
|
||||
|
||||
layer := &LayerReader{
|
||||
Layer: Layer{
|
||||
|
@ -767,10 +790,6 @@ func DeleteModel(name string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// only delete the files which are still in the deleteMap
|
||||
for k, v := range deleteMap {
|
||||
if v {
|
||||
|
@ -969,15 +988,10 @@ func pullModelManifest(mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, err
|
|||
return m, err
|
||||
}
|
||||
|
||||
func createConfigLayer(layers []string) (*LayerReader, error) {
|
||||
// TODO change architecture and OS
|
||||
config := ConfigV2{
|
||||
Architecture: "arm64",
|
||||
OS: "linux",
|
||||
RootFS: RootFS{
|
||||
func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
|
||||
config.RootFS = RootFS{
|
||||
Type: "layers",
|
||||
DiffIDs: layers,
|
||||
},
|
||||
}
|
||||
|
||||
configJSON, err := json.Marshal(config)
|
||||
|
|
|
@ -21,14 +21,14 @@ import (
|
|||
"gonum.org/v1/gonum/mat"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/llama"
|
||||
"github.com/jmorganca/ollama/llm"
|
||||
"github.com/jmorganca/ollama/vector"
|
||||
)
|
||||
|
||||
var loaded struct {
|
||||
mu sync.Mutex
|
||||
|
||||
llm *llama.LLM
|
||||
llm llm.LLM
|
||||
Embeddings []vector.Embedding
|
||||
|
||||
expireAt time.Time
|
||||
|
@ -63,11 +63,16 @@ func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Dur
|
|||
loaded.Embeddings = model.Embeddings
|
||||
}
|
||||
|
||||
llm, err := llama.New(model.ModelPath, opts)
|
||||
llmModel, err := llm.New(model.ModelPath, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// set cache values before modifying opts
|
||||
loaded.llm = llmModel
|
||||
loaded.digest = model.Digest
|
||||
loaded.options = opts
|
||||
|
||||
if opts.NumKeep < 0 {
|
||||
promptWithSystem, err := model.Prompt(api.GenerateRequest{}, "")
|
||||
if err != nil {
|
||||
|
@ -79,15 +84,13 @@ func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Dur
|
|||
return err
|
||||
}
|
||||
|
||||
tokensWithSystem := llm.Encode(promptWithSystem)
|
||||
tokensNoSystem := llm.Encode(promptNoSystem)
|
||||
tokensWithSystem := llmModel.Encode(promptWithSystem)
|
||||
tokensNoSystem := llmModel.Encode(promptNoSystem)
|
||||
|
||||
llm.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) + 1
|
||||
opts.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) + 1
|
||||
|
||||
llmModel.SetOptions(opts)
|
||||
}
|
||||
|
||||
loaded.llm = llm
|
||||
loaded.digest = model.Digest
|
||||
loaded.options = opts
|
||||
}
|
||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||
|
||||
|
|
Loading…
Reference in a new issue