Merge pull request #1334 from jmorganca/mxyng/load-projectors

load projectors
This commit is contained in:
Michael Yang 2023-12-05 14:40:53 -08:00 committed by GitHub
commit 32f62fbb8e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 62 additions and 27 deletions

View file

@ -203,12 +203,22 @@ type GenerateResponse struct {
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
Response string `json:"response"` Response string `json:"response"`
ModelConfiguration ModelConfiguration `json:"model_configuration"`
Done bool `json:"done"` Done bool `json:"done"`
Context []int `json:"context,omitempty"` Context []int `json:"context,omitempty"`
Metrics Metrics
} }
type ModelConfiguration struct {
ModelFormat string `json:"model_format"`
ModelFamily string `json:"model_family"`
ModelFamilies []string `json:"model_families"`
ModelType string `json:"model_type"`
FileType string `json:"file_type"`
}
func (m *Metrics) Summary() { func (m *Metrics) Summary() {
if m.TotalDuration > 0 { if m.TotalDuration > 0 {
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration) fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)

View file

@ -325,7 +325,7 @@ func (w *StatusWriter) Write(b []byte) (int, error) {
return os.Stderr.Write(b) return os.Stderr.Write(b)
} }
func newLlama(model string, adapters []string, runners []ModelRunner, numLayers int64, opts api.Options) (*llama, error) { func newLlama(model string, adapters, projectors []string, runners []ModelRunner, numLayers int64, opts api.Options) (*llama, error) {
fileInfo, err := os.Stat(model) fileInfo, err := os.Stat(model)
if err != nil { if err != nil {
return nil, err return nil, err
@ -365,6 +365,11 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers
params = append(params, "--lora", adapters[0]) params = append(params, "--lora", adapters[0])
} }
if len(projectors) > 0 {
// TODO: applying multiple projectors is not supported by the llama.cpp server yet
params = append(params, "--mmproj", projectors[0])
}
if opts.NumThread > 0 { if opts.NumThread > 0 {
params = append(params, "--threads", fmt.Sprintf("%d", opts.NumThread)) params = append(params, "--threads", fmt.Sprintf("%d", opts.NumThread))
} }

View file

@ -23,7 +23,7 @@ type LLM interface {
Ping(context.Context) error Ping(context.Context) error
} }
func New(workDir, model string, adapters []string, opts api.Options) (LLM, error) { func New(workDir, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
if _, err := os.Stat(model); err != nil { if _, err := os.Stat(model); err != nil {
return nil, err return nil, err
} }
@ -82,9 +82,9 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error
opts.NumGQA = 0 opts.NumGQA = 0
opts.RopeFrequencyBase = 0.0 opts.RopeFrequencyBase = 0.0
opts.RopeFrequencyScale = 0.0 opts.RopeFrequencyScale = 0.0
return newLlama(model, adapters, chooseRunners(workDir, "gguf"), ggml.NumLayers(), opts) return newLlama(model, adapters, projectors, chooseRunners(workDir, "gguf"), ggml.NumLayers(), opts)
case "ggml", "ggmf", "ggjt", "ggla": case "ggml", "ggmf", "ggjt", "ggla":
return newLlama(model, adapters, chooseRunners(workDir, "ggml"), ggml.NumLayers(), opts) return newLlama(model, adapters, projectors, chooseRunners(workDir, "ggml"), ggml.NumLayers(), opts)
default: default:
return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily()) return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily())
} }

View file

@ -35,16 +35,18 @@ type RegistryOptions struct {
} }
type Model struct { type Model struct {
Name string `json:"name"` Name string `json:"name"`
ShortName string Config ConfigV2
ModelPath string ShortName string
OriginalModel string ModelPath string
AdapterPaths []string OriginalModel string
Template string AdapterPaths []string
System string ProjectorPaths []string
License []string Template string
Digest string System string
Options map[string]interface{} License []string
Digest string
Options map[string]interface{}
} }
type PromptVars struct { type PromptVars struct {
@ -136,16 +138,12 @@ type ManifestV2 struct {
} }
type ConfigV2 struct { type ConfigV2 struct {
ModelFormat string `json:"model_format"`
ModelFamily string `json:"model_family"`
ModelFamilies []string `json:"model_families"`
ModelType string `json:"model_type"`
FileType string `json:"file_type"`
RootFS RootFS `json:"rootfs"`
// required by spec // required by spec
Architecture string `json:"architecture"` Architecture string `json:"architecture"`
OS string `json:"os"` OS string `json:"os"`
RootFS RootFS `json:"rootfs"`
api.ModelConfiguration
} }
func (c *ConfigV2) SetModelFormat(format string) { func (c *ConfigV2) SetModelFormat(format string) {
@ -234,6 +232,21 @@ func GetModel(name string) (*Model, error) {
License: []string{}, License: []string{},
} }
filename, err := GetBlobsPath(manifest.Config.Digest)
if err != nil {
return nil, err
}
configFile, err := os.Open(filename)
if err != nil {
return nil, err
}
defer configFile.Close()
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
return nil, err
}
for _, layer := range manifest.Layers { for _, layer := range manifest.Layers {
filename, err := GetBlobsPath(layer.Digest) filename, err := GetBlobsPath(layer.Digest)
if err != nil { if err != nil {
@ -250,6 +263,8 @@ func GetModel(name string) (*Model, error) {
log.Print("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.") log.Print("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
case "application/vnd.ollama.image.adapter": case "application/vnd.ollama.image.adapter":
model.AdapterPaths = append(model.AdapterPaths, filename) model.AdapterPaths = append(model.AdapterPaths, filename)
case "application/vnd.ollama.image.projector":
model.ProjectorPaths = append(model.ProjectorPaths, filename)
case "application/vnd.ollama.image.template": case "application/vnd.ollama.image.template":
bts, err := os.ReadFile(filename) bts, err := os.ReadFile(filename)
if err != nil { if err != nil {

View file

@ -105,7 +105,7 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess
loaded.Options = nil loaded.Options = nil
} }
llmRunner, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts) llmRunner, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
if err != nil { if err != nil {
// some older models are not compatible with newer versions of llama.cpp // some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to // show a generalized compatibility error until there is a better way to
@ -198,7 +198,11 @@ func GenerateHandler(c *gin.Context) {
// an empty request loads the model // an empty request loads the model
if req.Prompt == "" && req.Template == "" && req.System == "" { if req.Prompt == "" && req.Template == "" && req.System == "" {
c.JSON(http.StatusOK, api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true}) c.JSON(http.StatusOK, api.GenerateResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
ModelConfiguration: model.Config.ModelConfiguration,
Done: true})
return return
} }
@ -257,10 +261,11 @@ func GenerateHandler(c *gin.Context) {
} }
resp := api.GenerateResponse{ resp := api.GenerateResponse{
Model: r.Model, Model: r.Model,
CreatedAt: r.CreatedAt, ModelConfiguration: model.Config.ModelConfiguration,
Done: r.Done, CreatedAt: r.CreatedAt,
Response: r.Content, Done: r.Done,
Response: r.Content,
Metrics: api.Metrics{ Metrics: api.Metrics{
TotalDuration: r.TotalDuration, TotalDuration: r.TotalDuration,
LoadDuration: r.LoadDuration, LoadDuration: r.LoadDuration,