From 7804b8fab9595ea074f97c693467d9ef524e6f15 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Thu, 12 Oct 2023 11:18:11 -0400 Subject: [PATCH] validate api options fields from map (#711) --- api/types.go | 9 +++++++++ server/routes.go | 5 ++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/api/types.go b/api/types.go index 002db000..2164bec5 100644 --- a/api/types.go +++ b/api/types.go @@ -205,6 +205,8 @@ type Options struct { NumThread int `json:"num_thread,omitempty"` } +var ErrInvalidOpts = fmt.Errorf("invalid options") + func (opts *Options) FromMap(m map[string]interface{}) error { valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct @@ -218,6 +220,7 @@ func (opts *Options) FromMap(m map[string]interface{}) error { } } + invalidOpts := []string{} for key, val := range m { if opt, ok := jsonOpts[key]; ok { field := valueOpts.FieldByName(opt.Name) @@ -281,8 +284,14 @@ func (opts *Options) FromMap(m map[string]interface{}) error { return fmt.Errorf("unknown type loading config params: %v", field.Kind()) } } + } else { + invalidOpts = append(invalidOpts, key) } } + + if len(invalidOpts) > 0 { + return fmt.Errorf("%w: %v", ErrInvalidOpts, strings.Join(invalidOpts, ", ")) + } return nil } diff --git a/server/routes.go b/server/routes.go index 34cbb05e..9d342602 100644 --- a/server/routes.go +++ b/server/routes.go @@ -68,7 +68,6 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string] } if err := opts.FromMap(reqOpts); err != nil { - log.Printf("could not merge model options: %v", err) return err } @@ -186,6 +185,10 @@ func GenerateHandler(c *gin.Context) { // TODO: set this duration from the request if specified sessionDuration := defaultSessionDuration if err := load(c.Request.Context(), workDir, model, req.Options, sessionDuration); err != nil { + if errors.Is(err, api.ErrInvalidOpts) { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return }