Skip to content

Commit

Permalink
Fix belongs_to tag value to support underscored. (#546)
Browse files Browse the repository at this point in the history
* Reset to development

* Fix belongs_to tag value to support underscored, also support pointer fields for belongs_to

* add break clause when finding a match
  • Loading branch information
larrymjordan committed May 8, 2020
1 parent 9024edb commit e747e61
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 10 deletions.
18 changes: 17 additions & 1 deletion associations/belongs_to_association.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,23 @@ func belongsToAssociationBuilder(p associationParams) (Association, error) {
ownerVal := p.modelValue.FieldByName(p.field.Name)
tags := p.popTags
primaryIDField := defaults.String(tags.Find("primary_id").Value, "ID")
ownerIDField := defaults.String(tags.Find("fk_id").Value, fmt.Sprintf("%s%s", p.field.Name, "ID"))
ownerIDField := fmt.Sprintf("%s%s", p.field.Name, "ID")

if tags.Find("fk_id").Value != "" {
dbTag := tags.Find("fk_id").Value
if _, found := p.modelType.FieldByName(dbTag); !found {
t := p.modelValue.Type()
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.Tag.Get("db") == dbTag {
ownerIDField = f.Name
break
}
}
} else {
ownerIDField = dbTag
}
}

// belongs_to requires an holding field for the foreign model ID.
if _, found := p.modelType.FieldByName(ownerIDField); !found {
Expand Down
4 changes: 2 additions & 2 deletions executors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,7 @@ func Test_Eager_Create_Belongs_To(t *testing.T) {

car := Taxi{
Model: "Fancy car",
Driver: User{
Driver: &User{
Name: nulls.NewString("Larry 2"),
},
}
Expand Down Expand Up @@ -1101,7 +1101,7 @@ func Test_Flat_Create_Belongs_To(t *testing.T) {

car := Taxi{
Model: "Fancy car",
Driver: user,
Driver: &user,
}

err = tx.Create(&car)
Expand Down
2 changes: 1 addition & 1 deletion pop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ type Taxi struct {
ID int `db:"id"`
Model string `db:"model"`
UserID nulls.Int `db:"user_id"`
Driver User `belongs_to:"user" fk_id:"UserID"`
Driver *User `belongs_to:"user" fk_id:"user_id"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
}
Expand Down
20 changes: 14 additions & 6 deletions preload_associations.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,15 @@ func preloadHasMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInf
asocValue := slice.Elem().Index(i)
if mmi.mapper.FieldByName(mvalue, "ID").Interface() == mmi.mapper.FieldByName(asocValue, foreignField.Path).Interface() ||
reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, "ID"), mmi.mapper.FieldByName(asocValue, foreignField.Path)) {
if modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array {

switch {
case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array:
modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue))
continue
case modelAssociationField.Kind() == reflect.Ptr:
modelAssociationField.Elem().Set(reflect.Append(modelAssociationField.Elem(), asocValue))
default:
modelAssociationField.Set(asocValue)
}
modelAssociationField.Set(asocValue)
}
}
})
Expand Down Expand Up @@ -380,11 +384,15 @@ func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaI
asocValue := slice.Elem().Index(i)
if mmi.mapper.FieldByName(mvalue, fi.Path).Interface() == mmi.mapper.FieldByName(asocValue, "ID").Interface() ||
reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, fi.Path), mmi.mapper.FieldByName(asocValue, "ID")) {
if modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array {

switch {
case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array:
modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue))
continue
case modelAssociationField.Kind() == reflect.Ptr:
modelAssociationField.Elem().Set(asocValue)
default:
modelAssociationField.Set(asocValue)
}
modelAssociationField.Set(asocValue)
}
}
})
Expand Down

0 comments on commit e747e61

Please sign in to comment.