diff --git a/associations/belongs_to_association.go b/associations/belongs_to_association.go index 2d70667c..a28e5d0a 100644 --- a/associations/belongs_to_association.go +++ b/associations/belongs_to_association.go @@ -31,7 +31,23 @@ func belongsToAssociationBuilder(p associationParams) (Association, error) { ownerVal := p.modelValue.FieldByName(p.field.Name) tags := p.popTags primaryIDField := defaults.String(tags.Find("primary_id").Value, "ID") - ownerIDField := defaults.String(tags.Find("fk_id").Value, fmt.Sprintf("%s%s", p.field.Name, "ID")) + ownerIDField := fmt.Sprintf("%s%s", p.field.Name, "ID") + + if tags.Find("fk_id").Value != "" { + dbTag := tags.Find("fk_id").Value + if _, found := p.modelType.FieldByName(dbTag); !found { + t := p.modelValue.Type() + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if f.Tag.Get("db") == dbTag { + ownerIDField = f.Name + break + } + } + } else { + ownerIDField = dbTag + } + } // belongs_to requires an holding field for the foreign model ID. if _, found := p.modelType.FieldByName(ownerIDField); !found { diff --git a/executors_test.go b/executors_test.go index 5af9304b..8b263c72 100644 --- a/executors_test.go +++ b/executors_test.go @@ -972,7 +972,7 @@ func Test_Eager_Create_Belongs_To(t *testing.T) { car := Taxi{ Model: "Fancy car", - Driver: User{ + Driver: &User{ Name: nulls.NewString("Larry 2"), }, } @@ -1101,7 +1101,7 @@ func Test_Flat_Create_Belongs_To(t *testing.T) { car := Taxi{ Model: "Fancy car", - Driver: user, + Driver: &user, } err = tx.Create(&car) diff --git a/pop_test.go b/pop_test.go index 6dd812af..20e4eda5 100644 --- a/pop_test.go +++ b/pop_test.go @@ -124,7 +124,7 @@ type Taxi struct { ID int `db:"id"` Model string `db:"model"` UserID nulls.Int `db:"user_id"` - Driver User `belongs_to:"user" fk_id:"UserID"` + Driver *User `belongs_to:"user" fk_id:"user_id"` CreatedAt time.Time `db:"created_at"` UpdatedAt time.Time `db:"updated_at"` } diff --git a/preload_associations.go b/preload_associations.go index e78039a2..5d3be99a 100644 --- a/preload_associations.go +++ b/preload_associations.go @@ -159,7 +159,8 @@ func (ami *AssociationMetaInfo) fkName() string { t = reflectx.Deref(t.Elem()) } fkName := fmt.Sprintf("%s%s", flect.Underscore(flect.Singularize(t.Name())), "_id") - return defaults.String(ami.Field.Tag.Get("fk_id"), fkName) + fkNameTag := flect.Underscore(ami.Field.Tag.Get("fk_id")) + return defaults.String(fkNameTag, fkName) } // preload is the query mode used to load associations from database @@ -269,11 +270,15 @@ func preloadHasMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInf asocValue := slice.Elem().Index(i) if mmi.mapper.FieldByName(mvalue, "ID").Interface() == mmi.mapper.FieldByName(asocValue, foreignField.Path).Interface() || reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, "ID"), mmi.mapper.FieldByName(asocValue, foreignField.Path)) { - if modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array { + + switch { + case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array: modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue)) - continue + case modelAssociationField.Kind() == reflect.Ptr: + modelAssociationField.Elem().Set(reflect.Append(modelAssociationField.Elem(), asocValue)) + default: + modelAssociationField.Set(asocValue) } - modelAssociationField.Set(asocValue) } } }) @@ -339,6 +344,10 @@ func preloadHasOne(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error { // 1) get all associations ids. fi := mmi.getDBFieldTaggedWith(asoc.fkName()) + if fi == nil { + fi = mmi.getDBFieldTaggedWith(fmt.Sprintf("%s%s", flect.Underscore(asoc.Path), "_id")) + } + fkids := []interface{}{} mmi.iterate(func(val reflect.Value) { fkids = append(fkids, mmi.mapper.FieldByName(val, fi.Path).Interface()) @@ -375,11 +384,15 @@ func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaI asocValue := slice.Elem().Index(i) if mmi.mapper.FieldByName(mvalue, fi.Path).Interface() == mmi.mapper.FieldByName(asocValue, "ID").Interface() || reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, fi.Path), mmi.mapper.FieldByName(asocValue, "ID")) { - if modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array { + + switch { + case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array: modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue)) - continue + case modelAssociationField.Kind() == reflect.Ptr: + modelAssociationField.Elem().Set(asocValue) + default: + modelAssociationField.Set(asocValue) } - modelAssociationField.Set(asocValue) } } }) diff --git a/preload_associations_test.go b/preload_associations_test.go index ed4f7ae0..a7fe4a59 100644 --- a/preload_associations_test.go +++ b/preload_associations_test.go @@ -182,3 +182,24 @@ func Test_New_Implementation_For_Nplus1_Nested(t *testing.T) { SetEagerMode(EagerDefault) }) } + +func Test_New_Implementation_For_Nplus1_BelongsTo_Not_Underscore(t *testing.T) { + if PDB == nil { + t.Skip("skipping integration tests") + } + transaction(func(tx *Connection) { + a := require.New(t) + user := User{Name: nulls.NewString("Mark")} + a.NoError(tx.Create(&user)) + + taxi := Taxi{UserID: nulls.NewInt(user.ID)} + a.NoError(tx.Create(&taxi)) + + SetEagerMode(EagerPreload) + taxis := []Taxi{} + a.NoError(tx.EagerPreload().All(&taxis)) + a.Len(taxis, 1) + a.Equal("Mark", taxis[0].Driver.Name.String) + SetEagerMode(EagerDefault) + }) +} diff --git a/soda/cmd/version.go b/soda/cmd/version.go index 119bcf37..3e207cb3 100644 --- a/soda/cmd/version.go +++ b/soda/cmd/version.go @@ -1,4 +1,4 @@ package cmd // Version defines the current Pop version. -const Version = "v5.1.1" +const Version = "v5.1.2"