diff --git a/encode.go b/encode.go index 73366c0..bde6535 100644 --- a/encode.go +++ b/encode.go @@ -402,31 +402,30 @@ func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) { // Sort keys so that we have deterministic output. And write keys directly // underneath this key first, before writing sub-structs or sub-maps. - var mapKeysDirect, mapKeysSub []string + var mapKeysDirect, mapKeysSub []reflect.Value for _, mapKey := range rv.MapKeys() { - k := mapKey.String() if typeIsTable(tomlTypeOfGo(eindirect(rv.MapIndex(mapKey)))) { - mapKeysSub = append(mapKeysSub, k) + mapKeysSub = append(mapKeysSub, mapKey) } else { - mapKeysDirect = append(mapKeysDirect, k) + mapKeysDirect = append(mapKeysDirect, mapKey) } } - var writeMapKeys = func(mapKeys []string, trailC bool) { - sort.Strings(mapKeys) + writeMapKeys := func(mapKeys []reflect.Value, trailC bool) { + sort.Slice(mapKeys, func(i, j int) bool { return mapKeys[i].String() < mapKeys[j].String() }) for i, mapKey := range mapKeys { - val := eindirect(rv.MapIndex(reflect.ValueOf(mapKey))) + val := eindirect(rv.MapIndex(mapKey)) if isNil(val) { continue } if inline { - enc.writeKeyValue(Key{mapKey}, val, true) + enc.writeKeyValue(Key{mapKey.String()}, val, true) if trailC || i != len(mapKeys)-1 { enc.wf(", ") } } else { - enc.encode(key.add(mapKey), val) + enc.encode(key.add(mapKey.String()), val) } } } diff --git a/encode_test.go b/encode_test.go index fbf17de..a5c3a7e 100644 --- a/encode_test.go +++ b/encode_test.go @@ -837,10 +837,13 @@ func TestEncodeJSONNumber(t *testing.T) { } func TestEncode(t *testing.T) { - type Embedded struct { - Int int `toml:"_int"` - } - type NonStruct int + type ( + Embedded struct { + Int int `toml:"_int"` + } + NonStruct int + MyInt int + ) date := time.Date(2014, 5, 11, 19, 30, 40, 0, time.UTC) dateStr := "2014-05-11T19:30:40Z" @@ -1165,6 +1168,10 @@ ArrayOfMixedSlices = [[1, 2], ["a", "b"]] input: map[int]string{1: ""}, wantError: errNonString, }, + "(error) map no string key indirect": { + input: map[MyInt]string{1: ""}, + wantError: errNonString, + }, "tbl-in-arr-struct": { input: struct { @@ -1279,7 +1286,7 @@ func encodeExpected(t *testing.T, label string, val any, want string, wantErr er if wantErr == errAnything && err != nil { return } - t.Errorf("want Encode error %v, got %v", wantErr, err) + t.Errorf("wrong error:\nwant: %v\nhave: %v", wantErr, err) } else { t.Errorf("Encode failed: %s", err) } @@ -1297,3 +1304,40 @@ func encodeExpected(t *testing.T, label string, val any, want string, wantErr er } }) } + +func TestMapCustomKeytype(t *testing.T) { + type ( + MyString string + MyMap map[MyString]any + ) + + m := MyMap{ + "k1": "a", + "nested": MyMap{"k2": "b"}, + } + have := new(bytes.Buffer) + err := NewEncoder(have).Encode(m) + if err != nil { + t.Fatal(err) + } + + want := ` +k1 = "a" + +[nested] + k2 = "b" +`[1:] + + if have.String() != want { + t.Fatalf("\nhave: %q\nwant: %q", have, want) + } + + var m2 MyMap + _, err = Decode(have.String(), &m2) + if err != nil { + t.Fatal(err) + } + if h, w := fmt.Sprintf("%s", m2), fmt.Sprintf("%s", m); h != w { + t.Fatalf("\nhave: %s\nwant: %s", h, w) + } +}