Merge pull request #290 from jmorganca/add-adapter-layers

implement loading ggml lora adapters through the modelfile
This commit is contained in:
Michael Yang 2023-08-10 17:23:01 -07:00 committed by GitHub
commit 6517bcc53c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 75 additions and 14 deletions

View file

@ -35,6 +35,7 @@ INSTRUCTION arguments
| [`PARAMETER`](#parameter) | Sets the parameters for how Ollama will run the model. | | [`PARAMETER`](#parameter) | Sets the parameters for how Ollama will run the model. |
| [`TEMPLATE`](#template) | The full prompt template to be sent to the model. | | [`TEMPLATE`](#template) | The full prompt template to be sent to the model. |
| [`SYSTEM`](#system) | Specifies the system prompt that will be set in the template. | | [`SYSTEM`](#system) | Specifies the system prompt that will be set in the template. |
| [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. |
| [`LICENSE`](#license) | Specifies the legal license. | | [`LICENSE`](#license) | Specifies the legal license. |
## Examples ## Examples
@ -150,6 +151,14 @@ The `SYSTEM` instruction specifies the system prompt to be used in the template,
SYSTEM """<system message>""" SYSTEM """<system message>"""
``` ```
### ADAPTER
The `ADAPTER` instruction specifies the LoRA adapter to apply to the base model. The value of this instruction should be an absolute path or a path relative to the Modelfile and the file must be in a GGML file format. The adapter should be tuned from the base model otherwise the behaviour is undefined.
```
ADAPTER ./ollama-lora.bin
```
### LICENSE ### LICENSE
The `LICENSE` instruction allows you to specify the legal license under which the model used with this Modelfile is shared or distributed. The `LICENSE` instruction allows you to specify the legal license under which the model used with this Modelfile is shared or distributed.

View file

@ -136,7 +136,7 @@ type llamaHyperparameters struct {
FileType FileType
} }
func newLlama(model string, opts api.Options) (*llama, error) { func newLlama(model string, adapters []string, opts api.Options) (*llama, error) {
if _, err := os.Stat(model); err != nil { if _, err := os.Stat(model); err != nil {
return nil, err return nil, err
} }
@ -161,6 +161,12 @@ func newLlama(model string, opts api.Options) (*llama, error) {
params.embedding = C.bool(llm.EmbeddingOnly) params.embedding = C.bool(llm.EmbeddingOnly)
params.rope_freq_base = C.float(llm.RopeFrequencyBase) params.rope_freq_base = C.float(llm.RopeFrequencyBase)
params.rope_freq_scale = C.float(llm.RopeFrequencyScale) params.rope_freq_scale = C.float(llm.RopeFrequencyScale)
if len(adapters) > 0 && llm.UseMMap {
log.Printf("must disable mmap to use lora adapters")
params.use_mmap = C.bool(false)
}
llm.params = &params llm.params = &params
cModel := C.CString(model) cModel := C.CString(model)
@ -176,6 +182,15 @@ func newLlama(model string, opts api.Options) (*llama, error) {
return nil, errors.New("failed to create context") return nil, errors.New("failed to create context")
} }
for _, adapter := range adapters {
cAdapter := C.CString(adapter)
defer C.free(unsafe.Pointer(cAdapter))
if retval := C.llama_model_apply_lora_from_file(llm.model, cAdapter, nil, C.int(llm.NumThread)); retval != 0 {
return nil, fmt.Errorf("failed to load adapter %s", adapter)
}
}
// warm up the model // warm up the model
bos := []C.llama_token{C.llama_token_bos()} bos := []C.llama_token{C.llama_token_bos()}
C.llama_eval(llm.ctx, unsafe.SliceData(bos), C.int(len(bos)), 0, C.int(opts.NumThread)) C.llama_eval(llm.ctx, unsafe.SliceData(bos), C.int(len(bos)), 0, C.int(opts.NumThread))

View file

@ -19,7 +19,7 @@ type LLM interface {
Close() Close()
} }
func New(model string, opts api.Options) (LLM, error) { func New(model string, adapters []string, opts api.Options) (LLM, error) {
if _, err := os.Stat(model); err != nil { if _, err := os.Stat(model); err != nil {
return nil, err return nil, err
} }
@ -66,7 +66,7 @@ func New(model string, opts api.Options) (LLM, error) {
switch ggml.ModelFamily { switch ggml.ModelFamily {
case ModelFamilyLlama: case ModelFamilyLlama:
return newLlama(model, opts) return newLlama(model, adapters, opts)
default: default:
return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily) return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily)
} }

View file

@ -40,7 +40,7 @@ func Parse(reader io.Reader) ([]Command, error) {
command.Args = string(fields[1]) command.Args = string(fields[1])
// copy command for validation // copy command for validation
modelCommand = command modelCommand = command
case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "EMBED": case "LICENSE", "TEMPLATE", "SYSTEM", "PROMPT", "EMBED", "ADAPTER":
command.Name = string(bytes.ToLower(fields[0])) command.Name = string(bytes.ToLower(fields[0]))
command.Args = string(fields[1]) command.Args = string(fields[1])
case "PARAMETER": case "PARAMETER":

View file

@ -32,13 +32,14 @@ type RegistryOptions struct {
} }
type Model struct { type Model struct {
Name string `json:"name"` Name string `json:"name"`
ModelPath string ModelPath string
Template string AdapterPaths []string
System string Template string
Digest string System string
Options map[string]interface{} Digest string
Embeddings []vector.Embedding Options map[string]interface{}
Embeddings []vector.Embedding
} }
func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) { func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) {
@ -179,6 +180,8 @@ func GetModel(name string) (*Model, error) {
if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil { if err = json.NewDecoder(file).Decode(&model.Embeddings); err != nil {
return nil, err return nil, err
} }
case "application/vnd.ollama.image.adapter":
model.AdapterPaths = append(model.AdapterPaths, filename)
case "application/vnd.ollama.image.template": case "application/vnd.ollama.image.template":
bts, err := os.ReadFile(filename) bts, err := os.ReadFile(filename)
if err != nil { if err != nil {
@ -331,6 +334,40 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
return err return err
} }
embed.files = append(embed.files, embedFilePath) embed.files = append(embed.files, embedFilePath)
case "adapter":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
fp := c.Args
if strings.HasPrefix(fp, "~/") {
parts := strings.Split(fp, "/")
home, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
}
fp = filepath.Join(home, filepath.Join(parts[1:]...))
}
// If filePath is not an absolute path, make it relative to the modelfile path
if !filepath.IsAbs(fp) {
fp = filepath.Join(filepath.Dir(path), fp)
}
// create a model from this specified file
fn(api.ProgressResponse{Status: "creating model layer"})
file, err := os.Open(fp)
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
}
defer file.Close()
l, err := CreateLayer(file)
if err != nil {
return fmt.Errorf("failed to create layer: %v", err)
}
l.MediaType = "application/vnd.ollama.image.adapter"
layers = append(layers, l)
case "license": case "license":
fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)}) fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
@ -453,7 +490,7 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
} }
e.opts.EmbeddingOnly = true e.opts.EmbeddingOnly = true
llmModel, err := llm.New(e.model, e.opts) llmModel, err := llm.New(e.model, []string{}, e.opts)
if err != nil { if err != nil {
return nil, fmt.Errorf("load model to generate embeddings: %v", err) return nil, fmt.Errorf("load model to generate embeddings: %v", err)
} }

View file

@ -63,7 +63,7 @@ func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Dur
loaded.Embeddings = model.Embeddings loaded.Embeddings = model.Embeddings
} }
llmModel, err := llm.New(model.ModelPath, opts) llmModel, err := llm.New(model.ModelPath, model.AdapterPaths, opts)
if err != nil { if err != nil {
return err return err
} }