if directory cannot be resolved, do not fail

This commit is contained in:
Bruce MacDonald 2023-07-07 23:18:25 -04:00 committed by GitHub
commit 0bee4a8c07
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 1 deletions

View file

@ -106,6 +106,11 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc
return err return err
} }
if resp.Error.Message != "" {
// couldn't pull the model from the directory, proceed anyway
return nil
}
return fn(resp) return fn(resp)
}), }),
) )

View file

@ -26,6 +26,7 @@ type PullProgress struct {
Total int64 `json:"total"` Total int64 `json:"total"`
Completed int64 `json:"completed"` Completed int64 `json:"completed"`
Percent float64 `json:"percent"` Percent float64 `json:"percent"`
Error Error `json:"error"`
} }
type GenerateRequest struct { type GenerateRequest struct {

View file

@ -3,12 +3,14 @@ package server
import ( import (
"embed" "embed"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
"math" "math"
"net" "net"
"net/http" "net/http"
"os"
"path" "path"
"runtime" "runtime"
"strings" "strings"
@ -25,6 +27,15 @@ import (
var templatesFS embed.FS var templatesFS embed.FS
var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt")) var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt"))
func cacheDir() string {
home, err := os.UserHomeDir()
if err != nil {
panic(err)
}
return path.Join(home, ".ollama")
}
func generate(c *gin.Context) { func generate(c *gin.Context) {
var req api.GenerateRequest var req api.GenerateRequest
req.ModelOptions = api.DefaultModelOptions req.ModelOptions = api.DefaultModelOptions
@ -37,9 +48,16 @@ func generate(c *gin.Context) {
if remoteModel, _ := getRemote(req.Model); remoteModel != nil { if remoteModel, _ := getRemote(req.Model); remoteModel != nil {
req.Model = remoteModel.FullName() req.Model = remoteModel.FullName()
} }
if _, err := os.Stat(req.Model); err != nil {
if !errors.Is(err, os.ErrNotExist) {
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
return
}
req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
}
modelOpts := getModelOpts(req) modelOpts := getModelOpts(req)
modelOpts.NGPULayers = 1 // hard-code this for now modelOpts.NGPULayers = 1 // hard-code this for now
model, err := llama.New(req.Model, modelOpts) model, err := llama.New(req.Model, modelOpts)
if err != nil { if err != nil {
@ -118,6 +136,17 @@ func Serve(ln net.Listener) error {
go func() { go func() {
defer close(progressCh) defer close(progressCh)
if err := pull(req.Model, progressCh); err != nil { if err := pull(req.Model, progressCh); err != nil {
var opError *net.OpError
if errors.As(err, &opError) {
result := api.PullProgress{
Error: api.Error{
Code: http.StatusBadGateway,
Message: "failed to get models from directory",
},
}
c.JSON(http.StatusBadGateway, result)
return
}
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
return return
} }