add model IDs (#439)

This commit is contained in:
Patrick Devine 2023-08-28 20:50:24 -07:00 committed by GitHub
parent d3b838ce60
commit 8bbff2df98
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 27 additions and 19 deletions

View file

@ -96,6 +96,7 @@ type ListResponseModel struct {
Name string `json:"name"` Name string `json:"name"`
ModifiedAt time.Time `json:"modified_at"` ModifiedAt time.Time `json:"modified_at"`
Size int `json:"size"` Size int `json:"size"`
Digest string `json:"digest"`
} }
type TokenResponse struct { type TokenResponse struct {

View file

@ -196,12 +196,12 @@ func ListHandler(cmd *cobra.Command, args []string) error {
for _, m := range models.Models { for _, m := range models.Models {
if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) { if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) {
data = append(data, []string{m.Name, humanize.Bytes(uint64(m.Size)), format.HumanTime(m.ModifiedAt, "Never")}) data = append(data, []string{m.Name, m.Digest[:12], humanize.Bytes(uint64(m.Size)), format.HumanTime(m.ModifiedAt, "Never")})
} }
} }
table := tablewriter.NewWriter(os.Stdout) table := tablewriter.NewWriter(os.Stdout)
table.SetHeader([]string{"NAME", "SIZE", "MODIFIED"}) table.SetHeader([]string{"NAME", "ID", "SIZE", "MODIFIED"})
table.SetHeaderAlignment(tablewriter.ALIGN_LEFT) table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
table.SetAlignment(tablewriter.ALIGN_LEFT) table.SetAlignment(tablewriter.ALIGN_LEFT)
table.SetHeaderLine(false) table.SetHeaderLine(false)
@ -527,7 +527,7 @@ func generateInteractive(cmd *cobra.Command, model string) error {
return err return err
} }
manifest, err := server.GetManifest(mp) manifest, _, err := server.GetManifest(mp)
if err != nil { if err != nil {
fmt.Println("error: couldn't get a manifest for this model") fmt.Println("error: couldn't get a manifest for this model")
continue continue

View file

@ -5,6 +5,7 @@ import (
"bytes" "bytes"
"context" "context"
"crypto/sha256" "crypto/sha256"
"encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -44,6 +45,7 @@ type Model struct {
Template string Template string
System string System string
Digest string Digest string
ConfigDigest string
Options map[string]interface{} Options map[string]interface{}
Embeddings []vector.Embedding Embeddings []vector.Embedding
} }
@ -131,40 +133,44 @@ func (m *ManifestV2) GetTotalSize() int {
return total return total
} }
func GetManifest(mp ModelPath) (*ManifestV2, error) { func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
fp, err := mp.GetManifestPath(false) fp, err := mp.GetManifestPath(false)
if err != nil { if err != nil {
return nil, err return nil, "", err
} }
if _, err = os.Stat(fp); err != nil { if _, err = os.Stat(fp); err != nil {
return nil, err return nil, "", err
} }
var manifest *ManifestV2 var manifest *ManifestV2
bts, err := os.ReadFile(fp) bts, err := os.ReadFile(fp)
if err != nil { if err != nil {
return nil, fmt.Errorf("couldn't open file '%s'", fp) return nil, "", fmt.Errorf("couldn't open file '%s'", fp)
} }
shaSum := sha256.Sum256(bts)
shaStr := hex.EncodeToString(shaSum[:])
if err := json.Unmarshal(bts, &manifest); err != nil { if err := json.Unmarshal(bts, &manifest); err != nil {
return nil, err return nil, "", err
} }
return manifest, nil return manifest, shaStr, nil
} }
func GetModel(name string) (*Model, error) { func GetModel(name string) (*Model, error) {
mp := ParseModelPath(name) mp := ParseModelPath(name)
manifest, err := GetManifest(mp) manifest, digest, err := GetManifest(mp)
if err != nil { if err != nil {
return nil, err return nil, err
} }
model := &Model{ model := &Model{
Name: mp.GetFullTagname(), Name: mp.GetFullTagname(),
Digest: manifest.Config.Digest, Digest: digest,
ConfigDigest: manifest.Config.Digest,
Template: "{{ .Prompt }}", Template: "{{ .Prompt }}",
} }
@ -277,7 +283,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
embed.model = c.Args embed.model = c.Args
mp := ParseModelPath(c.Args) mp := ParseModelPath(c.Args)
mf, err := GetManifest(mp) mf, _, err := GetManifest(mp)
if err != nil { if err != nil {
modelFile, err := filenameWithPath(path, c.Args) modelFile, err := filenameWithPath(path, c.Args)
if err != nil { if err != nil {
@ -290,7 +296,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil { if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
return err return err
} }
mf, err = GetManifest(mp) mf, _, err = GetManifest(mp)
if err != nil { if err != nil {
return fmt.Errorf("failed to open file after pull: %v", err) return fmt.Errorf("failed to open file after pull: %v", err)
} }
@ -839,7 +845,7 @@ func CopyModel(src, dest string) error {
func DeleteModel(name string) error { func DeleteModel(name string) error {
mp := ParseModelPath(name) mp := ParseModelPath(name)
manifest, err := GetManifest(mp) manifest, _, err := GetManifest(mp)
if err != nil { if err != nil {
return err return err
} }
@ -872,7 +878,7 @@ func DeleteModel(name string) error {
} }
// save (i.e. delete from the deleteMap) any files used in other manifests // save (i.e. delete from the deleteMap) any files used in other manifests
manifest, err := GetManifest(fmp) manifest, _, err := GetManifest(fmp)
if err != nil { if err != nil {
log.Printf("skipping file: %s", fp) log.Printf("skipping file: %s", fp)
return nil return nil
@ -924,7 +930,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
return fmt.Errorf("insecure protocol http") return fmt.Errorf("insecure protocol http")
} }
manifest, err := GetManifest(mp) manifest, _, err := GetManifest(mp)
if err != nil { if err != nil {
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"}) fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
return err return err

View file

@ -373,7 +373,7 @@ func ListModelsHandler(c *gin.Context) {
tag := path[:slashIndex] + ":" + path[slashIndex+1:] tag := path[:slashIndex] + ":" + path[slashIndex+1:]
mp := ParseModelPath(tag) mp := ParseModelPath(tag)
manifest, err := GetManifest(mp) manifest, digest, err := GetManifest(mp)
if err != nil { if err != nil {
log.Printf("skipping file: %s", fp) log.Printf("skipping file: %s", fp)
return nil return nil
@ -381,6 +381,7 @@ func ListModelsHandler(c *gin.Context) {
model := api.ListResponseModel{ model := api.ListResponseModel{
Name: mp.GetShortTagname(), Name: mp.GetShortTagname(),
Size: manifest.GetTotalSize(), Size: manifest.GetTotalSize(),
Digest: digest,
ModifiedAt: fi.ModTime(), ModifiedAt: fi.ModTime(),
} }
models = append(models, model) models = append(models, model)