validate api options fields from map (#711)

This commit is contained in:
Bruce MacDonald 2023-10-12 11:18:11 -04:00 committed by GitHub
parent 56497663c8
commit 7804b8fab9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 1 deletions

View file

@ -205,6 +205,8 @@ type Options struct {
NumThread int `json:"num_thread,omitempty"` NumThread int `json:"num_thread,omitempty"`
} }
var ErrInvalidOpts = fmt.Errorf("invalid options")
func (opts *Options) FromMap(m map[string]interface{}) error { func (opts *Options) FromMap(m map[string]interface{}) error {
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
@ -218,6 +220,7 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
} }
} }
invalidOpts := []string{}
for key, val := range m { for key, val := range m {
if opt, ok := jsonOpts[key]; ok { if opt, ok := jsonOpts[key]; ok {
field := valueOpts.FieldByName(opt.Name) field := valueOpts.FieldByName(opt.Name)
@ -281,8 +284,14 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
return fmt.Errorf("unknown type loading config params: %v", field.Kind()) return fmt.Errorf("unknown type loading config params: %v", field.Kind())
} }
} }
} else {
invalidOpts = append(invalidOpts, key)
} }
} }
if len(invalidOpts) > 0 {
return fmt.Errorf("%w: %v", ErrInvalidOpts, strings.Join(invalidOpts, ", "))
}
return nil return nil
} }

View file

@ -68,7 +68,6 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
} }
if err := opts.FromMap(reqOpts); err != nil { if err := opts.FromMap(reqOpts); err != nil {
log.Printf("could not merge model options: %v", err)
return err return err
} }
@ -186,6 +185,10 @@ func GenerateHandler(c *gin.Context) {
// TODO: set this duration from the request if specified // TODO: set this duration from the request if specified
sessionDuration := defaultSessionDuration sessionDuration := defaultSessionDuration
if err := load(c.Request.Context(), workDir, model, req.Options, sessionDuration); err != nil { if err := load(c.Request.Context(), workDir, model, req.Options, sessionDuration); err != nil {
if errors.Is(err, api.ErrInvalidOpts) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }