diff --git a/common/structure/structure.go b/common/structure/structure.go index fde223090..99c980bc4 100644 --- a/common/structure/structure.go +++ b/common/structure/structure.go @@ -3,6 +3,7 @@ package structure // references: https://github.com/mitchellh/mapstructure import ( + "encoding" "encoding/base64" "fmt" "reflect" @@ -86,35 +87,41 @@ func (d *Decoder) Decode(src map[string]any, dst any) error { } func (d *Decoder) decode(name string, data any, val reflect.Value) error { - kind := val.Kind() - switch { - case isInt(kind): - return d.decodeInt(name, data, val) - case isUint(kind): - return d.decodeUint(name, data, val) - case isFloat(kind): - return d.decodeFloat(name, data, val) - } - switch kind { - case reflect.Pointer: - if val.IsNil() { + for { + kind := val.Kind() + if kind == reflect.Pointer && val.IsNil() { val.Set(reflect.New(val.Type().Elem())) } - return d.decode(name, data, val.Elem()) - case reflect.String: - return d.decodeString(name, data, val) - case reflect.Bool: - return d.decodeBool(name, data, val) - case reflect.Slice: - return d.decodeSlice(name, data, val) - case reflect.Map: - return d.decodeMap(name, data, val) - case reflect.Interface: - return d.setInterface(name, data, val) - case reflect.Struct: - return d.decodeStruct(name, data, val) - default: - return fmt.Errorf("type %s not support", val.Kind().String()) + if ok, err := d.decodeTextUnmarshaller(name, data, val); ok { + return err + } + switch { + case isInt(kind): + return d.decodeInt(name, data, val) + case isUint(kind): + return d.decodeUint(name, data, val) + case isFloat(kind): + return d.decodeFloat(name, data, val) + } + switch kind { + case reflect.Pointer: + val = val.Elem() + continue + case reflect.String: + return d.decodeString(name, data, val) + case reflect.Bool: + return d.decodeBool(name, data, val) + case reflect.Slice: + return d.decodeSlice(name, data, val) + case reflect.Map: + return d.decodeMap(name, data, val) + case reflect.Interface: + return d.setInterface(name, data, val) + case reflect.Struct: + return d.decodeStruct(name, data, val) + default: + return fmt.Errorf("type %s not support", val.Kind().String()) + } } } @@ -553,3 +560,25 @@ func (d *Decoder) setInterface(name string, data any, val reflect.Value) (err er val.Set(dataVal) return nil } + +func (d *Decoder) decodeTextUnmarshaller(name string, data any, val reflect.Value) (bool, error) { + if !val.CanAddr() { + return false, nil + } + valAddr := val.Addr() + if !valAddr.CanInterface() { + return false, nil + } + unmarshaller, ok := valAddr.Interface().(encoding.TextUnmarshaler) + if !ok { + return false, nil + } + var str string + if err := d.decodeString(name, data, reflect.Indirect(reflect.ValueOf(&str))); err != nil { + return false, err + } + if err := unmarshaller.UnmarshalText([]byte(str)); err != nil { + return true, fmt.Errorf("cannot parse '%s' as %s: %s", name, val.Type(), err) + } + return true, nil +} diff --git a/common/structure/structure_test.go b/common/structure/structure_test.go index 9f31d3d11..e5fe693d7 100644 --- a/common/structure/structure_test.go +++ b/common/structure/structure_test.go @@ -1,6 +1,7 @@ package structure import ( + "strconv" "testing" "github.com/stretchr/testify/assert" @@ -179,3 +180,90 @@ func TestStructure_SliceNilValueComplex(t *testing.T) { err = decoder.Decode(rawMap, ss) assert.NotNil(t, err) } + +func TestStructure_SliceCap(t *testing.T) { + rawMap := map[string]any{ + "foo": []string{}, + } + + s := &struct { + Foo []string `test:"foo,omitempty"` + Bar []string `test:"bar,omitempty"` + }{} + + err := decoder.Decode(rawMap, s) + assert.Nil(t, err) + assert.NotNil(t, s.Foo) // structure's Decode will ensure value not nil when input has value even it was set an empty array + assert.Nil(t, s.Bar) +} + +func TestStructure_Base64(t *testing.T) { + rawMap := map[string]any{ + "foo": "AQID", + } + + s := &struct { + Foo []byte `test:"foo"` + }{} + + err := decoder.Decode(rawMap, s) + assert.Nil(t, err) + assert.Equal(t, []byte{1, 2, 3}, s.Foo) +} + +func TestStructure_Pointer(t *testing.T) { + rawMap := map[string]any{ + "foo": "foo", + } + + s := &struct { + Foo *string `test:"foo,omitempty"` + Bar *string `test:"bar,omitempty"` + }{} + + err := decoder.Decode(rawMap, s) + assert.Nil(t, err) + assert.NotNil(t, s.Foo) + assert.Equal(t, "foo", *s.Foo) + assert.Nil(t, s.Bar) +} + +type num struct { + a int +} + +func (n *num) UnmarshalText(text []byte) (err error) { + n.a, err = strconv.Atoi(string(text)) + return +} + +func TestStructure_TextUnmarshaller(t *testing.T) { + rawMap := map[string]any{ + "num": "255", + "num_p": "127", + } + + s := &struct { + Num num `test:"num"` + NumP *num `test:"num_p"` + }{} + + err := decoder.Decode(rawMap, s) + assert.Nil(t, err) + assert.Equal(t, 255, s.Num.a) + assert.NotNil(t, s.NumP) + assert.Equal(t, s.NumP.a, 127) + + // test WeaklyTypedInput + rawMap["num"] = 256 + err = decoder.Decode(rawMap, s) + assert.NotNilf(t, err, "should throw error: %#v", s) + err = weakTypeDecoder.Decode(rawMap, s) + assert.Nil(t, err) + assert.Equal(t, 256, s.Num.a) + + // test invalid input + rawMap["num_p"] = "abc" + err = decoder.Decode(rawMap, s) + assert.NotNilf(t, err, "should throw error: %#v", s) +}