Add MarshalJSON to Duration (#3284)

---------

Co-authored-by: Patrick Devine <patrick@infrahq.com>
This commit is contained in:
Jackie Li 2024-05-06 23:59:18 +01:00 committed by GitHub
parent d091fe3c21
commit af47413dba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 67 additions and 1 deletions

View file

@ -436,6 +436,13 @@ type Duration struct {
time.Duration time.Duration
} }
func (d Duration) MarshalJSON() ([]byte, error) {
if d.Duration < 0 {
return []byte("-1"), nil
}
return []byte("\"" + d.Duration.String() + "\""), nil
}
func (d *Duration) UnmarshalJSON(b []byte) (err error) { func (d *Duration) UnmarshalJSON(b []byte) (err error) {
var v any var v any
if err := json.Unmarshal(b, &v); err != nil { if err := json.Unmarshal(b, &v); err != nil {
@ -449,7 +456,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
if t < 0 { if t < 0 {
d.Duration = time.Duration(math.MaxInt64) d.Duration = time.Duration(math.MaxInt64)
} else { } else {
d.Duration = time.Duration(t * float64(time.Second)) d.Duration = time.Duration(int(t) * int(time.Second))
} }
case string: case string:
d.Duration, err = time.ParseDuration(t) d.Duration, err = time.ParseDuration(t)
@ -459,6 +466,8 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
if d.Duration < 0 { if d.Duration < 0 {
d.Duration = time.Duration(math.MaxInt64) d.Duration = time.Duration(math.MaxInt64)
} }
default:
return fmt.Errorf("Unsupported type: '%s'", reflect.TypeOf(v))
} }
return nil return nil

View file

@ -21,6 +21,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
req: `{ "keep_alive": 42 }`, req: `{ "keep_alive": 42 }`,
exp: &Duration{42 * time.Second}, exp: &Duration{42 * time.Second},
}, },
{
name: "Positive Float",
req: `{ "keep_alive": 42.5 }`,
exp: &Duration{42 * time.Second},
},
{ {
name: "Positive Integer String", name: "Positive Integer String",
req: `{ "keep_alive": "42m" }`, req: `{ "keep_alive": "42m" }`,
@ -31,6 +36,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
req: `{ "keep_alive": -1 }`, req: `{ "keep_alive": -1 }`,
exp: &Duration{math.MaxInt64}, exp: &Duration{math.MaxInt64},
}, },
{
name: "Negative Float",
req: `{ "keep_alive": -3.14 }`,
exp: &Duration{math.MaxInt64},
},
{ {
name: "Negative Integer String", name: "Negative Integer String",
req: `{ "keep_alive": "-1m" }`, req: `{ "keep_alive": "-1m" }`,
@ -48,3 +58,50 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
}) })
} }
} }
func TestDurationMarshalUnmarshal(t *testing.T) {
tests := []struct {
name string
input time.Duration
expected time.Duration
}{
{
"negative duration",
time.Duration(-1),
time.Duration(math.MaxInt64),
},
{
"positive duration",
time.Duration(42 * time.Second),
time.Duration(42 * time.Second),
},
{
"another positive duration",
time.Duration(42 * time.Minute),
time.Duration(42 * time.Minute),
},
{
"zero duration",
time.Duration(0),
time.Duration(0),
},
{
"max duration",
time.Duration(math.MaxInt64),
time.Duration(math.MaxInt64),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
b, err := json.Marshal(Duration{test.input})
require.NoError(t, err)
var d Duration
err = json.Unmarshal(b, &d)
require.NoError(t, err)
assert.Equal(t, test.expected, d.Duration, "input %v, marshalled %v, got %v", test.input, string(b), d.Duration)
})
}
}