From a30915bde166b2f392a0ff72c61c9ac53189a962 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 11 Jun 2024 14:03:42 -0700 Subject: [PATCH] add capabilities --- server/images.go | 20 ++++++++++++++++++-- server/routes.go | 8 ++++---- template/template_test.go | 8 ++++---- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/server/images.go b/server/images.go index 65ed51c7..5cd0a7a5 100644 --- a/server/images.go +++ b/server/images.go @@ -34,6 +34,10 @@ import ( "github.com/ollama/ollama/version" ) +type Capability string + +const CapabilityCompletion = Capability("completion") + type registryOptions struct { Insecure bool Username string @@ -58,8 +62,20 @@ type Model struct { Template *template.Template } -func (m *Model) IsEmbedding() bool { - return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert") +func (m *Model) Has(caps ...Capability) bool { + for _, cap := range caps { + switch cap { + case CapabilityCompletion: + if slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert") { + return false + } + default: + slog.Error("unknown capability", "capability", cap) + return false + } + } + + return true } func (m *Model) String() string { diff --git a/server/routes.go b/server/routes.go index d8a4a67e..8ca6dcc8 100644 --- a/server/routes.go +++ b/server/routes.go @@ -122,8 +122,8 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - if model.IsEmbedding() { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support generate"}) + if !model.Has(CapabilityCompletion) { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)}) return } @@ -1308,8 +1308,8 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - if model.IsEmbedding() { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support chat"}) + if !model.Has(CapabilityCompletion) { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)}) return } diff --git a/template/template_test.go b/template/template_test.go index e5405bdb..eda4634f 100644 --- a/template/template_test.go +++ b/template/template_test.go @@ -61,8 +61,8 @@ func TestNamed(t *testing.T) { func TestParse(t *testing.T) { cases := []struct { - template string - capabilities []string + template string + vars []string }{ {"{{ .Prompt }}", []string{"prompt"}}, {"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}}, @@ -81,8 +81,8 @@ func TestParse(t *testing.T) { } vars := tmpl.Vars() - if !slices.Equal(tt.capabilities, vars) { - t.Errorf("expected %v, got %v", tt.capabilities, vars) + if !slices.Equal(tt.vars, vars) { + t.Errorf("expected %v, got %v", tt.vars, vars) } }) }