if directory cannot be resolved, do not fail

This commit is contained in:
Bruce MacDonald 2023-07-07 23:18:25 -04:00 committed by GitHub
commit 0bee4a8c07
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 1 deletions

View file

@ -106,6 +106,11 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc
return err
}
if resp.Error.Message != "" {
// couldn't pull the model from the directory, proceed anyway
return nil
}
return fn(resp)
}),
)

View file

@ -26,6 +26,7 @@ type PullProgress struct {
Total int64 `json:"total"`
Completed int64 `json:"completed"`
Percent float64 `json:"percent"`
Error Error `json:"error"`
}
type GenerateRequest struct {

View file

@ -3,12 +3,14 @@ package server
import (
"embed"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"math"
"net"
"net/http"
"os"
"path"
"runtime"
"strings"
@ -25,6 +27,15 @@ import (
var templatesFS embed.FS
var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt"))
func cacheDir() string {
home, err := os.UserHomeDir()
if err != nil {
panic(err)
}
return path.Join(home, ".ollama")
}
func generate(c *gin.Context) {
var req api.GenerateRequest
req.ModelOptions = api.DefaultModelOptions
@ -37,6 +48,13 @@ func generate(c *gin.Context) {
if remoteModel, _ := getRemote(req.Model); remoteModel != nil {
req.Model = remoteModel.FullName()
}
if _, err := os.Stat(req.Model); err != nil {
if !errors.Is(err, os.ErrNotExist) {
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
return
}
req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
}
modelOpts := getModelOpts(req)
modelOpts.NGPULayers = 1 // hard-code this for now
@ -118,6 +136,17 @@ func Serve(ln net.Listener) error {
go func() {
defer close(progressCh)
if err := pull(req.Model, progressCh); err != nil {
var opError *net.OpError
if errors.As(err, &opError) {
result := api.PullProgress{
Error: api.Error{
Code: http.StatusBadGateway,
Message: "failed to get models from directory",
},
}
c.JSON(http.StatusBadGateway, result)
return
}
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
return
}