From 4f625671e6ce1737f0037800f98552b2995ee957 Mon Sep 17 00:00:00 2001 From: Larry Morales Jordan Date: Mon, 16 Apr 2018 13:28:18 -0500 Subject: [PATCH] Pop Eager Creation (#14) * add slices and array support for create, update, save, validateAndCreate, validateAndSave, validateAndUpdate * change return reference on Eager function from *Query To *Connection * add has_many eager creation feature * add has one association support for eager creation * add belongs_to eager creation feature * many to many association support * add SQL translator to accomplish dialect params * extract skipped function into skipable struct * add AssociationCreatable interface to define associations that can be created * Update README.md * add exclude columns for eager model creation * Fix broken test * Update README.md * add support for UUID generation before create a many to many relationship * [Fix] allow eager creation when there are not associations defined for model * remove assoc package dependency * add owner eligible for creation flag to belongs_to associations * Replace design for eager creation by using Before and After type associations * Fix Broken Test * Use store instead of TX when it has a nil reference * clear slices when eager loading from many_to_many and has_many associations * [Refactor] remove duplicated code for iterations over models * [Refactor] remove duplicated code and improve validations type * [Refactor] remove unused code --- README.md | 122 ++++++- associations/association.go | 95 ++++- associations/belongs_to_association.go | 45 ++- associations/has_many_association.go | 62 +++- associations/has_many_association_test.go | 31 +- associations/has_one_association.go | 43 ++- associations/has_one_association_test.go | 23 +- associations/many_to_many_association.go | 83 ++++- connection.go | 12 +- executors.go | 150 ++++---- executors_eager.go | 130 +++++++ executors_test.go | 341 ++++++++++++++++++ finders.go | 16 +- .../20160808213310_setup_tests2.down.fizz | 2 + .../20160808213310_setup_tests2.up.fizz | 9 + model.go | 20 + nulls/nulls.go | 51 +++ pop_test.go | 31 ++ query.go | 27 +- validations.go | 92 +++-- 20 files changed, 1204 insertions(+), 181 deletions(-) create mode 100644 executors_eager.go create mode 100644 nulls/nulls.go diff --git a/README.md b/README.md index 33206754..ff5ee853 100644 --- a/README.md +++ b/README.md @@ -261,7 +261,7 @@ user := models.User{} err := tx.Find(&user, id) ``` -#### Query +#### Query All ```go tx := models.DB query := tx.Where("id = 1").Where("name = 'Mark'") @@ -289,6 +289,80 @@ sql, args := query.ToSQL(&pop.Model{Value: models.UserRole{}}, "user_roles.*", err := models.DB.RawQuery(sql, args...).All(&roles) ``` +#### Create +```go +// Create one record. +user := models.User{} +user.Name = "Mark" +err := tx.Create(&user) + +// Create many records. +users := models.Users{ + {Name:"Mark"}, + {Name: "Larry"}, +} + +err := tx.Create(&users) +``` + +#### Save +```go +// Save one record. +user := models.User{} +user.Name = "Mark" +err := tx.Save(&user) + +// Save many records. +users := models.Users{ + {Name:"Mark"}, + {Name: "Larry"}, +} + +err := tx.Save(&users) +``` + +#### Update +```go +// Update one record. +user := models.User{} +user.Name = "Mark" +err := tx.Create(&user) + +user.Name = "Mark Bates" +err = tx.Update(&user) + +// Update many records. +users := models.Users{ + {Name:"Mark"}, + {Name: "Larry"}, +} + +err := tx.Create(&users) + +users[0].Name = "Mark Bates" +users[1].Name = "Larry Morales" +err := tx.Update(&users) +``` + +#### Destroy +```go +// Destroy one record. +user := models.User{} +user.Name = "Mark" +err := tx.Create(&user) + +err = tx.Destroy(&user) + +// Destroy many records. +users := models.Users{ + {Name:"Mark"}, + {Name: "Larry"}, +} +err := tx.Create(&users) + +err = tx.Destroy(&users) +``` + ### Eager Loading **pop** allows you to perform an eager loading for associations defined in a model. By using `pop.Connection.Eager()` function plus some fields tags predefined in your model you can extract associated data from a model. @@ -375,6 +449,52 @@ tx.Eager("Books.Writers.Book").First(&u) // will load all Books for u and for e tx.Eager("Books.Writers").Eager("FavoriteSong").First(&u) // will load all Books for u and for every Book will load all Writers. And Also it will load the favorite song for user. ``` +#### Eager Creation +pop allows you to eager create models and their associations in just one simple statement, you don't need to create every association separately anymore. + +```go +user := User{ + Name: "Mark Bates", + Books: Books{{Title: "Pop Book", Description: "Pop Book", Isbn: "PB1"}}, + FavoriteSong: Song{Title: "Don't know the title"}, + Houses: Addresses{ + Address{HouseNumber: 1, Street: "Golang"}, + }, +} +``` + +```go +err := tx.Eager().Create(&user) +``` + +The above sentence will do this: + +1. It will notice `Books` is a `has_many` association and it will realize that to actually store every book it will need to get the `User ID` first. So, it proceeds to store first `User` data so it can retrieve an **ID** and then use that ID to fill `UserID` field in every `Book` in `Books`. Later it stores all books in database. + +2. `FavoriteSong` is a `has_one` association and it uses same logic described in `has_many` association. Since `User` data was previously saved before creating all books, it already knows that `User` got an `ID` so it fills its `UserID` field with that value and `FavoriteSong` is then stored in database. + +3. `Houses` for this example is a `many_to_many` relationship and it will have to deal with two tables in this case: `users` and `addresses`. It will need to store all addresses first in `addresses` table before save them in the many to many table. Because `User` was already stored, it already have an `ID`. This is a special case to deal with, since this behavior is different to all other associations, it managed to solve it by let it implement the `AssociationCreatableStatement` interface, all other associations implement by default `AssociationCreatable` interface. + +For `belongs_to` association like shown in the example bellow, it will need first to create `User` to retrieve **ID** value and then fill its `UserID` field before be saved in database. + +```go +book := Book{ + Title: "Pop Book", + Description: "Pop Book", + Isbn: "PB1", + User: User{ + Name: nulls.NewString("Larry"), + }, +} +``` + +```go +tx.Eager().Create(&book) +``` + +All these cases are assuming that none of models and associations has previously been saved in database. + + #### Callbacks Pop provides a means to execute code before and after database operations. This is done by defining specific methods on your models. For diff --git a/associations/association.go b/associations/association.go index e1cc2f91..b65a5bd9 100644 --- a/associations/association.go +++ b/associations/association.go @@ -4,6 +4,7 @@ import ( "reflect" "github.com/gobuffalo/pop/columns" + "github.com/gobuffalo/pop/nulls" ) // Association represents a definition of a model association @@ -14,6 +15,17 @@ type Association interface { Interface() interface{} Constraint() (string, []interface{}) InnerAssociations() InnerAssociations + Skipped() bool +} + +// associationSkipable is a helper struct that helps +// to include skippable behavior in associations. +type associationSkipable struct { + skipped bool +} + +func (a *associationSkipable) Skipped() bool { + return a.skipped } // associationComposite adds the ability for a Association to @@ -43,12 +55,70 @@ type AssociationSortable interface { Association } +type AssociationBeforeCreatable interface { + BeforeInterface() interface{} + BeforeSetup() error + Association +} + +type AssociationAfterCreatable interface { + AfterInterface() interface{} + AfterSetup() error + Association +} + +// AssociationCreatableStatement a association that defines +// create statements on database. +type AssociationCreatableStatement interface { + Statements() []AssociationStatement + Association +} + +// AssociationStatement a type that represents a statement to be +// executed. +type AssociationStatement struct { + Statement string + Args []interface{} +} + // Associations a group of model associations. type Associations []Association -// SkippedAssociation an empty association used to indicate -// an association should not be queried. -var SkippedAssociation = (Association)(nil) +// AssociationsBeforeCreatable returns all associations that implement AssociationBeforeCreatable +// interface. Belongs To association is an example of this implementation. +func (a Associations) AssociationsBeforeCreatable() []AssociationBeforeCreatable { + before := []AssociationBeforeCreatable{} + for i := range a { + if _, ok := a[i].(AssociationBeforeCreatable); ok { + before = append(before, a[i].(AssociationBeforeCreatable)) + } + } + return before +} + +// AssociationsAfterCreatable returns all associations that implement AssociationAfterCreatable +// interface. Has Many and Has One associations are example of this implementation. +func (a Associations) AssociationsAfterCreatable() []AssociationAfterCreatable { + after := []AssociationAfterCreatable{} + for i := range a { + if _, ok := a[i].(AssociationAfterCreatable); ok { + after = append(after, a[i].(AssociationAfterCreatable)) + } + } + return after +} + +// AssociationsCreatableStatement returns all associations that implement AssociationCreatableStament +// interface. Many To Many association is an example of this implementation. +func (a Associations) AssociationsCreatableStatement() []AssociationCreatableStatement { + stm := []AssociationCreatableStatement{} + for i := range a { + if _, ok := a[i].(AssociationCreatableStatement); ok { + stm = append(stm, a[i].(AssociationCreatableStatement)) + } + } + return stm +} // associationParams a wrapper for associations definition // and creation. @@ -65,22 +135,17 @@ type associationParams struct { // see the builder defined in ./has_many_association.go as a guide of how to use it. type associationBuilder func(associationParams) (Association, error) -// nullable means this type is a nullable association field. -type nullable interface { - Interface() interface{} -} - // fieldIsNil validates if a field has a nil reference. Also // it validates if a field implements nullable interface and // it has a nil value. func fieldIsNil(f reflect.Value) bool { - null := (*nullable)(nil) - t := reflect.TypeOf(f.Interface()) - if t.Implements(reflect.TypeOf(null).Elem()) { - m := reflect.ValueOf(f.Interface()).MethodByName("Interface") - out := m.Call([]reflect.Value{}) - idValue := out[0].Interface() - return idValue == nil + if n := nulls.New(f.Interface()); n != nil { + return n.Interface() == nil } return f.Interface() == nil } + +func isZero(i interface{}) bool { + v := reflect.ValueOf(i) + return v.Interface() == reflect.Zero(v.Type()).Interface() +} diff --git a/associations/belongs_to_association.go b/associations/belongs_to_association.go index 265ac708..3756b6db 100644 --- a/associations/belongs_to_association.go +++ b/associations/belongs_to_association.go @@ -3,6 +3,8 @@ package associations import ( "fmt" "reflect" + + "github.com/gobuffalo/pop/nulls" ) // belongsToAssociation is the implementation for the belongs_to @@ -11,7 +13,8 @@ type belongsToAssociation struct { ownerModel reflect.Value ownerType reflect.Type ownerID reflect.Value - owner interface{} + ownedModel interface{} + *associationSkipable *associationComposite } @@ -28,16 +31,20 @@ func belongsToAssociationBuilder(p associationParams) (Association, error) { } // Validates if ownerIDField is nil, this association will be skipped. + var skipped bool f := p.modelValue.FieldByName(ownerIDField) - if fieldIsNil(f) { - return SkippedAssociation, nil + if fieldIsNil(f) || isZero(f.Interface()) { + skipped = true } return &belongsToAssociation{ - ownerModel: fval, - ownerType: fval.Type(), - ownerID: f, - owner: p.model, + ownerModel: fval, + ownerType: fval.Type(), + ownerID: f, + ownedModel: p.model, + associationSkipable: &associationSkipable{ + skipped: skipped, + }, associationComposite: &associationComposite{innerAssociations: p.innerAssociations}, }, nil } @@ -63,3 +70,27 @@ func (b *belongsToAssociation) Interface() interface{} { func (b *belongsToAssociation) Constraint() (string, []interface{}) { return "id = ?", []interface{}{b.ownerID.Interface()} } + +func (b *belongsToAssociation) BeforeInterface() interface{} { + if !b.skipped { + return nil + } + + if b.ownerModel.Kind() == reflect.Ptr { + return b.ownerModel.Interface() + } + return b.ownerModel.Addr().Interface() +} + +func (b *belongsToAssociation) BeforeSetup() error { + ownerID := reflect.Indirect(reflect.ValueOf(b.ownerModel.Interface())).FieldByName("ID").Interface() + if b.ownerID.CanSet() { + if n := nulls.New(b.ownerID.Interface()); n != nil { + b.ownerID.Set(reflect.ValueOf(n.Parse(ownerID))) + } else { + b.ownerID.Set(reflect.ValueOf(ownerID)) + } + return nil + } + return fmt.Errorf("could not set '%s' to '%s'", ownerID, b.ownerID) +} diff --git a/associations/has_many_association.go b/associations/has_many_association.go index 3709fd15..6e006bc2 100644 --- a/associations/has_many_association.go +++ b/associations/has_many_association.go @@ -4,6 +4,7 @@ import ( "fmt" "reflect" + "github.com/gobuffalo/pop/nulls" "github.com/markbates/inflect" ) @@ -15,8 +16,10 @@ type hasManyAssociation struct { value reflect.Value ownerName string ownerID interface{} + owner interface{} fkID string orderBy string + *associationSkipable *associationComposite } @@ -26,19 +29,24 @@ func init() { func hasManyAssociationBuilder(p associationParams) (Association, error) { // Validates if ownerID is nil, this association will be skipped. + var skipped bool ownerID := p.modelValue.FieldByName("ID") if fieldIsNil(ownerID) { - return SkippedAssociation, nil + skipped = true } return &hasManyAssociation{ - tableName: p.popTags.Find("has_many").Value, - field: p.field, - value: p.modelValue.FieldByName(p.field.Name), - ownerName: p.modelType.Name(), - ownerID: ownerID.Interface(), - fkID: p.popTags.Find("fk_id").Value, - orderBy: p.popTags.Find("order_by").Value, + owner: p.model, + tableName: p.popTags.Find("has_many").Value, + field: p.field, + value: p.modelValue.FieldByName(p.field.Name), + ownerName: p.modelType.Name(), + ownerID: ownerID.Interface(), + fkID: p.popTags.Find("fk_id").Value, + orderBy: p.popTags.Find("order_by").Value, + associationSkipable: &associationSkipable{ + skipped: skipped, + }, associationComposite: &associationComposite{innerAssociations: p.innerAssociations}, }, nil } @@ -56,6 +64,14 @@ func (a *hasManyAssociation) Interface() interface{} { a.value.Set(val) return a.value.Interface() } + + // This piece of code clears a slice in case it is filled with elements. + if a.value.Kind() == reflect.Slice || a.value.Kind() == reflect.Array { + valPointer := a.value.Addr() + valPointer.Elem().Set(reflect.MakeSlice(valPointer.Type().Elem(), 0, valPointer.Elem().Cap())) + return valPointer.Interface() + } + return a.value.Addr().Interface() } @@ -73,3 +89,33 @@ func (a *hasManyAssociation) Constraint() (string, []interface{}) { func (a *hasManyAssociation) OrderBy() string { return a.orderBy } + +func (a *hasManyAssociation) AfterInterface() interface{} { + if a.value.Kind() == reflect.Ptr { + return a.value.Interface() + } + return a.value.Addr().Interface() +} + +func (a *hasManyAssociation) AfterSetup() error { + ownerID := reflect.Indirect(reflect.ValueOf(a.owner)).FieldByName("ID").Interface() + + v := a.value + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + for i := 0; i < v.Len(); i++ { + fval := v.Index(i).FieldByName(a.ownerName + "ID") + if fval.CanSet() { + if n := nulls.New(fval.Interface()); n != nil { + fval.Set(reflect.ValueOf(n.Parse(ownerID))) + } else { + fval.Set(reflect.ValueOf(ownerID)) + } + } else { + return fmt.Errorf("could not set '%s' in '%s'", ownerID, fval) + } + } + return nil +} diff --git a/associations/has_many_association_test.go b/associations/has_many_association_test.go index 567f2126..e0ff4ab7 100644 --- a/associations/has_many_association_test.go +++ b/associations/has_many_association_test.go @@ -5,17 +5,18 @@ import ( "testing" "github.com/gobuffalo/pop/associations" - "github.com/gobuffalo/uuid" + "github.com/gobuffalo/pop/nulls" "github.com/stretchr/testify/require" ) -type fooHasMany struct { - ID uuid.UUID `db:"id"` - BarHasManies barHasManies `has_many:"bar_has_manies"` +type FooHasMany struct { + ID int `db:"id"` + BarHasManies *barHasManies `has_many:"bar_has_manies"` } type barHasMany struct { - FooHasManyID uuid.UUID `db:"foo_has_many_id"` + Title string `db:"title"` + FooHasManyID nulls.Int `db:"foo_has_many_id"` } type barHasManies []barHasMany @@ -23,8 +24,8 @@ type barHasManies []barHasMany func Test_Has_Many_Association(t *testing.T) { a := require.New(t) - id, _ := uuid.NewV1() - foo := fooHasMany{ID: id} + id := 1 + foo := FooHasMany{ID: 1} as, err := associations.AssociationsForStruct(&foo) @@ -34,5 +35,19 @@ func Test_Has_Many_Association(t *testing.T) { where, args := as[0].Constraint() a.Equal("foo_has_many_id = ?", where) - a.Equal(id, args[0].(uuid.UUID)) + a.Equal(id, args[0].(int)) +} + +func Test_Has_Many_SetValue(t *testing.T) { + a := require.New(t) + foo := FooHasMany{ID: 1, BarHasManies: &barHasManies{{Title: "bar"}}} + + as, _ := associations.AssociationsForStruct(&foo) + a.Equal(len(as), 1) + + ca, ok := as[0].(associations.AssociationAfterCreatable) + a.True(ok) + + ca.AfterSetup() + a.Equal(foo.ID, (*foo.BarHasManies)[0].FooHasManyID.Interface().(int)) } diff --git a/associations/has_one_association.go b/associations/has_one_association.go index ae9f11b3..e3053cc9 100644 --- a/associations/has_one_association.go +++ b/associations/has_one_association.go @@ -4,6 +4,7 @@ import ( "fmt" "reflect" + "github.com/gobuffalo/pop/nulls" "github.com/markbates/inflect" ) @@ -14,6 +15,7 @@ type hasOneAssociation struct { ownerName string owner interface{} fkID string + *associationSkipable *associationComposite } @@ -23,19 +25,23 @@ func init() { func hasOneAssociationBuilder(p associationParams) (Association, error) { // Validates if ownerIDField is nil, this association will be skipped. + var skipped bool ownerID := p.modelValue.FieldByName("ID") if fieldIsNil(ownerID) { - return SkippedAssociation, nil + skipped = true } fval := p.modelValue.FieldByName(p.field.Name) return &hasOneAssociation{ - owner: p.model, - ownedModel: fval, - ownedType: fval.Type(), - ownerID: ownerID.Interface(), - ownerName: p.modelType.Name(), - fkID: p.popTags.Find("fk_id").Value, + owner: p.model, + ownedModel: fval, + ownedType: fval.Type(), + ownerID: ownerID.Interface(), + ownerName: p.modelType.Name(), + fkID: p.popTags.Find("fk_id").Value, + associationSkipable: &associationSkipable{ + skipped: skipped, + }, associationComposite: &associationComposite{innerAssociations: p.innerAssociations}, }, nil } @@ -64,3 +70,26 @@ func (h *hasOneAssociation) Constraint() (string, []interface{}) { return condition, []interface{}{h.ownerID} } + +func (h *hasOneAssociation) AfterInterface() interface{} { + if h.ownedModel.Kind() == reflect.Ptr { + return h.ownedModel.Interface() + } + return h.ownedModel.Addr().Interface() +} + +func (h *hasOneAssociation) AfterSetup() error { + ownerID := reflect.Indirect(reflect.ValueOf(h.owner)).FieldByName("ID").Interface() + + fval := h.ownedModel.FieldByName(h.ownerName + "ID") + if fval.CanSet() { + if n := nulls.New(fval.Interface()); n != nil { + fval.Set(reflect.ValueOf(n.Parse(ownerID))) + } else { + fval.Set(reflect.ValueOf(ownerID)) + } + return nil + } + + return fmt.Errorf("could not set '%s' to '%s'", ownerID, fval) +} diff --git a/associations/has_one_association_test.go b/associations/has_one_association_test.go index 19ff8f4d..feaede51 100644 --- a/associations/has_one_association_test.go +++ b/associations/has_one_association_test.go @@ -5,24 +5,26 @@ import ( "testing" "github.com/gobuffalo/pop/associations" + "github.com/gobuffalo/pop/nulls" "github.com/gobuffalo/uuid" "github.com/stretchr/testify/require" ) -type fooHasOne struct { +type FooHasOne struct { ID uuid.UUID `db:"id"` BarHasOne barHasOne `has_one:"barHasOne"` } type barHasOne struct { - FooHasOneID uuid.UUID `db:"foo_has_one_id"` + Title string `db:"title"` + FooHasOneID nulls.UUID `db:"foo_has_one_id"` } func Test_Has_One_Association(t *testing.T) { a := require.New(t) id, _ := uuid.NewV1() - foo := fooHasOne{ID: id} + foo := FooHasOne{ID: id} as, err := associations.AssociationsForStruct(&foo) @@ -34,3 +36,18 @@ func Test_Has_One_Association(t *testing.T) { a.Equal("foo_has_one_id = ?", where) a.Equal(id, args[0].(uuid.UUID)) } + +func Test_Has_One_SetValue(t *testing.T) { + a := require.New(t) + id, _ := uuid.NewV1() + foo := FooHasOne{ID: id, BarHasOne: barHasOne{Title: "bar"}} + + as, _ := associations.AssociationsForStruct(&foo) + a.Equal(len(as), 1) + + ca, ok := as[0].(associations.AssociationAfterCreatable) + a.True(ok) + + ca.AfterSetup() + a.Equal(foo.ID, foo.BarHasOne.FooHasOneID.Interface().(uuid.UUID)) +} diff --git a/associations/many_to_many_association.go b/associations/many_to_many_association.go index 29e9407b..60aa0290 100644 --- a/associations/many_to_many_association.go +++ b/associations/many_to_many_association.go @@ -3,7 +3,9 @@ package associations import ( "fmt" "reflect" + "time" + "github.com/gobuffalo/uuid" "github.com/markbates/inflect" ) @@ -15,25 +17,30 @@ type manyToManyAssociation struct { owner interface{} fkID string orderBy string + *associationSkipable *associationComposite } func init() { associationBuilders["many_to_many"] = func(p associationParams) (Association, error) { // Validates if model.ID is nil, this association will be skipped. + var skipped bool model := p.modelValue if fieldIsNil(model.FieldByName("ID")) { - return SkippedAssociation, nil + skipped = true } return &manyToManyAssociation{ - fieldType: p.modelValue.FieldByName(p.field.Name).Type(), - fieldValue: p.modelValue.FieldByName(p.field.Name), - owner: p.model, - model: model, - manyToManyTableName: p.popTags.Find("many_to_many").Value, - fkID: p.popTags.Find("fk_id").Value, - orderBy: p.popTags.Find("order_by").Value, + fieldType: p.modelValue.FieldByName(p.field.Name).Type(), + fieldValue: p.modelValue.FieldByName(p.field.Name), + owner: p.model, + model: model, + manyToManyTableName: p.popTags.Find("many_to_many").Value, + fkID: p.popTags.Find("fk_id").Value, + orderBy: p.popTags.Find("order_by").Value, + associationSkipable: &associationSkipable{ + skipped: skipped, + }, associationComposite: &associationComposite{innerAssociations: p.innerAssociations}, }, nil } @@ -44,11 +51,19 @@ func (m *manyToManyAssociation) Kind() reflect.Kind { } func (m *manyToManyAssociation) Interface() interface{} { + val := reflect.New(m.fieldType.Elem()) if m.fieldValue.Kind() == reflect.Ptr { - val := reflect.New(m.fieldType.Elem()) m.fieldValue.Set(val) return m.fieldValue.Interface() } + + // This piece of code clears a slice in case it is filled with elements. + if m.fieldValue.Kind() == reflect.Slice || m.fieldValue.Kind() == reflect.Array { + valPointer := m.fieldValue.Addr() + valPointer.Elem().Set(reflect.MakeSlice(valPointer.Type().Elem(), 0, valPointer.Elem().Cap())) + return valPointer.Interface() + } + return m.fieldValue.Addr().Interface() } @@ -79,3 +94,53 @@ func (m *manyToManyAssociation) Constraint() (string, []interface{}) { func (m *manyToManyAssociation) OrderBy() string { return m.orderBy } + +func (m *manyToManyAssociation) BeforeInterface() interface{} { + if m.fieldValue.Kind() == reflect.Ptr { + return m.fieldValue.Interface() + } + return m.fieldValue.Addr().Interface() +} + +func (m *manyToManyAssociation) BeforeSetup() error { + return nil +} + +func (m *manyToManyAssociation) Statements() []AssociationStatement { + statements := []AssociationStatement{} + + modelColumnID := fmt.Sprintf("%s%s", inflect.Underscore(m.model.Type().Name()), "_id") + var columnFieldID string + i := reflect.Indirect(m.fieldValue) + if i.Kind() == reflect.Slice || i.Kind() == reflect.Array { + t := i.Type().Elem() + columnFieldID = fmt.Sprintf("%s%s", inflect.Underscore(t.Name()), "_id") + } else { + columnFieldID = fmt.Sprintf("%s%s", inflect.Underscore(i.Type().Name()), "_id") + } + + for i := 0; i < m.fieldValue.Len(); i++ { + v := m.fieldValue.Index(i) + manyIDValue := v.FieldByName("ID").Interface() + modelIDValue := m.model.FieldByName("ID").Interface() + stm := "INSERT INTO %s (%s,%s,%s,%s) VALUES(?,?,?,?)" + + associationStm := AssociationStatement{ + Statement: fmt.Sprintf(stm, m.manyToManyTableName, modelColumnID, columnFieldID, "created_at", "updated_at"), + Args: []interface{}{modelIDValue, manyIDValue, time.Now(), time.Now()}, + } + + if m.model.FieldByName("ID").Type().Name() == "UUID" { + stm = "INSERT INTO %s (%s,%s,%s,%s,%s) VALUES(?,?,?,?,?)" + id, _ := uuid.NewV4() + associationStm = AssociationStatement{ + Statement: fmt.Sprintf(stm, m.manyToManyTableName, "id", modelColumnID, columnFieldID, "created_at", "updated_at"), + Args: []interface{}{id, modelIDValue, manyIDValue, time.Now(), time.Now()}, + } + } + + statements = append(statements, associationStm) + } + + return statements +} diff --git a/connection.go b/connection.go index 33baebaa..1b1f0e14 100644 --- a/connection.go +++ b/connection.go @@ -16,11 +16,13 @@ var Connections = map[string]*Connection{} // Connection represents all of the necessary details for // talking with a datastore type Connection struct { - ID string - Store store - Dialect dialect - Elapsed int64 - TX *Tx + ID string + Store store + Dialect dialect + Elapsed int64 + TX *Tx + eager bool + eagerFields []string } func (c *Connection) String() string { diff --git a/executors.go b/executors.go index b5d7e8ef..e2005edd 100644 --- a/executors.go +++ b/executors.go @@ -11,7 +11,9 @@ import ( // Reload fetch fresh data for a given model, using its ID func (c *Connection) Reload(model interface{}) error { sm := Model{Value: model} - return c.Find(model, sm.ID()) + return sm.iterate(func(m *Model) error { + return c.Find(m.Value, m.ID()) + }) } // Exec runs the given query @@ -59,17 +61,22 @@ var emptyUUID = uuid.Nil.String() // or issues an Update otherwise. func (c *Connection) Save(model interface{}, excludeColumns ...string) error { sm := &Model{Value: model} - id := sm.ID() - - if fmt.Sprint(id) == "0" || fmt.Sprint(id) == emptyUUID { - return c.Create(model, excludeColumns...) - } - return c.Update(model, excludeColumns...) + return sm.iterate(func(m *Model) error { + id := m.ID() + if fmt.Sprint(id) == "0" || fmt.Sprint(id) == emptyUUID { + return c.Create(m.Value, excludeColumns...) + } + return c.Update(m.Value, excludeColumns...) + }) } // ValidateAndCreate applies validation rules on the given entry, then creates it // if the validation succeed, excluding the given columns. func (c *Connection) ValidateAndCreate(model interface{}, excludeColumns ...string) (*validate.Errors, error) { + if c.eager { + return c.eagerValidateAndCreate(model, excludeColumns...) + } + sm := &Model{Value: model} verrs, err := sm.validateCreate(c) if err != nil { @@ -84,33 +91,41 @@ func (c *Connection) ValidateAndCreate(model interface{}, excludeColumns ...stri // Create add a new given entry to the database, excluding the given columns. // It updates `created_at` and `updated_at` columns automatically. func (c *Connection) Create(model interface{}, excludeColumns ...string) error { - return c.timeFunc("Create", func() error { - var err error - sm := &Model{Value: model} + if c.eager { + return c.eagerCreate(model, excludeColumns...) + } - if err = sm.beforeSave(c); err != nil { - return err - } + sm := &Model{Value: model} + return sm.iterate(func(m *Model) error { + return c.timeFunc("Create", func() error { + var err error + if err = m.beforeSave(c); err != nil { + return err + } - if err = sm.beforeCreate(c); err != nil { - return err - } + if err = m.beforeCreate(c); err != nil { + return err + } - cols := columns.ColumnsForStructWithAlias(model, sm.TableName(), sm.As) - cols.Remove(excludeColumns...) + cols := columns.ColumnsForStructWithAlias(m.Value, m.TableName(), m.As) - sm.touchCreatedAt() - sm.touchUpdatedAt() + if sm.TableName() == m.TableName() { + cols.Remove(excludeColumns...) + } - if err = c.Dialect.Create(c.Store, sm, cols); err != nil { - return err - } + m.touchCreatedAt() + m.touchUpdatedAt() - if err = sm.afterCreate(c); err != nil { - return err - } + if err = c.Dialect.Create(c.Store, m, cols); err != nil { + return err + } + + if err = m.afterCreate(c); err != nil { + return err + } - return sm.afterSave(c) + return m.afterSave(c) + }) }) } @@ -131,47 +146,54 @@ func (c *Connection) ValidateAndUpdate(model interface{}, excludeColumns ...stri // Update writes changes from an entry to the database, excluding the given columns. // It updates the `updated_at` column automatically. func (c *Connection) Update(model interface{}, excludeColumns ...string) error { - return c.timeFunc("Update", func() error { - var err error - sm := &Model{Value: model} - - if err = sm.beforeSave(c); err != nil { - return err - } - if err = sm.beforeUpdate(c); err != nil { - return err - } - - cols := columns.ColumnsForStructWithAlias(model, sm.TableName(), sm.As) - cols.Remove("id", "created_at") - cols.Remove(excludeColumns...) - - sm.touchUpdatedAt() - - if err = c.Dialect.Update(c.Store, sm, cols); err != nil { - return err - } - if err = sm.afterUpdate(c); err != nil { - return err - } - - return sm.afterSave(c) + sm := &Model{Value: model} + return sm.iterate(func(m *Model) error { + return c.timeFunc("Update", func() error { + var err error + + if err = m.beforeSave(c); err != nil { + return err + } + if err = m.beforeUpdate(c); err != nil { + return err + } + + cols := columns.ColumnsForStructWithAlias(model, m.TableName(), m.As) + cols.Remove("id", "created_at") + + if m.TableName() == sm.TableName() { + cols.Remove(excludeColumns...) + } + + m.touchUpdatedAt() + + if err = c.Dialect.Update(c.Store, m, cols); err != nil { + return err + } + if err = m.afterUpdate(c); err != nil { + return err + } + + return m.afterSave(c) + }) }) } // Destroy deletes a given entry from the database func (c *Connection) Destroy(model interface{}) error { - return c.timeFunc("Destroy", func() error { - var err error - sm := &Model{Value: model} - - if err = sm.beforeDestroy(c); err != nil { - return err - } - if err = c.Dialect.Destroy(c.Store, sm); err != nil { - return err - } - - return sm.afterDestroy(c) + sm := &Model{Value: model} + return sm.iterate(func(m *Model) error { + return c.timeFunc("Destroy", func() error { + var err error + + if err = m.beforeDestroy(c); err != nil { + return err + } + if err = c.Dialect.Destroy(c.Store, m); err != nil { + return err + } + + return m.afterDestroy(c) + }) }) } diff --git a/executors_eager.go b/executors_eager.go new file mode 100644 index 00000000..4452f79a --- /dev/null +++ b/executors_eager.go @@ -0,0 +1,130 @@ +package pop + +import ( + "github.com/gobuffalo/pop/associations" + "github.com/gobuffalo/validate" +) + +func (c *Connection) eagerCreate(model interface{}, excludeColumns ...string) error { + asos, err := associations.AssociationsForStruct(model, c.eagerFields...) + if err != nil { + return err + } + + c.eager = false + + if len(asos) == 0 { + return c.Create(model, excludeColumns...) + } + + before := asos.AssociationsBeforeCreatable() + for index := range before { + i := before[index].BeforeInterface() + if i == nil { + continue + } + + err = c.Create(i) + if err != nil { + return err + } + + err = before[index].BeforeSetup() + if err != nil { + return err + } + } + + err = c.Create(model, excludeColumns...) + if err != nil { + return err + } + + after := asos.AssociationsAfterCreatable() + for index := range after { + err = after[index].AfterSetup() + if err != nil { + return err + } + + i := after[index].AfterInterface() + if i == nil { + continue + } + + err = c.Create(i) + if err != nil { + return err + } + } + + stms := asos.AssociationsCreatableStatement() + for index := range stms { + statements := stms[index].Statements() + for _, stm := range statements { + if c.TX != nil { + _, err := c.TX.Exec(c.Dialect.TranslateSQL(stm.Statement), stm.Args...) + if err != nil { + return err + } + continue + } + _, err = c.Store.Exec(c.Dialect.TranslateSQL(stm.Statement), stm.Args...) + if err != nil { + return err + } + } + } + + return err +} + +func (c *Connection) eagerValidateAndCreate(model interface{}, excludeColumns ...string) (*validate.Errors, error) { + asos, err := associations.AssociationsForStruct(model, c.eagerFields...) + verrs := validate.NewErrors() + + if err != nil { + return verrs, err + } + + if len(asos) == 0 { + c.eager = false + return c.ValidateAndCreate(model, excludeColumns...) + } + + before := asos.AssociationsBeforeCreatable() + for index := range before { + i := before[index].BeforeInterface() + if i == nil { + continue + } + + sm := &Model{Value: i} + verrs, err := sm.validateCreate(c) + if err != nil || verrs.HasAny() { + return verrs, err + } + } + + after := asos.AssociationsAfterCreatable() + for index := range after { + i := after[index].AfterInterface() + if i == nil { + continue + } + + sm := &Model{Value: i} + verrs, err := sm.validateCreate(c) + if err != nil || verrs.HasAny() { + return verrs, err + } + } + + sm := &Model{Value: model} + verrs, err = sm.validateCreate(c) + if err != nil || verrs.HasAny() { + return verrs, err + } + + return verrs, c.eagerCreate(model, excludeColumns...) +} diff --git a/executors_test.go b/executors_test.go index 07ea4062..b30d96db 100644 --- a/executors_test.go +++ b/executors_test.go @@ -40,6 +40,49 @@ func Test_ValidateAndSave(t *testing.T) { }) } +func Test_ValidateAndSave_With_Slice(t *testing.T) { + r := require.New(t) + validationLogs = []string{} + transaction(func(tx *pop.Connection) { + car := []ValidatableCar{ + {Name: "VW"}, + {Name: "AU"}, + } + verrs, err := tx.ValidateAndSave(&car) + r.NoError(err) + r.False(verrs.HasAny()) + r.Len(validationLogs, 4) + r.Equal([]string{"Validate", "ValidateSave", "Validate", "ValidateSave"}, validationLogs) + + r.NotZero(car[0].ID) + r.NotZero(car[0].CreatedAt) + r.NotZero(car[1].ID) + r.NotZero(car[1].CreatedAt) + + validationLogs = []string{} + car = []ValidatableCar{ + {Name: ""}, + {Name: "AU"}, + } + verrs, err = tx.ValidateAndSave(&car) + r.NoError(err) + r.True(verrs.HasAny()) + r.Len(validationLogs, 2) + errs := verrs.Get("name") + r.Len(errs, 1) + + validationLogs = []string{} + ncar := []NotValidatableCar{ + {Name: ""}, + {Name: "AU"}, + } + verrs, err = tx.ValidateAndSave(&ncar) + r.NoError(err) + r.False(verrs.HasAny()) + r.Len(validationLogs, 0) + }) +} + func Test_ValidateAndCreate(t *testing.T) { r := require.New(t) validationLogs = []string{} @@ -71,6 +114,48 @@ func Test_ValidateAndCreate(t *testing.T) { }) } +func Test_ValidateAndCreate_With_Slice(t *testing.T) { + r := require.New(t) + validationLogs = []string{} + transaction(func(tx *pop.Connection) { + car := []ValidatableCar{ + {Name: "VW"}, + {Name: "AU"}, + } + verrs, err := tx.ValidateAndCreate(&car) + r.NoError(err) + r.False(verrs.HasAny()) + r.Len(validationLogs, 4) + r.Equal([]string{"Validate", "ValidateCreate", "Validate", "ValidateCreate"}, validationLogs) + r.NotZero(car[0].ID) + r.NotZero(car[0].CreatedAt) + r.NotZero(car[1].ID) + r.NotZero(car[1].CreatedAt) + + validationLogs = []string{} + car = []ValidatableCar{ + {Name: ""}, + {Name: "AU"}, + } + verrs, err = tx.ValidateAndSave(&car) + r.NoError(err) + r.True(verrs.HasAny()) + r.Len(validationLogs, 2) + errs := verrs.Get("name") + r.Len(errs, 1) + + validationLogs = []string{} + ncar := []NotValidatableCar{ + {Name: ""}, + {Name: "AU"}, + } + verrs, err = tx.ValidateAndCreate(ncar) + r.NoError(err) + r.False(verrs.HasAny()) + r.Len(validationLogs, 0) + }) +} + func Test_ValidateAndUpdate(t *testing.T) { r := require.New(t) validationLogs = []string{} @@ -109,6 +194,52 @@ func Test_ValidateAndUpdate(t *testing.T) { }) } +func Test_ValidateAndUpdate_With_Slice(t *testing.T) { + r := require.New(t) + validationLogs = []string{} + transaction(func(tx *pop.Connection) { + car := []ValidatableCar{ + {Name: "VW"}, + {Name: "AU"}, + } + verrs, err := tx.ValidateAndCreate(&car) + r.NoError(err) + r.False(verrs.HasAny()) + r.Len(validationLogs, 4) + r.Equal([]string{"Validate", "ValidateCreate", "Validate", "ValidateCreate"}, validationLogs) + r.NotZero(car[0].ID) + r.NotZero(car[0].CreatedAt) + r.NotZero(car[1].ID) + r.NotZero(car[1].CreatedAt) + + validationLogs = []string{} + car[0].Name = "" + verrs, err = tx.ValidateAndUpdate(&car) + r.NoError(err) + r.True(verrs.HasAny()) + r.Len(validationLogs, 2) + errs := verrs.Get("name") + r.Len(errs, 1) + + validationLogs = []string{} + ncar := []NotValidatableCar{ + {Name: ""}, + {Name: "AU"}, + } + verrs, err = tx.ValidateAndCreate(&ncar) + r.NoError(err) + r.False(verrs.HasAny()) + r.Len(validationLogs, 0) + + validationLogs = []string{} + ncar[1].Name = "" + verrs, err = tx.ValidateAndUpdate(&ncar) + r.NoError(err) + r.False(verrs.HasAny()) + r.Len(validationLogs, 0) + }) +} + func Test_Exec(t *testing.T) { transaction(func(tx *pop.Connection) { a := require.New(t) @@ -164,6 +295,27 @@ func Test_Save(t *testing.T) { }) } +func Test_Save_With_Slice(t *testing.T) { + r := require.New(t) + transaction(func(tx *pop.Connection) { + u := Users{ + {Name: nulls.NewString("Mark")}, + {Name: nulls.NewString("Larry")}, + } + r.Zero(u[0].ID) + r.Zero(u[1].ID) + + tx.Save(&u) + r.NotZero(u[0].ID) + r.NotZero(u[1].ID) + + uat := u[0].UpdatedAt.UnixNano() + + tx.Save(u) + r.NotEqual(uat, u[0].UpdatedAt.UnixNano()) + }) +} + func Test_Create(t *testing.T) { transaction(func(tx *pop.Connection) { a := require.New(t) @@ -185,6 +337,142 @@ func Test_Create(t *testing.T) { }) } +func Test_Create_With_Slice(t *testing.T) { + transaction(func(tx *pop.Connection) { + a := require.New(t) + + count, _ := tx.Count(&User{}) + users := Users{ + {Name: nulls.NewString("Mark Bates")}, + {Name: nulls.NewString("Larry M. Jordan")}, + {Name: nulls.NewString("Pop")}, + } + err := tx.Create(&users) + a.NoError(err) + + ctx, _ := tx.Count(&User{}) + a.Equal(count+3, ctx) + }) +} + +func Test_Eager_Create_Has_Many(t *testing.T) { + transaction(func(tx *pop.Connection) { + a := require.New(t) + count, _ := tx.Count(&User{}) + user := User{ + Name: nulls.NewString("Mark 'Awesome' Bates"), + Books: Books{{Title: "Pop Book", Description: "Pop Book", Isbn: "PB1"}}, + FavoriteSong: Song{Title: "Hook - Blues Traveler"}, + Houses: Addresses{ + Address{HouseNumber: 86, Street: "Modelo"}, + }, + } + + err := tx.Eager().Create(&user) + a.NoError(err) + a.NotEqual(user.ID, 0) + + ctx, _ := tx.Count(&User{}) + a.Equal(count+1, ctx) + + ctx, _ = tx.Count(&Book{}) + a.Equal(count+1, ctx) + + ctx, _ = tx.Count(&Song{}) + a.Equal(count+1, ctx) + + ctx, _ = tx.Count(&Address{}) + a.Equal(count+1, ctx) + + u := User{} + q := tx.Eager().Where("name = ?", "Mark 'Awesome' Bates") + err = q.First(&u) + a.NoError(err) + a.Equal(u.Name.String, "Mark 'Awesome' Bates") + a.Equal(u.Books[0].Title, "Pop Book") + a.Equal(u.FavoriteSong.Title, "Hook - Blues Traveler") + a.Equal(u.Houses[0].Street, "Modelo") + }) +} + +func Test_Eager_Validate_And_Create_Has_Many(t *testing.T) { + a := require.New(t) + transaction(func(tx *pop.Connection) { + user := User{ + Name: nulls.NewString("Mark 'Awesome' Bates"), + Books: Books{{Title: "Pop Book", Isbn: "PB1"}}, + FavoriteSong: Song{Title: "Hook - Blues Traveler"}, + Houses: Addresses{ + Address{HouseNumber: 86, Street: "Modelo"}, + }, + } + + verrs, err := tx.Eager().ValidateAndCreate(&user) + a.NoError(err) + ctx, _ := tx.Count(&User{}) + a.Zero(ctx) + a.Equal(1, verrs.Count()) // Missing Books.Description. + }) +} + +func Test_Eager_Validate_And_Create_Parental(t *testing.T) { + a := require.New(t) + transaction(func(tx *pop.Connection) { + user := User{ + Name: nulls.NewString(""), + Books: Books{{Title: "Pop Book", Isbn: "PB1", Description: "Awesome Book!"}}, + FavoriteSong: Song{Title: "Hook - Blues Traveler"}, + Houses: Addresses{ + Address{HouseNumber: 86, Street: "Modelo"}, + }, + } + + verrs, err := tx.Eager().ValidateAndCreate(&user) + a.NoError(err) + ctx, _ := tx.Count(&User{}) + a.Zero(ctx) + a.Equal(1, verrs.Count()) // Missing Books.Description. + }) +} + +func Test_Eager_Create_Belongs_To(t *testing.T) { + transaction(func(tx *pop.Connection) { + a := require.New(t) + book := Book{ + Title: "Pop Book", + Description: "Pop Book", + Isbn: "PB1", + User: User{ + Name: nulls.NewString("Larry"), + }, + } + + err := tx.Eager().Create(&book) + a.NoError(err) + + ctx, _ := tx.Count(&Book{}) + a.Equal(1, ctx) + + ctx, _ = tx.Count(&User{}) + a.Equal(1, ctx) + }) +} + +func Test_Eager_Creation_Without_Associations(t *testing.T) { + transaction(func(tx *pop.Connection) { + a := require.New(t) + code := CourseCode{ + Course: Course{}, + } + + err := tx.Eager().Create(&code) + a.NoError(err) + + ctx, _ := tx.Count(&CourseCode{}) + a.Equal(1, ctx) + }) +} + func Test_Create_UUID(t *testing.T) { transaction(func(tx *pop.Connection) { a := require.New(t) @@ -267,6 +555,34 @@ func Test_Update(t *testing.T) { }) } +func Test_Update_With_Slice(t *testing.T) { + transaction(func(tx *pop.Connection) { + a := require.New(t) + + user := Users{ + {Name: nulls.NewString("Mark")}, + {Name: nulls.NewString("Larry")}, + } + tx.Create(&user) + + a.NotZero(user[0].CreatedAt) + a.NotZero(user[0].UpdatedAt) + + a.NotZero(user[1].CreatedAt) + a.NotZero(user[1].UpdatedAt) + + user[0].Name.String = "Marky" + user[1].Name.String = "Lawrence" + + err := tx.Update(&user) + a.NoError(err) + + tx.Reload(&user) + a.Equal(user[0].Name.String, "Marky") + a.Equal(user[1].Name.String, "Lawrence") + }) +} + func Test_Update_UUID(t *testing.T) { transaction(func(tx *pop.Connection) { r := require.New(t) @@ -309,6 +625,31 @@ func Test_Destroy(t *testing.T) { }) } +func Test_Destroy_With_Slice(t *testing.T) { + transaction(func(tx *pop.Connection) { + a := require.New(t) + + count, err := tx.Count("users") + user := Users{ + {Name: nulls.NewString("Mark")}, + {Name: nulls.NewString("Larry")}, + } + err = tx.Create(&user) + a.NoError(err) + a.NotEqual(user[0].ID, 0) + a.NotEqual(user[1].ID, 0) + + ctx, err := tx.Count("users") + a.Equal(count+2, ctx) + + err = tx.Destroy(&user) + a.NoError(err) + + ctx, _ = tx.Count("users") + a.Equal(count, ctx) + }) +} + func Test_Destroy_UUID(t *testing.T) { transaction(func(tx *pop.Connection) { r := require.New(t) diff --git a/finders.go b/finders.go index e4298e92..c4d803aa 100644 --- a/finders.go +++ b/finders.go @@ -19,7 +19,8 @@ var rLimit = regexp.MustCompile("(?i)(limit [0-9]+)$") // // c.Find(&User{}, 1) func (c *Connection) Find(model interface{}, id interface{}) error { - return Q(c).Find(model, id) + q := Q(c) + return q.Find(model, id) } // Find the first record of the model in the database with a particular id. @@ -46,7 +47,8 @@ func (q *Query) Find(model interface{}, id interface{}) error { // // c.First(&User{}) func (c *Connection) First(model interface{}) error { - return Q(c).First(model) + q := Q(c) + return q.First(model) } // First record of the model in the database that matches the query. @@ -76,7 +78,8 @@ func (q *Query) First(model interface{}) error { // // c.Last(&User{}) func (c *Connection) Last(model interface{}) error { - return Q(c).Last(model) + q := Q(c) + return q.Last(model) } // Last record of the model in the database that matches the query. @@ -108,7 +111,8 @@ func (q *Query) Last(model interface{}) error { // // c.All(&[]User{}) func (c *Connection) All(models interface{}) error { - return Q(c).All(models) + q := Q(c) + return q.All(models) } // All retrieves all of the records in the database that match the query. @@ -182,11 +186,13 @@ func (q *Query) eagerAssociations(model interface{}) error { } for _, association := range assos { - if association == associations.SkippedAssociation { + if association.Skipped() { continue } query := Q(q.Connection) + query.eager = false + whereCondition, args := association.Constraint() query = query.Where(whereCondition, args...) diff --git a/migrations/20160808213310_setup_tests2.down.fizz b/migrations/20160808213310_setup_tests2.down.fizz index 490dfe92..cda31523 100644 --- a/migrations/20160808213310_setup_tests2.down.fizz +++ b/migrations/20160808213310_setup_tests2.down.fizz @@ -1,4 +1,6 @@ drop_table("songs") +drop_table("course_codes") +drop_table("courses") {{ if eq .Dialect "postgres" }} drop_table("cakes") {{ end }} diff --git a/migrations/20160808213310_setup_tests2.up.fizz b/migrations/20160808213310_setup_tests2.up.fizz index 02066022..613b2fd9 100644 --- a/migrations/20160808213310_setup_tests2.up.fizz +++ b/migrations/20160808213310_setup_tests2.up.fizz @@ -24,6 +24,15 @@ create_table("users_addresses", func(t) { t.Column("address_id", "int", {}) }) +create_table("courses", func(t) { + t.Column("id", "uuid", {"primary": true}) +}) + +create_table("course_codes", func(t) { + t.Column("id", "uuid", {"primary": true}) + t.Column("course_id", "uuid", {}) +}) + {{ if eq .Dialect "postgres" }} create_table("cakes", func(t) { t.Column("int_slice", "int[]", {"null": true}) diff --git a/model.go b/model.go index b1f9db6a..e81af67e 100644 --- a/model.go +++ b/model.go @@ -17,6 +17,8 @@ var tableMapMu = sync.RWMutex{} // Value is the contents of a `Model`. type Value interface{} +type modelIterable func(*Model) error + // Model is used throughout Pop to wrap the end user interface // that is passed in to many functions. type Model struct { @@ -162,3 +164,21 @@ func (m *Model) whereID() string { } return value } + +func (m *Model) iterate(fn modelIterable) error { + v := reflect.Indirect(reflect.ValueOf(m.Value)) + if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { + for i := 0; i < v.Len(); i++ { + val := v.Index(i) + newModel := &Model{Value: val.Addr().Interface()} + err := fn(newModel) + + if err != nil { + return err + } + } + return nil + } + + return fn(m) +} diff --git a/nulls/nulls.go b/nulls/nulls.go new file mode 100644 index 00000000..4f547cb8 --- /dev/null +++ b/nulls/nulls.go @@ -0,0 +1,51 @@ +package nulls + +import ( + "database/sql/driver" + + "github.com/gobuffalo/uuid" +) + +// nullable a generic representation of nulls type. +type nullable interface { + Interface() interface{} + Value() (driver.Value, error) +} + +// Nulls a generic nulls type. something that implements +// nullable interface. can be any of nulls.Int, nulls.uuid.UUID +// nulls.String, etc. +type Nulls struct { + Value interface{} +} + +// Interface calls Interface function for value. +func (nulls *Nulls) Interface() interface{} { + n := nulls.Value.(nullable) + return n.Interface() +} + +// Parse parses the specified value to the corresponding +// nullable type. value is one of the inner value hold +// by a nullable type. i.e int, string, uuid.UUID etc. +func (nulls *Nulls) Parse(value interface{}) interface{} { + switch nulls.Value.(type) { + case Int: + return NewInt(value.(int)) + case Int64: + return NewInt64(value.(int64)) + case UUID: + return NewUUID(value.(uuid.UUID)) + default: + return value + } +} + +// New returns a wrapper called nulls for the +// interface passed as a param. +func New(i interface{}) *Nulls { + if _, ok := i.(nullable); !ok { + return nil + } + return &Nulls{Value: i} +} diff --git a/pop_test.go b/pop_test.go index 6e62a9ac..642b1c13 100644 --- a/pop_test.go +++ b/pop_test.go @@ -84,6 +84,14 @@ type User struct { Houses Addresses `many_to_many:"users_addresses"` } +// Validate gets run every time you call a "pop.Validate*" (pop.ValidateAndSave, pop.ValidateAndCreate, pop.ValidateAndUpdate) method. +// This method is not required and may be deleted. +func (u *User) Validate(tx *pop.Connection) (*validate.Errors, error) { + return validate.Validate( + &validators.StringIsPresent{Field: u.Name.String, Name: "Name"}, + ), nil +} + type Users []User type Book struct { @@ -98,6 +106,14 @@ type Book struct { UpdatedAt time.Time `db:"updated_at"` } +// Validate gets run every time you call a "pop.Validate*" (pop.ValidateAndSave, pop.ValidateAndCreate, pop.ValidateAndUpdate) method. +// This method is not required and may be deleted. +func (b *Book) Validate(tx *pop.Connection) (*validate.Errors, error) { + return validate.Validate( + &validators.StringIsPresent{Field: b.Description, Name: "Description"}, + ), nil +} + type Books []Book type Writer struct { @@ -179,6 +195,21 @@ type Composer struct { UpdatedAt time.Time `db:"updated_at"` } +type Course struct { + ID uuid.UUID `json:"id" db:"id"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` +} + +type CourseCode struct { + ID uuid.UUID `json:"id" db:"id"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + CourseID uuid.UUID `json:"course_id" db:"course_id"` + Course Course `json:"-" db:"-"` + // Course Course `belongs_to:"course"` +} + type ValidatableCar struct { ID int64 `db:"id"` Name string `db:"name"` diff --git a/query.go b/query.go index 96095488..0649fac2 100644 --- a/query.go +++ b/query.go @@ -69,20 +69,10 @@ func (q *Query) RawQuery(stmt string, args ...interface{}) *Query { // // c.Eager().Find(model, 1) // will load all associations for model. // c.Eager("Books").Find(model, 1) // will load only Book association for model. -func (c *Connection) Eager(fields ...string) *Query { - return Q(c).Eager(fields...) -} - -// Eager will enable load associations of the model. -// by defaults loads all the associations on the model, -// but can take a variadic list of associations to load. -// -// q.Eager().Find(model, 1) // will load all associations for model. -// q.Eager("Books").Find(model, 1) // will load only Book association for model. -func (q *Query) Eager(fields ...string) *Query { - q.eager = true - q.eagerFields = append(q.eagerFields, fields...) - return q +func (c *Connection) Eager(fields ...string) *Connection { + c.eager = true + c.eagerFields = append(c.eagerFields, fields...) + return c } // Where will append a where clause to the query. You may use `?` in place of @@ -91,7 +81,8 @@ func (q *Query) Eager(fields ...string) *Query { // c.Where("id = ?", 1) // q.Where("id in (?)", 1, 2, 3) func (c *Connection) Where(stmt string, args ...interface{}) *Query { - return Q(c).Where(stmt, args...) + q := Q(c) + return q.Where(stmt, args...) } // Where will append a where clause to the query. You may use `?` in place of @@ -141,8 +132,10 @@ func (q *Query) Limit(limit int) *Query { // Q will create a new "empty" query from the current connection. func Q(c *Connection) *Query { return &Query{ - RawSQL: &clause{}, - Connection: c, + RawSQL: &clause{}, + Connection: c, + eager: c.eager, + eagerFields: c.eagerFields, } } diff --git a/validations.go b/validations.go index 21ccad5d..21f130e2 100644 --- a/validations.go +++ b/validations.go @@ -1,6 +1,8 @@ package pop import ( + "reflect" + "github.com/gobuffalo/validate" "github.com/pkg/errors" ) @@ -13,6 +15,8 @@ type validateable interface { Validate(*Connection) (*validate.Errors, error) } +type modelIterableValidator func(*Model) (*validate.Errors, error) + func (m *Model) validate(c *Connection) (*validate.Errors, error) { if x, ok := m.Value.(beforeValidatable); ok { if err := x.BeforeValidations(c); err != nil { @@ -30,21 +34,23 @@ type validateCreateable interface { } func (m *Model) validateCreate(c *Connection) (*validate.Errors, error) { - verrs, err := m.validate(c) - if err != nil { - return verrs, errors.WithStack(err) - } - if x, ok := m.Value.(validateCreateable); ok { - vs, err := x.ValidateCreate(c) - if vs != nil { - verrs.Append(vs) - } + return m.iterateAndValidate(func(model *Model) (*validate.Errors, error) { + verrs, err := model.validate(c) if err != nil { return verrs, errors.WithStack(err) } - } + if x, ok := model.Value.(validateCreateable); ok { + vs, err := x.ValidateCreate(c) + if vs != nil { + verrs.Append(vs) + } + if err != nil { + return verrs, errors.WithStack(err) + } + } - return verrs, err + return verrs, err + }) } type validateSaveable interface { @@ -52,21 +58,23 @@ type validateSaveable interface { } func (m *Model) validateSave(c *Connection) (*validate.Errors, error) { - verrs, err := m.validate(c) - if err != nil { - return verrs, errors.WithStack(err) - } - if x, ok := m.Value.(validateSaveable); ok { - vs, err := x.ValidateSave(c) - if vs != nil { - verrs.Append(vs) - } + return m.iterateAndValidate(func(model *Model) (*validate.Errors, error) { + verrs, err := model.validate(c) if err != nil { return verrs, errors.WithStack(err) } - } + if x, ok := model.Value.(validateSaveable); ok { + vs, err := x.ValidateSave(c) + if vs != nil { + verrs.Append(vs) + } + if err != nil { + return verrs, errors.WithStack(err) + } + } - return verrs, err + return verrs, err + }) } type validateUpdateable interface { @@ -74,19 +82,39 @@ type validateUpdateable interface { } func (m *Model) validateUpdate(c *Connection) (*validate.Errors, error) { - verrs, err := m.validate(c) - if err != nil { - return verrs, errors.WithStack(err) - } - if x, ok := m.Value.(validateUpdateable); ok { - vs, err := x.ValidateUpdate(c) - if vs != nil { - verrs.Append(vs) - } + return m.iterateAndValidate(func(model *Model) (*validate.Errors, error) { + verrs, err := model.validate(c) if err != nil { return verrs, errors.WithStack(err) } + if x, ok := model.Value.(validateUpdateable); ok { + vs, err := x.ValidateUpdate(c) + if vs != nil { + verrs.Append(vs) + } + if err != nil { + return verrs, errors.WithStack(err) + } + } + + return verrs, err + }) +} + +func (m *Model) iterateAndValidate(fn modelIterableValidator) (*validate.Errors, error) { + v := reflect.Indirect(reflect.ValueOf(m.Value)) + if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { + for i := 0; i < v.Len(); i++ { + val := v.Index(i) + newModel := &Model{Value: val.Addr().Interface()} + verrs, err := fn(newModel) + + if err != nil || verrs.HasAny() { + return verrs, err + } + } + return validate.NewErrors(), nil } - return verrs, err + return fn(m) }