diff --git a/binary.go b/binary.go index 3d479cc..5ad8f5d 100644 --- a/binary.go +++ b/binary.go @@ -108,6 +108,20 @@ func (b *Encoder) Encode(v interface{}) (err error) { } for _, key := range rv.MapKeys() { value := rv.MapIndex(key) + if key.CanAddr() { + key = key.Addr() + } else { + k := reflect.New(key.Type()).Elem() + k.Set(key) + key = k.Addr() + } + if value.CanAddr() { + value = value.Addr() + } else { + v := reflect.New(value.Type()).Elem() + v.Set(value) + value = v.Addr() + } if err = b.Encode(key.Interface()); err != nil { return err } diff --git a/binary_test.go b/binary_test.go index 19aee37..a926e5f 100644 --- a/binary_test.go +++ b/binary_test.go @@ -390,3 +390,40 @@ func BenchmarkDecodeStructI3(b *testing.B) { } } + +func TestMapOfStructWithStruct(t *testing.T) { + type T1 struct { + ID uint64 + Name string + } + type T2 uint64 + type Struct struct { + V1 T1 + V2 T2 + V3 T1 + } + + k1 := Struct{V1: T1{1, "1"}, V2: 2, V3: T1{3, "3"}} + v1 := Struct{V1: T1{1, "1"}, V2: 2, V3: T1{3, "3"}} + s := map[Struct]Struct{ + k1: v1, + } + buf := new(bytes.Buffer) + enc := NewEncoder(buf) + err := enc.Encode(&s) + if err != nil { + t.Fatalf("error: %v\n", err) + } + + v := make(map[Struct]Struct) + dec := NewDecoder(buf) + err = dec.Decode(&v) + if err != nil { + t.Fatalf("error: %v\n", err) + } + + if !reflect.DeepEqual(s, v) { + t.Fatalf("got= %#v\nwant=%#v\n", v, s) + } + +}