Merge pull request #5205 from dhiltgen/modelfile_use_mmap
Fix use_mmap parsing for modelfiles
This commit is contained in:
commit
ccef9431c8
2 changed files with 76 additions and 0 deletions
13
api/types.go
13
api/types.go
|
@ -608,6 +608,19 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) {
|
||||||
} else {
|
} else {
|
||||||
field := valueOpts.FieldByName(opt.Name)
|
field := valueOpts.FieldByName(opt.Name)
|
||||||
if field.IsValid() && field.CanSet() {
|
if field.IsValid() && field.CanSet() {
|
||||||
|
if reflect.PointerTo(field.Type()) == reflect.TypeOf((*TriState)(nil)) {
|
||||||
|
boolVal, err := strconv.ParseBool(vals[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid bool value %s", vals)
|
||||||
|
}
|
||||||
|
if boolVal {
|
||||||
|
out[key] = TriStateTrue
|
||||||
|
} else {
|
||||||
|
out[key] = TriStateFalse
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
switch field.Kind() {
|
switch field.Kind() {
|
||||||
case reflect.Float32:
|
case reflect.Float32:
|
||||||
floatVal, err := strconv.ParseFloat(vals[0], 32)
|
floatVal, err := strconv.ParseFloat(vals[0], 32)
|
||||||
|
|
|
@ -2,6 +2,7 @@ package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -141,3 +142,65 @@ func TestUseMmapParsingFromJSON(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUseMmapFormatParams(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req map[string][]string
|
||||||
|
exp TriState
|
||||||
|
err error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "True",
|
||||||
|
req: map[string][]string{
|
||||||
|
"use_mmap": []string{"true"},
|
||||||
|
},
|
||||||
|
exp: TriStateTrue,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "False",
|
||||||
|
req: map[string][]string{
|
||||||
|
"use_mmap": []string{"false"},
|
||||||
|
},
|
||||||
|
exp: TriStateFalse,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Numeric True",
|
||||||
|
req: map[string][]string{
|
||||||
|
"use_mmap": []string{"1"},
|
||||||
|
},
|
||||||
|
exp: TriStateTrue,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Numeric False",
|
||||||
|
req: map[string][]string{
|
||||||
|
"use_mmap": []string{"0"},
|
||||||
|
},
|
||||||
|
exp: TriStateFalse,
|
||||||
|
err: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid string",
|
||||||
|
req: map[string][]string{
|
||||||
|
"use_mmap": []string{"foo"},
|
||||||
|
},
|
||||||
|
exp: TriStateUndefined,
|
||||||
|
err: fmt.Errorf("invalid bool value [foo]"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
resp, err := FormatParams(test.req)
|
||||||
|
require.Equal(t, err, test.err)
|
||||||
|
respVal, ok := resp["use_mmap"]
|
||||||
|
if test.exp != TriStateUndefined {
|
||||||
|
assert.True(t, ok, "resp: %v", resp)
|
||||||
|
assert.Equal(t, test.exp, respVal)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue