if directory cannot be resolved, do not fail

This commit is contained in:
Bruce MacDonald 2023-07-07 15:27:43 -04:00
parent b24be8c6b3
commit 61dd87bd90
3 changed files with 43 additions and 2 deletions

View file

@ -26,6 +26,7 @@ type PullProgress struct {
Total int64 `json:"total"` Total int64 `json:"total"`
Completed int64 `json:"completed"` Completed int64 `json:"completed"`
Percent float64 `json:"percent"` Percent float64 `json:"percent"`
Error Error `json:"error"`
} }
type GenerateRequest struct { type GenerateRequest struct {

View file

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"log" "log"
"net" "net"
"net/http"
"os" "os"
"path" "path"
"strings" "strings"
@ -50,6 +51,10 @@ func pull(model string) error {
context.Background(), context.Background(),
&api.PullRequest{Model: model}, &api.PullRequest{Model: model},
func(progress api.PullProgress) error { func(progress api.PullProgress) error {
if progress.Error.Code == http.StatusBadGateway {
// couldn't pull the model from the directory, proceed in offline mode
return nil
}
if bar == nil && progress.Percent == 100 { if bar == nil && progress.Percent == 100 {
// already downloaded // already downloaded
return nil return nil

View file

@ -3,12 +3,14 @@ package server
import ( import (
"embed" "embed"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
"math" "math"
"net" "net"
"net/http" "net/http"
"os"
"path" "path"
"runtime" "runtime"
"strings" "strings"
@ -25,6 +27,15 @@ import (
var templatesFS embed.FS var templatesFS embed.FS
var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt")) var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt"))
func cacheDir() string {
home, err := os.UserHomeDir()
if err != nil {
panic(err)
}
return path.Join(home, ".ollama")
}
func generate(c *gin.Context) { func generate(c *gin.Context) {
var req api.GenerateRequest var req api.GenerateRequest
req.ModelOptions = api.DefaultModelOptions req.ModelOptions = api.DefaultModelOptions
@ -34,7 +45,20 @@ func generate(c *gin.Context) {
return return
} }
if remoteModel, _ := getRemote(req.Model); remoteModel != nil { remoteModel, err := getRemote(req.Model)
if err != nil {
// couldn't check the directory, proceed in offline mode
_, err := os.Stat(req.Model)
if err != nil {
if !os.IsNotExist(err) {
c.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
return
}
// couldn't find the model file, try setting the model to the cache directory
req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
}
}
if remoteModel != nil {
req.Model = remoteModel.FullName() req.Model = remoteModel.FullName()
} }
@ -118,6 +142,17 @@ func Serve(ln net.Listener) error {
go func() { go func() {
defer close(progressCh) defer close(progressCh)
if err := pull(req.Model, progressCh); err != nil { if err := pull(req.Model, progressCh); err != nil {
var opError *net.OpError
if errors.As(err, &opError) {
result := api.PullProgress{
Error: api.Error{
Code: http.StatusBadGateway,
Message: "failed to get models from directory",
},
}
c.JSON(http.StatusBadGateway, result)
return
}
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
return return
} }