if directory cannot be resolved, do not fail
This commit is contained in:
parent
b24be8c6b3
commit
61dd87bd90
3 changed files with 43 additions and 2 deletions
|
@ -26,6 +26,7 @@ type PullProgress struct {
|
||||||
Total int64 `json:"total"`
|
Total int64 `json:"total"`
|
||||||
Completed int64 `json:"completed"`
|
Completed int64 `json:"completed"`
|
||||||
Percent float64 `json:"percent"`
|
Percent float64 `json:"percent"`
|
||||||
|
Error Error `json:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GenerateRequest struct {
|
type GenerateRequest struct {
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -50,6 +51,10 @@ func pull(model string) error {
|
||||||
context.Background(),
|
context.Background(),
|
||||||
&api.PullRequest{Model: model},
|
&api.PullRequest{Model: model},
|
||||||
func(progress api.PullProgress) error {
|
func(progress api.PullProgress) error {
|
||||||
|
if progress.Error.Code == http.StatusBadGateway {
|
||||||
|
// couldn't pull the model from the directory, proceed in offline mode
|
||||||
|
return nil
|
||||||
|
}
|
||||||
if bar == nil && progress.Percent == 100 {
|
if bar == nil && progress.Percent == 100 {
|
||||||
// already downloaded
|
// already downloaded
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -3,12 +3,14 @@ package server
|
||||||
import (
|
import (
|
||||||
"embed"
|
"embed"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -25,6 +27,15 @@ import (
|
||||||
var templatesFS embed.FS
|
var templatesFS embed.FS
|
||||||
var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt"))
|
var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt"))
|
||||||
|
|
||||||
|
func cacheDir() string {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return path.Join(home, ".ollama")
|
||||||
|
}
|
||||||
|
|
||||||
func generate(c *gin.Context) {
|
func generate(c *gin.Context) {
|
||||||
var req api.GenerateRequest
|
var req api.GenerateRequest
|
||||||
req.ModelOptions = api.DefaultModelOptions
|
req.ModelOptions = api.DefaultModelOptions
|
||||||
|
@ -34,7 +45,20 @@ func generate(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if remoteModel, _ := getRemote(req.Model); remoteModel != nil {
|
remoteModel, err := getRemote(req.Model)
|
||||||
|
if err != nil {
|
||||||
|
// couldn't check the directory, proceed in offline mode
|
||||||
|
_, err := os.Stat(req.Model)
|
||||||
|
if err != nil {
|
||||||
|
if !os.IsNotExist(err) {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// couldn't find the model file, try setting the model to the cache directory
|
||||||
|
req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if remoteModel != nil {
|
||||||
req.Model = remoteModel.FullName()
|
req.Model = remoteModel.FullName()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,6 +142,17 @@ func Serve(ln net.Listener) error {
|
||||||
go func() {
|
go func() {
|
||||||
defer close(progressCh)
|
defer close(progressCh)
|
||||||
if err := pull(req.Model, progressCh); err != nil {
|
if err := pull(req.Model, progressCh); err != nil {
|
||||||
|
var opError *net.OpError
|
||||||
|
if errors.As(err, &opError) {
|
||||||
|
result := api.PullProgress{
|
||||||
|
Error: api.Error{
|
||||||
|
Code: http.StatusBadGateway,
|
||||||
|
Message: "failed to get models from directory",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusBadGateway, result)
|
||||||
|
return
|
||||||
|
}
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue