check api status

This commit is contained in:
Michael Yang 2023-07-11 13:05:51 -07:00
parent 2a66a1164a
commit e243329e2e
3 changed files with 38 additions and 11 deletions

View file

@ -10,6 +10,20 @@ import (
"net/url" "net/url"
) )
type StatusError struct {
StatusCode int
Status string
Message string
}
func (e StatusError) Error() string {
if e.Message != "" {
return fmt.Sprintf("%s: %s", e.Status, e.Message)
}
return e.Status
}
type Client struct { type Client struct {
base url.URL base url.URL
} }
@ -25,7 +39,7 @@ func NewClient(hosts ...string) *Client {
} }
} }
func (c *Client) stream(ctx context.Context, method, path string, data any, callback func([]byte) error) error { func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf *bytes.Buffer var buf *bytes.Buffer
if data != nil { if data != nil {
bts, err := json.Marshal(data) bts, err := json.Marshal(data)
@ -53,7 +67,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, call
scanner := bufio.NewScanner(response.Body) scanner := bufio.NewScanner(response.Body)
for scanner.Scan() { for scanner.Scan() {
var errorResponse struct { var errorResponse struct {
Error string `json:"error"` Error string `json:"error,omitempty"`
} }
bts := scanner.Bytes() bts := scanner.Bytes()
@ -61,11 +75,15 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, call
return fmt.Errorf("unmarshal: %w", err) return fmt.Errorf("unmarshal: %w", err)
} }
if len(errorResponse.Error) > 0 { if response.StatusCode >= 400 {
return fmt.Errorf("stream: %s", errorResponse.Error) return StatusError{
StatusCode: response.StatusCode,
Status: response.Status,
Message: errorResponse.Error,
}
} }
if err := callback(bts); err != nil { if err := fn(bts); err != nil {
return err return err
} }
} }

View file

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"log" "log"
"net" "net"
"net/http"
"os" "os"
"path" "path"
"strings" "strings"
@ -34,8 +35,15 @@ func RunRun(cmd *cobra.Command, args []string) error {
switch { switch {
case errors.Is(err, os.ErrNotExist): case errors.Is(err, os.ErrNotExist):
if err := pull(args[0]); err != nil { if err := pull(args[0]); err != nil {
var apiStatusError api.StatusError
if !errors.As(err, &apiStatusError) {
return err return err
} }
if apiStatusError.StatusCode != http.StatusBadGateway {
return err
}
}
case err != nil: case err != nil:
return err return err
} }
@ -50,11 +58,12 @@ func pull(model string) error {
context.Background(), context.Background(),
&api.PullRequest{Model: model}, &api.PullRequest{Model: model},
func(progress api.PullProgress) error { func(progress api.PullProgress) error {
if bar == nil && progress.Percent == 100 { if bar == nil {
if progress.Percent == 100 {
// already downloaded // already downloaded
return nil return nil
} }
if bar == nil {
bar = progressbar.DefaultBytes(progress.Total) bar = progressbar.DefaultBytes(progress.Total)
} }

View file

@ -108,7 +108,7 @@ func pull(c *gin.Context) {
remote, err := getRemote(req.Model) remote, err := getRemote(req.Model)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return return
} }