diff --git a/api/client.go b/api/client.go index 65d36ecb..e0b9b0aa 100644 --- a/api/client.go +++ b/api/client.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "encoding/json" + "fmt" "io" "net/http" "net/url" @@ -25,6 +26,18 @@ func NewClient(hosts ...string) *Client { } } +func StatusError(status int, message ...string) error { + if status < 400 { + return nil + } + + if len(message) > 0 && len(message[0]) > 0 { + return fmt.Errorf("%d %s: %s", status, http.StatusText(status), message[0]) + } + + return fmt.Errorf("%d %s", status, http.StatusText(status)) +} + type options struct { requestBody io.Reader responseFunc func(bts []byte) error @@ -70,7 +83,20 @@ func (c *Client) stream(ctx context.Context, method, path string, fns ...func(*o if opts.responseFunc != nil { scanner := bufio.NewScanner(response.Body) for scanner.Scan() { - if err := opts.responseFunc(scanner.Bytes()); err != nil { + var errorResponse struct { + Error string `json:"error"` + } + + bts := scanner.Bytes() + if err := json.Unmarshal(bts, &errorResponse); err != nil { + return err + } + + if err := StatusError(response.StatusCode, errorResponse.Error); err != nil { + return err + } + + if err := opts.responseFunc(bts); err != nil { return err } } diff --git a/api/types.go b/api/types.go index fbcf570e..5dc7488e 100644 --- a/api/types.go +++ b/api/types.go @@ -15,6 +15,7 @@ func (e Error) Error() string { if e.Message == "" { return fmt.Sprintf("%d %v", e.Code, strings.ToLower(http.StatusText(int(e.Code)))) } + return e.Message } diff --git a/cmd/cmd.go b/cmd/cmd.go index ad52853e..8421b8f5 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -100,14 +100,19 @@ func generate(model, prompt string) error { } }() - client.Generate(context.Background(), &api.GenerateRequest{Model: model, Prompt: prompt}, func(resp api.GenerateResponse) error { + request := api.GenerateRequest{Model: model, Prompt: prompt} + fn := func(resp api.GenerateResponse) error { if !spinner.IsFinished() { spinner.Finish() } fmt.Print(resp.Response) return nil - }) + } + + if err := client.Generate(context.Background(), &request, fn); err != nil { + return err + } fmt.Println() fmt.Println() diff --git a/server/routes.go b/server/routes.go index d9e2184f..7d0fdf72 100644 --- a/server/routes.go +++ b/server/routes.go @@ -4,7 +4,6 @@ import ( "embed" "encoding/json" "errors" - "fmt" "io" "log" "math" @@ -46,7 +45,7 @@ func generate(c *gin.Context) { req.PredictOptions = &api.DefaultPredictOptions } if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } @@ -66,7 +65,7 @@ func generate(c *gin.Context) { model, err := llama.New(req.Model, modelOpts) if err != nil { - fmt.Println("Loading the model failed:", err.Error()) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } defer model.Free() @@ -80,7 +79,7 @@ func generate(c *gin.Context) { if template := templates.Lookup(match); template != nil { var sb strings.Builder if err := template.Execute(&sb, req); err != nil { - fmt.Println("Prompt template failed:", err.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return }