From af47413dbab0bad3e9985cd9978cd9235851bfed Mon Sep 17 00:00:00 2001 From: Jackie Li Date: Mon, 6 May 2024 23:59:18 +0100 Subject: [PATCH] Add MarshalJSON to Duration (#3284) --------- Co-authored-by: Patrick Devine --- api/types.go | 11 ++++++++- api/types_test.go | 57 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/api/types.go b/api/types.go index 7cfd5ff7..70caee87 100644 --- a/api/types.go +++ b/api/types.go @@ -436,6 +436,13 @@ type Duration struct { time.Duration } +func (d Duration) MarshalJSON() ([]byte, error) { + if d.Duration < 0 { + return []byte("-1"), nil + } + return []byte("\"" + d.Duration.String() + "\""), nil +} + func (d *Duration) UnmarshalJSON(b []byte) (err error) { var v any if err := json.Unmarshal(b, &v); err != nil { @@ -449,7 +456,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) { if t < 0 { d.Duration = time.Duration(math.MaxInt64) } else { - d.Duration = time.Duration(t * float64(time.Second)) + d.Duration = time.Duration(int(t) * int(time.Second)) } case string: d.Duration, err = time.ParseDuration(t) @@ -459,6 +466,8 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) { if d.Duration < 0 { d.Duration = time.Duration(math.MaxInt64) } + default: + return fmt.Errorf("Unsupported type: '%s'", reflect.TypeOf(v)) } return nil diff --git a/api/types_test.go b/api/types_test.go index 5a093be2..cfe1331f 100644 --- a/api/types_test.go +++ b/api/types_test.go @@ -21,6 +21,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) { req: `{ "keep_alive": 42 }`, exp: &Duration{42 * time.Second}, }, + { + name: "Positive Float", + req: `{ "keep_alive": 42.5 }`, + exp: &Duration{42 * time.Second}, + }, { name: "Positive Integer String", req: `{ "keep_alive": "42m" }`, @@ -31,6 +36,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) { req: `{ "keep_alive": -1 }`, exp: &Duration{math.MaxInt64}, }, + { + name: "Negative Float", + req: `{ "keep_alive": -3.14 }`, + exp: &Duration{math.MaxInt64}, + }, { name: "Negative Integer String", req: `{ "keep_alive": "-1m" }`, @@ -48,3 +58,50 @@ func TestKeepAliveParsingFromJSON(t *testing.T) { }) } } + +func TestDurationMarshalUnmarshal(t *testing.T) { + tests := []struct { + name string + input time.Duration + expected time.Duration + }{ + { + "negative duration", + time.Duration(-1), + time.Duration(math.MaxInt64), + }, + { + "positive duration", + time.Duration(42 * time.Second), + time.Duration(42 * time.Second), + }, + { + "another positive duration", + time.Duration(42 * time.Minute), + time.Duration(42 * time.Minute), + }, + { + "zero duration", + time.Duration(0), + time.Duration(0), + }, + { + "max duration", + time.Duration(math.MaxInt64), + time.Duration(math.MaxInt64), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + b, err := json.Marshal(Duration{test.input}) + require.NoError(t, err) + + var d Duration + err = json.Unmarshal(b, &d) + require.NoError(t, err) + + assert.Equal(t, test.expected, d.Duration, "input %v, marshalled %v, got %v", test.input, string(b), d.Duration) + }) + } +}