implement loading ggml lora adapters through the modelfile
This commit is contained in:
parent
d791df75dd
commit
6de5d032e1
5 changed files with 65 additions and 13 deletions
17
llm/llama.go
17
llm/llama.go
|
@ -136,7 +136,7 @@ type llamaHyperparameters struct {
|
||||||
FileType
|
FileType
|
||||||
}
|
}
|
||||||
|
|
||||||
func newLlama(model string, opts api.Options) (*llama, error) {
|
func newLlama(model string, adapters []string, opts api.Options) (*llama, error) {
|
||||||
if _, err := os.Stat(model); err != nil {
|
if _, err := os.Stat(model); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -161,6 +161,12 @@ func newLlama(model string, opts api.Options) (*llama, error) {
|
||||||
params.embedding = C.bool(llm.EmbeddingOnly)
|
params.embedding = C.bool(llm.EmbeddingOnly)
|
||||||
params.rope_freq_base = C.float(llm.RopeFrequencyBase)
|
params.rope_freq_base = C.float(llm.RopeFrequencyBase)
|
||||||
params.rope_freq_scale = C.float(llm.RopeFrequencyScale)
|
params.rope_freq_scale = C.float(llm.RopeFrequencyScale)
|
||||||
|
|
||||||
|
if len(adapters) > 0 && llm.UseMMap {
|
||||||
|
log.Printf("must disable mmap to use lora adapters")
|
||||||
|
params.use_mmap = C.bool(false)
|
||||||
|
}
|
||||||
|
|
||||||
llm.params = ¶ms
|
llm.params = ¶ms
|
||||||
|
|
||||||
cModel := C.CString(model)
|
cModel := C.CString(model)
|
||||||
|
@ -176,6 +182,15 @@ func newLlama(model string, opts api.Options) (*llama, error) {
|
||||||
return nil, errors.New("failed to create context")
|
return nil, errors.New("failed to create context")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, adapter := range adapters {
|
||||||
|
cAdapter := C.CString(adapter)
|
||||||
|
defer C.free(unsafe.Pointer(cAdapter))
|
||||||
|
|
||||||
|
if retval := C.llama_model_apply_lora_from_file(llm.model, cAdapter, nil, C.int(llm.NumThread)); retval != 0 {
|
||||||
|
return nil, fmt.Errorf("failed to load adapter %s", adapter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// warm up the model
|
// warm up the model
|
||||||
bos := []C.llama_token{C.llama_token_bos()}
|
bos := []C.llama_token{C.llama_token_bos()}
|
||||||
C.llama_eval(llm.ctx, unsafe.SliceData(bos), C.int(len(bos)), 0, C.int(opts.NumThread))
|
C.llama_eval(llm.ctx, unsafe.SliceData(bos), C.int(len(bos)), 0, C.int(opts.NumThread))
|
||||||
|
|
|
@ -19,7 +19,7 @@ type LLM interface {
|
||||||
Close()
|
Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(model string, opts api.Options) (LLM, error) {
|
func New(model string, adapters []string, opts api.Options) (LLM, error) {
|
||||||
if _, err := os.Stat(model); err != nil {
|
if _, err := os.Stat(model); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -66,7 +66,7 @@ func New(model string, opts api.Options) (LLM, error) {
|
||||||
|
|
||||||
switch ggml.ModelFamily {
|
switch ggml.ModelFamily {
|
||||||
case ModelFamilyLlama:
|
case ModelFamilyLlama:
|
||||||
return newLlama(model, opts)
|
return newLlama(model, adapters, opts)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily)
|
return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily)
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,7 +40,7 @@ func Parse(reader io.Reader) ([]Command, error) {
|
||||||
command.Args = string(fields[1])
|
command.Args = string(fields[1])
|
||||||
// copy command for validation
|
// copy command for validation
|
||||||
modelCommand = command
|
modelCommand = command
|
||||||
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "EMBED":
|
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "EMBED", "ADAPTER":
|
||||||
command.Name = string(bytes.ToLower(fields[0]))
|
command.Name = string(bytes.ToLower(fields[0]))
|
||||||
command.Args = string(fields[1])
|
command.Args = string(fields[1])
|
||||||
case "PARAMETER":
|
case "PARAMETER":
|
||||||
|
|
|
@ -33,6 +33,7 @@ type RegistryOptions struct {
|
||||||
type Model struct {
|
type Model struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
ModelPath string
|
ModelPath string
|
||||||
|
AdapterPaths []string
|
||||||
Template string
|
Template string
|
||||||
System string
|
System string
|
||||||
Digest string
|
Digest string
|
||||||
|
@ -178,6 +179,8 @@ func GetModel(name string) (*Model, error) {
|
||||||
if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil {
|
if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
case "application/vnd.ollama.image.adapter":
|
||||||
|
model.AdapterPaths = append(model.AdapterPaths, filename)
|
||||||
case "application/vnd.ollama.image.template":
|
case "application/vnd.ollama.image.template":
|
||||||
bts, err := os.ReadFile(filename)
|
bts, err := os.ReadFile(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -330,6 +333,40 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
embed.files = append(embed.files, embedFilePath)
|
embed.files = append(embed.files, embedFilePath)
|
||||||
|
case "adapter":
|
||||||
|
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
|
||||||
|
|
||||||
|
fp := c.Args
|
||||||
|
if strings.HasPrefix(fp, "~/") {
|
||||||
|
parts := strings.Split(fp, "/")
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fp = filepath.Join(home, filepath.Join(parts[1:]...))
|
||||||
|
}
|
||||||
|
|
||||||
|
// If filePath is not an absolute path, make it relative to the modelfile path
|
||||||
|
if !filepath.IsAbs(fp) {
|
||||||
|
fp = filepath.Join(filepath.Dir(path), fp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a model from this specified file
|
||||||
|
fn(api.ProgressResponse{Status: "creating model layer"})
|
||||||
|
|
||||||
|
file, err := os.Open(fp)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open file: %v", err)
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
l, err := CreateLayer(file)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create layer: %v", err)
|
||||||
|
}
|
||||||
|
l.MediaType = "application/vnd.ollama.image.adapter"
|
||||||
|
layers = append(layers, l)
|
||||||
case "license":
|
case "license":
|
||||||
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
|
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
|
||||||
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
|
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
|
||||||
|
@ -452,7 +489,7 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
e.opts.EmbeddingOnly = true
|
e.opts.EmbeddingOnly = true
|
||||||
llmModel, err := llm.New(e.model, e.opts)
|
llmModel, err := llm.New(e.model, []string{}, e.opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
|
return nil, fmt.Errorf("load model to generate embeddings: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,7 +63,7 @@ func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Dur
|
||||||
loaded.Embeddings = model.Embeddings
|
loaded.Embeddings = model.Embeddings
|
||||||
}
|
}
|
||||||
|
|
||||||
llmModel, err := llm.New(model.ModelPath, opts)
|
llmModel, err := llm.New(model.ModelPath, model.AdapterPaths, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue