Merge pull request #167 from jmorganca/decode-ggml
partial decode ggml bin for more info
This commit is contained in:
commit
6a6828bddf
26 changed files with 336 additions and 69 deletions
180
llm/ggml.go
Normal file
180
llm/ggml.go
Normal file
|
@ -0,0 +1,180 @@
|
||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ModelFamily string
|
||||||
|
|
||||||
|
const ModelFamilyLlama ModelFamily = "llama"
|
||||||
|
|
||||||
|
type ModelType uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
ModelType3B ModelType = 26
|
||||||
|
ModelType7B ModelType = 32
|
||||||
|
ModelType13B ModelType = 40
|
||||||
|
ModelType30B ModelType = 60
|
||||||
|
ModelType65B ModelType = 80
|
||||||
|
)
|
||||||
|
|
||||||
|
type FileType uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
FileTypeF32 FileType = iota
|
||||||
|
FileTypeF16
|
||||||
|
FileTypeQ4_0
|
||||||
|
FileTypeQ4_1
|
||||||
|
FileTypeQ4_1_F16
|
||||||
|
FileTypeQ8_0 = iota + 3
|
||||||
|
FileTypeQ5_0
|
||||||
|
FileTypeQ5_1
|
||||||
|
FileTypeQ2_K
|
||||||
|
FileTypeQ3_K
|
||||||
|
FileTypeQ4_K
|
||||||
|
FileTypeQ5_K
|
||||||
|
FileTypeQ6_K
|
||||||
|
FileTypeUnknown = -1
|
||||||
|
)
|
||||||
|
|
||||||
|
type GGML struct {
|
||||||
|
ModelFamily
|
||||||
|
ModelType
|
||||||
|
|
||||||
|
magic uint32
|
||||||
|
container
|
||||||
|
|
||||||
|
llamaHyperparameters
|
||||||
|
}
|
||||||
|
|
||||||
|
type container interface {
|
||||||
|
Name() string
|
||||||
|
Decode(io.Reader) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type containerGGML struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *containerGGML) Name() string {
|
||||||
|
return "ggml"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *containerGGML) Decode(r io.Reader) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type containerGGMF struct {
|
||||||
|
version uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *containerGGMF) Name() string {
|
||||||
|
return "ggmf"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *containerGGMF) Decode(r io.Reader) error {
|
||||||
|
var version uint32
|
||||||
|
binary.Read(r, binary.LittleEndian, &version)
|
||||||
|
|
||||||
|
switch version {
|
||||||
|
case 1:
|
||||||
|
default:
|
||||||
|
return errors.New("invalid version")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.version = version
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type containerGGJT struct {
|
||||||
|
version uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *containerGGJT) Name() string {
|
||||||
|
return "ggjt"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *containerGGJT) Decode(r io.Reader) error {
|
||||||
|
var version uint32
|
||||||
|
binary.Read(r, binary.LittleEndian, &version)
|
||||||
|
|
||||||
|
switch version {
|
||||||
|
case 1, 2, 3:
|
||||||
|
default:
|
||||||
|
return errors.New("invalid version")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.version = version
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type containerLORA struct {
|
||||||
|
version uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *containerLORA) Name() string {
|
||||||
|
return "ggla"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *containerLORA) Decode(r io.Reader) error {
|
||||||
|
var version uint32
|
||||||
|
binary.Read(r, binary.LittleEndian, &version)
|
||||||
|
|
||||||
|
switch version {
|
||||||
|
case 1:
|
||||||
|
default:
|
||||||
|
return errors.New("invalid version")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.version = version
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// / Magic constant for `ggml` files (unversioned).
|
||||||
|
FILE_MAGIC_GGML = 0x67676d6c
|
||||||
|
// / Magic constant for `ggml` files (versioned, ggmf).
|
||||||
|
FILE_MAGIC_GGMF = 0x67676d66
|
||||||
|
// / Magic constant for `ggml` files (versioned, ggjt).
|
||||||
|
FILE_MAGIC_GGJT = 0x67676a74
|
||||||
|
// / Magic constant for `ggla` files (LoRA adapter).
|
||||||
|
FILE_MAGIC_GGLA = 0x67676C61
|
||||||
|
)
|
||||||
|
|
||||||
|
func DecodeGGML(r io.ReadSeeker, hint ModelFamily) (*GGML, error) {
|
||||||
|
var ggml GGML
|
||||||
|
binary.Read(r, binary.LittleEndian, &ggml.magic)
|
||||||
|
|
||||||
|
switch ggml.magic {
|
||||||
|
case FILE_MAGIC_GGML:
|
||||||
|
ggml.container = &containerGGML{}
|
||||||
|
case FILE_MAGIC_GGMF:
|
||||||
|
ggml.container = &containerGGMF{}
|
||||||
|
case FILE_MAGIC_GGJT:
|
||||||
|
ggml.container = &containerGGJT{}
|
||||||
|
case FILE_MAGIC_GGLA:
|
||||||
|
ggml.container = &containerLORA{}
|
||||||
|
default:
|
||||||
|
return nil, errors.New("invalid file magic")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ggml.Decode(r); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// different model types may have different layouts for hyperparameters
|
||||||
|
switch hint {
|
||||||
|
case ModelFamilyLlama:
|
||||||
|
binary.Read(r, binary.LittleEndian, &ggml.llamaHyperparameters)
|
||||||
|
// TODO: sanity check hyperparameters
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported model type: %s", hint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// final model type
|
||||||
|
ggml.ModelFamily = hint
|
||||||
|
ggml.ModelType = ModelType(ggml.NumLayer)
|
||||||
|
return &ggml, nil
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package llama
|
package llm
|
||||||
|
|
||||||
/*
|
/*
|
||||||
#cgo CPPFLAGS: -O3 -Wall -Wextra -Wno-unused-function -Wno-unused-variable -DNDEBUG -DGGML_USE_K_QUANTS
|
#cgo CPPFLAGS: -O3 -Wall -Wextra -Wno-unused-function -Wno-unused-variable -DNDEBUG -DGGML_USE_K_QUANTS
|
||||||
|
@ -105,7 +105,7 @@ import (
|
||||||
//go:embed ggml-metal.metal
|
//go:embed ggml-metal.metal
|
||||||
var fs embed.FS
|
var fs embed.FS
|
||||||
|
|
||||||
type LLM struct {
|
type llama struct {
|
||||||
params *C.struct_llama_context_params
|
params *C.struct_llama_context_params
|
||||||
model *C.struct_llama_model
|
model *C.struct_llama_model
|
||||||
ctx *C.struct_llama_context
|
ctx *C.struct_llama_context
|
||||||
|
@ -120,12 +120,28 @@ type LLM struct {
|
||||||
api.Options
|
api.Options
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(model string, opts api.Options) (*LLM, error) {
|
type llamaHyperparameters struct {
|
||||||
|
// NumVocab is the size of the model's vocabulary.
|
||||||
|
NumVocab uint32
|
||||||
|
|
||||||
|
// NumEmbd is the size of the model's embedding layer.
|
||||||
|
NumEmbd uint32
|
||||||
|
NumMult uint32
|
||||||
|
NumHead uint32
|
||||||
|
|
||||||
|
// NumLayer is the number of layers in the model.
|
||||||
|
NumLayer uint32
|
||||||
|
NumRot uint32
|
||||||
|
// FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
|
||||||
|
FileType
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLlama(model string, opts api.Options) (*llama, error) {
|
||||||
if _, err := os.Stat(model); err != nil {
|
if _, err := os.Stat(model); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
llm := LLM{Options: opts}
|
llm := llama{Options: opts}
|
||||||
|
|
||||||
C.llama_backend_init(C.bool(llm.UseNUMA))
|
C.llama_backend_init(C.bool(llm.UseNUMA))
|
||||||
|
|
||||||
|
@ -168,7 +184,7 @@ func New(model string, opts api.Options) (*LLM, error) {
|
||||||
return &llm, nil
|
return &llm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *LLM) Close() {
|
func (llm *llama) Close() {
|
||||||
llm.gc = true
|
llm.gc = true
|
||||||
|
|
||||||
llm.mu.Lock()
|
llm.mu.Lock()
|
||||||
|
@ -180,17 +196,16 @@ func (llm *LLM) Close() {
|
||||||
C.llama_print_timings(llm.ctx)
|
C.llama_print_timings(llm.ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (llm *llama) SetOptions(opts api.Options) {
|
||||||
|
llm.Options = opts
|
||||||
|
}
|
||||||
|
|
||||||
var errNeedMoreData = errors.New("need more data")
|
var errNeedMoreData = errors.New("need more data")
|
||||||
|
|
||||||
func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
|
||||||
C.llama_reset_timings(llm.ctx)
|
C.llama_reset_timings(llm.ctx)
|
||||||
|
|
||||||
tokens := make([]C.llama_token, len(ctx))
|
llm.marshalPrompt(ctx, prompt)
|
||||||
for i := range tokens {
|
|
||||||
tokens[i] = C.llama_token(ctx[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
llm.marshalPrompt(tokens, prompt)
|
|
||||||
|
|
||||||
C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed))
|
C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed))
|
||||||
|
|
||||||
|
@ -205,7 +220,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
b.WriteString(llm.Decode(token))
|
b.WriteString(llm.Decode(int(token)))
|
||||||
|
|
||||||
if err := llm.checkStopConditions(b); err != nil {
|
if err := llm.checkStopConditions(b); err != nil {
|
||||||
if errors.Is(err, io.EOF) {
|
if errors.Is(err, io.EOF) {
|
||||||
|
@ -243,7 +258,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
|
func (llm *llama) checkStopConditions(b bytes.Buffer) error {
|
||||||
for _, stopCondition := range llm.Stop {
|
for _, stopCondition := range llm.Stop {
|
||||||
if stopCondition == strings.TrimSpace(b.String()) {
|
if stopCondition == strings.TrimSpace(b.String()) {
|
||||||
return io.EOF
|
return io.EOF
|
||||||
|
@ -255,12 +270,17 @@ func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token {
|
func (llm *llama) marshalPrompt(ctx []int, prompt string) []C.llama_token {
|
||||||
tokens := append(ctx, llm.Encode(prompt)...)
|
tokens := append(ctx, llm.Encode(prompt)...)
|
||||||
if llm.NumKeep < 0 {
|
if llm.NumKeep < 0 {
|
||||||
llm.NumKeep = len(tokens)
|
llm.NumKeep = len(tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cTokens := make([]C.llama_token, len(tokens))
|
||||||
|
for i := range tokens {
|
||||||
|
cTokens[i] = C.llama_token(tokens[i])
|
||||||
|
}
|
||||||
|
|
||||||
// min(llm.NumCtx - 4, llm.NumKeep)
|
// min(llm.NumCtx - 4, llm.NumKeep)
|
||||||
if llm.NumCtx-4 < llm.NumKeep {
|
if llm.NumCtx-4 < llm.NumKeep {
|
||||||
llm.NumKeep = llm.NumCtx - 4
|
llm.NumKeep = llm.NumCtx - 4
|
||||||
|
@ -269,25 +289,25 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke
|
||||||
if len(tokens) >= llm.NumCtx {
|
if len(tokens) >= llm.NumCtx {
|
||||||
// truncate input
|
// truncate input
|
||||||
numLeft := (llm.NumCtx - llm.NumKeep) / 2
|
numLeft := (llm.NumCtx - llm.NumKeep) / 2
|
||||||
truncated := tokens[:llm.NumKeep]
|
truncated := cTokens[:llm.NumKeep]
|
||||||
erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft
|
erasedBlocks := (len(cTokens) - llm.NumKeep - numLeft - 1) / numLeft
|
||||||
truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...)
|
truncated = append(truncated, cTokens[llm.NumKeep+erasedBlocks*numLeft:]...)
|
||||||
copy(llm.last, tokens[len(tokens)-llm.NumCtx:])
|
copy(llm.last, cTokens[len(cTokens)-llm.NumCtx:])
|
||||||
|
|
||||||
tokens = truncated
|
cTokens = truncated
|
||||||
log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated))
|
log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated))
|
||||||
} else {
|
} else {
|
||||||
llm.last = make([]C.llama_token, llm.NumCtx-len(tokens))
|
llm.last = make([]C.llama_token, llm.NumCtx-len(cTokens))
|
||||||
llm.last = append(llm.last, tokens...)
|
llm.last = append(llm.last, cTokens...)
|
||||||
}
|
}
|
||||||
|
|
||||||
var i int
|
var i int
|
||||||
for i = 0; i < len(llm.embd) && i < len(tokens) && llm.embd[i] == tokens[i]; i++ {
|
for i = 0; i < len(llm.embd) && i < len(cTokens) && llm.embd[i] == cTokens[i]; i++ {
|
||||||
// noop
|
// noop
|
||||||
}
|
}
|
||||||
|
|
||||||
llm.embd = tokens
|
llm.embd = cTokens
|
||||||
if i == len(tokens) {
|
if i == len(cTokens) {
|
||||||
// evaluate at least one token to generate logits
|
// evaluate at least one token to generate logits
|
||||||
i--
|
i--
|
||||||
}
|
}
|
||||||
|
@ -295,31 +315,36 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke
|
||||||
llm.cursor = i
|
llm.cursor = i
|
||||||
|
|
||||||
log.Printf("prompt: num_past=%d cached=%v eval=%v", i, len(llm.embd[:i]), len(llm.embd[i:]))
|
log.Printf("prompt: num_past=%d cached=%v eval=%v", i, len(llm.embd[:i]), len(llm.embd[i:]))
|
||||||
return tokens
|
return cTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *LLM) Encode(prompt string) []C.llama_token {
|
func (llm *llama) Encode(prompt string) []int {
|
||||||
cPrompt := C.CString(prompt)
|
cPrompt := C.CString(prompt)
|
||||||
defer C.free(unsafe.Pointer(cPrompt))
|
defer C.free(unsafe.Pointer(cPrompt))
|
||||||
|
|
||||||
tokens := make([]C.llama_token, len(prompt)+1)
|
cTokens := make([]C.llama_token, len(prompt)+1)
|
||||||
if n := C.llama_tokenize(llm.ctx, cPrompt, unsafe.SliceData(tokens), C.int(len(tokens)), true); n > 0 {
|
if n := C.llama_tokenize(llm.ctx, cPrompt, unsafe.SliceData(cTokens), C.int(len(cTokens)), true); n > 0 {
|
||||||
return tokens[:n]
|
tokens := make([]int, n)
|
||||||
|
for i := range cTokens[:n] {
|
||||||
|
tokens[i] = int(cTokens[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *LLM) Decode(tokens ...C.llama_token) string {
|
func (llm *llama) Decode(tokens ...int) string {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for _, token := range tokens {
|
for _, token := range tokens {
|
||||||
sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token)))
|
sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, C.llama_token(token))))
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.String()
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *LLM) next() (C.llama_token, error) {
|
func (llm *llama) next() (C.llama_token, error) {
|
||||||
llm.mu.Lock()
|
llm.mu.Lock()
|
||||||
defer llm.mu.Unlock()
|
defer llm.mu.Unlock()
|
||||||
|
|
||||||
|
@ -410,7 +435,7 @@ func (llm *LLM) next() (C.llama_token, error) {
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm *LLM) Embedding(input string) ([]float64, error) {
|
func (llm *llama) Embedding(input string) ([]float64, error) {
|
||||||
if !llm.EmbeddingOnly {
|
if !llm.EmbeddingOnly {
|
||||||
return nil, errors.New("llama: embedding not enabled")
|
return nil, errors.New("llama: embedding not enabled")
|
||||||
}
|
}
|
||||||
|
@ -420,7 +445,12 @@ func (llm *LLM) Embedding(input string) ([]float64, error) {
|
||||||
return nil, errors.New("llama: tokenize embedding")
|
return nil, errors.New("llama: tokenize embedding")
|
||||||
}
|
}
|
||||||
|
|
||||||
retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), 0, C.int(llm.NumThread))
|
cTokens := make([]C.llama_token, len(tokens))
|
||||||
|
for i := range tokens {
|
||||||
|
cTokens[i] = C.llama_token(tokens[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
retval := C.llama_eval(llm.ctx, unsafe.SliceData(cTokens), C.int(len(tokens)), 0, C.int(llm.NumThread))
|
||||||
if retval != 0 {
|
if retval != 0 {
|
||||||
return nil, errors.New("llama: eval")
|
return nil, errors.New("llama: eval")
|
||||||
}
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package llama
|
package llm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
40
llm/llm.go
Normal file
40
llm/llm.go
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/jmorganca/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LLM interface {
|
||||||
|
Predict([]int, string, func(api.GenerateResponse)) error
|
||||||
|
Embedding(string) ([]float64, error)
|
||||||
|
Encode(string) []int
|
||||||
|
Decode(...int) string
|
||||||
|
SetOptions(api.Options)
|
||||||
|
Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(model string, opts api.Options) (LLM, error) {
|
||||||
|
if _, err := os.Stat(model); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.Open(model)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml, err := DecodeGGML(f, ModelFamilyLlama)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ggml.ModelFamily {
|
||||||
|
case ModelFamilyLlama:
|
||||||
|
return newLlama(model, opts)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package llama
|
package llm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
|
@ -19,7 +19,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
"github.com/jmorganca/ollama/llama"
|
"github.com/jmorganca/ollama/llm"
|
||||||
"github.com/jmorganca/ollama/parser"
|
"github.com/jmorganca/ollama/parser"
|
||||||
"github.com/jmorganca/ollama/vector"
|
"github.com/jmorganca/ollama/vector"
|
||||||
)
|
)
|
||||||
|
@ -99,9 +99,14 @@ type LayerReader struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type ConfigV2 struct {
|
type ConfigV2 struct {
|
||||||
|
ModelFamily llm.ModelFamily `json:"model_family"`
|
||||||
|
ModelType llm.ModelType `json:"model_type"`
|
||||||
|
FileType llm.FileType `json:"file_type"`
|
||||||
|
RootFS RootFS `json:"rootfs"`
|
||||||
|
|
||||||
|
// required by spec
|
||||||
Architecture string `json:"architecture"`
|
Architecture string `json:"architecture"`
|
||||||
OS string `json:"os"`
|
OS string `json:"os"`
|
||||||
RootFS RootFS `json:"rootfs"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type RootFS struct {
|
type RootFS struct {
|
||||||
|
@ -246,6 +251,11 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
config := ConfigV2{
|
||||||
|
Architecture: "amd64",
|
||||||
|
OS: "linux",
|
||||||
|
}
|
||||||
|
|
||||||
var layers []*LayerReader
|
var layers []*LayerReader
|
||||||
params := make(map[string][]string)
|
params := make(map[string][]string)
|
||||||
embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()}
|
embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()}
|
||||||
|
@ -284,6 +294,18 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
||||||
}
|
}
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
|
|
||||||
|
ggml, err := llm.DecodeGGML(file, llm.ModelFamilyLlama)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
config.ModelFamily = ggml.ModelFamily
|
||||||
|
config.ModelType = ggml.ModelType
|
||||||
|
config.FileType = ggml.FileType
|
||||||
|
|
||||||
|
// reset the file
|
||||||
|
file.Seek(0, io.SeekStart)
|
||||||
|
|
||||||
l, err := CreateLayer(file)
|
l, err := CreateLayer(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create layer: %v", err)
|
return fmt.Errorf("failed to create layer: %v", err)
|
||||||
|
@ -292,6 +314,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
||||||
layers = append(layers, l)
|
layers = append(layers, l)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if mf != nil {
|
if mf != nil {
|
||||||
log.Printf("manifest = %#v", mf)
|
log.Printf("manifest = %#v", mf)
|
||||||
for _, l := range mf.Layers {
|
for _, l := range mf.Layers {
|
||||||
|
@ -321,7 +344,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
||||||
layers = append(layers, layer)
|
layers = append(layers, layer)
|
||||||
case "template", "system", "prompt":
|
case "template", "system", "prompt":
|
||||||
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
|
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
|
||||||
// remove the prompt layer if one exists
|
// remove the layer if one exists
|
||||||
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
|
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
|
||||||
layers = removeLayerFromLayers(layers, mediaType)
|
layers = removeLayerFromLayers(layers, mediaType)
|
||||||
|
|
||||||
|
@ -383,7 +406,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
||||||
|
|
||||||
// Create a layer for the config object
|
// Create a layer for the config object
|
||||||
fn(api.ProgressResponse{Status: "creating config layer"})
|
fn(api.ProgressResponse{Status: "creating config layer"})
|
||||||
cfg, err := createConfigLayer(digests)
|
cfg, err := createConfigLayer(config, digests)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -430,13 +453,13 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
e.opts.EmbeddingOnly = true
|
e.opts.EmbeddingOnly = true
|
||||||
llm, err := llama.New(e.model, e.opts)
|
llmModel, err := llm.New(e.model, e.opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
|
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if llm != nil {
|
if llmModel != nil {
|
||||||
llm.Close()
|
llmModel.Close()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -480,7 +503,7 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
|
||||||
Total: len(data) - 1,
|
Total: len(data) - 1,
|
||||||
Completed: i,
|
Completed: i,
|
||||||
})
|
})
|
||||||
embed, err := llm.Embedding(d)
|
embed, err := llmModel.Embedding(d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
|
log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
|
||||||
continue
|
continue
|
||||||
|
@ -676,7 +699,7 @@ func getLayerDigests(layers []*LayerReader) ([]string, error) {
|
||||||
// CreateLayer creates a Layer object from a given file
|
// CreateLayer creates a Layer object from a given file
|
||||||
func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
|
func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
|
||||||
digest, size := GetSHA256Digest(f)
|
digest, size := GetSHA256Digest(f)
|
||||||
f.Seek(0, 0)
|
f.Seek(0, io.SeekStart)
|
||||||
|
|
||||||
layer := &LayerReader{
|
layer := &LayerReader{
|
||||||
Layer: Layer{
|
Layer: Layer{
|
||||||
|
@ -768,10 +791,6 @@ func DeleteModel(name string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// only delete the files which are still in the deleteMap
|
// only delete the files which are still in the deleteMap
|
||||||
for k, v := range deleteMap {
|
for k, v := range deleteMap {
|
||||||
if v {
|
if v {
|
||||||
|
@ -970,15 +989,10 @@ func pullModelManifest(mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, err
|
||||||
return m, err
|
return m, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func createConfigLayer(layers []string) (*LayerReader, error) {
|
func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
|
||||||
// TODO change architecture and OS
|
config.RootFS = RootFS{
|
||||||
config := ConfigV2{
|
Type: "layers",
|
||||||
Architecture: "arm64",
|
DiffIDs: layers,
|
||||||
OS: "linux",
|
|
||||||
RootFS: RootFS{
|
|
||||||
Type: "layers",
|
|
||||||
DiffIDs: layers,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
configJSON, err := json.Marshal(config)
|
configJSON, err := json.Marshal(config)
|
||||||
|
|
|
@ -21,14 +21,14 @@ import (
|
||||||
"gonum.org/v1/gonum/mat"
|
"gonum.org/v1/gonum/mat"
|
||||||
|
|
||||||
"github.com/jmorganca/ollama/api"
|
"github.com/jmorganca/ollama/api"
|
||||||
"github.com/jmorganca/ollama/llama"
|
"github.com/jmorganca/ollama/llm"
|
||||||
"github.com/jmorganca/ollama/vector"
|
"github.com/jmorganca/ollama/vector"
|
||||||
)
|
)
|
||||||
|
|
||||||
var loaded struct {
|
var loaded struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
|
||||||
llm *llama.LLM
|
llm llm.LLM
|
||||||
Embeddings []vector.Embedding
|
Embeddings []vector.Embedding
|
||||||
|
|
||||||
expireAt time.Time
|
expireAt time.Time
|
||||||
|
@ -63,11 +63,16 @@ func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Dur
|
||||||
loaded.Embeddings = model.Embeddings
|
loaded.Embeddings = model.Embeddings
|
||||||
}
|
}
|
||||||
|
|
||||||
llm, err := llama.New(model.ModelPath, opts)
|
llmModel, err := llm.New(model.ModelPath, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set cache values before modifying opts
|
||||||
|
loaded.llm = llmModel
|
||||||
|
loaded.digest = model.Digest
|
||||||
|
loaded.options = opts
|
||||||
|
|
||||||
if opts.NumKeep < 0 {
|
if opts.NumKeep < 0 {
|
||||||
promptWithSystem, err := model.Prompt(api.GenerateRequest{}, "")
|
promptWithSystem, err := model.Prompt(api.GenerateRequest{}, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -79,15 +84,13 @@ func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Dur
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
tokensWithSystem := llm.Encode(promptWithSystem)
|
tokensWithSystem := llmModel.Encode(promptWithSystem)
|
||||||
tokensNoSystem := llm.Encode(promptNoSystem)
|
tokensNoSystem := llmModel.Encode(promptNoSystem)
|
||||||
|
|
||||||
llm.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) + 1
|
opts.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) + 1
|
||||||
|
|
||||||
|
llmModel.SetOptions(opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
loaded.llm = llm
|
|
||||||
loaded.digest = model.Digest
|
|
||||||
loaded.options = opts
|
|
||||||
}
|
}
|
||||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue