add capabilities

This commit is contained in:
Michael Yang 2024-06-11 14:03:42 -07:00
parent 58e3fff311
commit a30915bde1
3 changed files with 26 additions and 10 deletions

View file

@ -34,6 +34,10 @@ import (
"github.com/ollama/ollama/version"
)
type Capability string
const CapabilityCompletion = Capability("completion")
type registryOptions struct {
Insecure bool
Username string
@ -58,8 +62,20 @@ type Model struct {
Template *template.Template
}
func (m *Model) IsEmbedding() bool {
return slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert")
func (m *Model) Has(caps ...Capability) bool {
for _, cap := range caps {
switch cap {
case CapabilityCompletion:
if slices.Contains(m.Config.ModelFamilies, "bert") || slices.Contains(m.Config.ModelFamilies, "nomic-bert") {
return false
}
default:
slog.Error("unknown capability", "capability", cap)
return false
}
}
return true
}
func (m *Model) String() string {

View file

@ -122,8 +122,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
if model.IsEmbedding() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support generate"})
if !model.Has(CapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support generate", req.Model)})
return
}
@ -1308,8 +1308,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
if model.IsEmbedding() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "embedding models do not support chat"})
if !model.Has(CapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%s does not support chat", req.Model)})
return
}

View file

@ -61,8 +61,8 @@ func TestNamed(t *testing.T) {
func TestParse(t *testing.T) {
cases := []struct {
template string
capabilities []string
template string
vars []string
}{
{"{{ .Prompt }}", []string{"prompt"}},
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "system"}},
@ -81,8 +81,8 @@ func TestParse(t *testing.T) {
}
vars := tmpl.Vars()
if !slices.Equal(tt.capabilities, vars) {
t.Errorf("expected %v, got %v", tt.capabilities, vars)
if !slices.Equal(tt.vars, vars) {
t.Errorf("expected %v, got %v", tt.vars, vars)
}
})
}