improve api error handling (#781)

- remove new lines from llama.cpp error messages relayed to client
- check api option types and return error on wrong type
- change num layers from 95% VRAM to 92% VRAM
This commit is contained in:
Bruce MacDonald 2023-10-13 16:57:10 -04:00 committed by GitHub
parent d890890f66
commit 6fe178134d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 16 deletions

View file

@ -3,7 +3,6 @@ package api
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"math" "math"
"os" "os"
"reflect" "reflect"
@ -238,44 +237,39 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
// when JSON unmarshals numbers, it uses float64, not int // when JSON unmarshals numbers, it uses float64, not int
field.SetInt(int64(t)) field.SetInt(int64(t))
default: default:
log.Printf("could not convert model parameter %v of type %T to int, skipped", key, val) return fmt.Errorf("option %q must be of type integer", key)
} }
case reflect.Bool: case reflect.Bool:
val, ok := val.(bool) val, ok := val.(bool)
if !ok { if !ok {
log.Printf("could not convert model parameter %v of type %T to bool, skipped", key, val) return fmt.Errorf("option %q must be of type boolean", key)
continue
} }
field.SetBool(val) field.SetBool(val)
case reflect.Float32: case reflect.Float32:
// JSON unmarshals to float64 // JSON unmarshals to float64
val, ok := val.(float64) val, ok := val.(float64)
if !ok { if !ok {
log.Printf("could not convert model parameter %v of type %T to float32, skipped", key, val) return fmt.Errorf("option %q must be of type float32", key)
continue
} }
field.SetFloat(val) field.SetFloat(val)
case reflect.String: case reflect.String:
val, ok := val.(string) val, ok := val.(string)
if !ok { if !ok {
log.Printf("could not convert model parameter %v of type %T to string, skipped", key, val) return fmt.Errorf("option %q must be of type string", key)
continue
} }
field.SetString(val) field.SetString(val)
case reflect.Slice: case reflect.Slice:
// JSON unmarshals to []interface{}, not []string // JSON unmarshals to []interface{}, not []string
val, ok := val.([]interface{}) val, ok := val.([]interface{})
if !ok { if !ok {
log.Printf("could not convert model parameter %v of type %T to slice, skipped", key, val) return fmt.Errorf("option %q must be of type array", key)
continue
} }
// convert []interface{} to []string // convert []interface{} to []string
slice := make([]string, len(val)) slice := make([]string, len(val))
for i, item := range val { for i, item := range val {
str, ok := item.(string) str, ok := item.(string)
if !ok { if !ok {
log.Printf("could not convert model parameter %v of type %T to slice of strings, skipped", key, item) return fmt.Errorf("option %q must be of an array of strings", key)
continue
} }
slice[i] = str slice[i] = str
} }

View file

@ -238,8 +238,8 @@ func NumGPU(numLayer, fileSizeBytes int64, opts api.Options) int {
// TODO: this is a rough heuristic, better would be to calculate this based on number of layers and context size // TODO: this is a rough heuristic, better would be to calculate this based on number of layers and context size
bytesPerLayer := fileSizeBytes / numLayer bytesPerLayer := fileSizeBytes / numLayer
// max number of layers we can fit in VRAM, subtract 5% to prevent consuming all available VRAM and running out of memory // max number of layers we can fit in VRAM, subtract 8% to prevent consuming all available VRAM and running out of memory
layers := int(freeVramBytes/bytesPerLayer) * 95 / 100 layers := int(freeVramBytes/bytesPerLayer) * 92 / 100
log.Printf("%d MiB VRAM available, loading up to %d GPU layers", vramMib, layers) log.Printf("%d MiB VRAM available, loading up to %d GPU layers", vramMib, layers)
return layers return layers
@ -261,8 +261,7 @@ func NewStatusWriter() *StatusWriter {
func (w *StatusWriter) Write(b []byte) (int, error) { func (w *StatusWriter) Write(b []byte) (int, error) {
if _, after, ok := bytes.Cut(b, []byte("error:")); ok { if _, after, ok := bytes.Cut(b, []byte("error:")); ok {
err := fmt.Errorf("llama runner: %s", after) w.ErrCh <- fmt.Errorf("llama runner: %s", bytes.TrimSpace(after))
w.ErrCh <- err
} }
return os.Stderr.Write(b) return os.Stderr.Write(b)
} }