Skip to content

Commit

Permalink
Eager query method & Clean options for eager loading. (#73)
Browse files Browse the repository at this point in the history
* Return Eager method for Query
Clean eager fields slice after every SQL command execution

* [Refactor] rename eagerDisabled function to disableEager

* [Refactor] clean no sense test

* copy connection when eager is activeated on Connection
  • Loading branch information
larrymjordan authored and markbates committed Apr 25, 2018
1 parent 7e384f3 commit d02695e
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 7 deletions.
7 changes: 7 additions & 0 deletions associations/belongs_to_association.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ func (b *belongsToAssociation) BeforeInterface() interface{} {
if b.ownerModel.Kind() == reflect.Ptr {
return b.ownerModel.Interface()
}

currentVal := b.ownerModel.Interface()
zeroVal := reflect.Zero(b.ownerModel.Type()).Interface()
if reflect.DeepEqual(zeroVal, currentVal) {
return nil
}

return b.ownerModel.Addr().Interface()
}

Expand Down
7 changes: 7 additions & 0 deletions associations/has_one_association.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ func (h *hasOneAssociation) AfterInterface() interface{} {
if h.ownedModel.Kind() == reflect.Ptr {
return h.ownedModel.Interface()
}

currentVal := h.ownedModel.Interface()
zeroVal := reflect.Zero(h.ownedModel.Type()).Interface()
if reflect.DeepEqual(zeroVal, currentVal) {
return nil
}

return h.ownedModel.Addr().Interface()
}

Expand Down
15 changes: 15 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,15 @@ func (c *Connection) NewTransaction() (*Connection, error) {
return cn, nil
}

func (c *Connection) copy() *Connection {
return &Connection{
ID: randx.String(30),
Store: c.TX,
Dialect: c.Dialect,
TX: c.TX,
}
}

// Rollback will open a new transaction and automatically rollback that transaction
// when the inner function returns, regardless. This can be useful for tests, etc...
func (c *Connection) Rollback(fn func(tx *Connection)) error {
Expand Down Expand Up @@ -182,6 +191,12 @@ func (c *Connection) Q() *Query {
return Q(c)
}

// disableEager disables eager mode for current connection.
func (c *Connection) disableEager() {
c.eager = false
c.eagerFields = []string{}
}

// TruncateAll truncates all data from the datasource
func (c *Connection) TruncateAll() error {
return c.Dialect.TruncateAll(c)
Expand Down
4 changes: 2 additions & 2 deletions executors_eager.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ func (c *Connection) eagerCreate(model interface{}, excludeColumns ...string) er
return err
}

c.eager = false
c.disableEager()

if len(asos) == 0 {
return c.Create(model, excludeColumns...)
Expand Down Expand Up @@ -88,7 +88,7 @@ func (c *Connection) eagerValidateAndCreate(model interface{}, excludeColumns ..
}

if len(asos) == 0 {
c.eager = false
c.disableEager()
return c.ValidateAndCreate(model, excludeColumns...)
}

Expand Down
25 changes: 25 additions & 0 deletions executors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,31 @@ func Test_Eager_Create_Has_Many(t *testing.T) {
})
}

func Test_Eager_Create_Has_Many_Reset_Eager_Mode_Connection(t *testing.T) {
transaction(func(tx *pop.Connection) {
a := require.New(t)
count, _ := tx.Count(&User{})
user1 := User{
Name: nulls.NewString("Mark 'Awesome' Bates"),
Books: Books{{Title: "Pop Book", Description: "Pop Book", Isbn: "PB1"}},
}

err := tx.Eager("Books").Create(&user1)
a.NoError(err)
ctx, _ := tx.Count(&User{})
a.Equal(count+1, ctx)
ctx, _ = tx.Count(&Book{})
a.Equal(count+1, ctx)

book := Book{Title: "Pop Book", Description: "Pop Book", Isbn: "PB1"}

err = tx.Eager().Create(&book)
a.NoError(err)
ctx, _ = tx.Count(&Book{})
a.Equal(count+2, ctx)
})
}

func Test_Eager_Validate_And_Create_Has_Many(t *testing.T) {
a := require.New(t)
transaction(func(tx *pop.Connection) {
Expand Down
5 changes: 3 additions & 2 deletions finders.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,18 +180,19 @@ func (q *Query) eagerAssociations(model interface{}) error {
}

assos, err := associations.AssociationsForStruct(model, q.eagerFields...)

if err != nil {
return err
}

//disable eager mode for current connection.
q.disableEager()

for _, association := range assos {
if association.Skipped() {
continue
}

query := Q(q.Connection)
query.eager = false

whereCondition, args := association.Constraint()
query = query.Where(whereCondition, args...)
Expand Down
21 changes: 21 additions & 0 deletions finders_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,27 @@ func Test_All_Eager(t *testing.T) {
})
}

func Test_All_Eager_For_Query(t *testing.T) {
transaction(func(tx *pop.Connection) {
a := require.New(t)

user := User{Name: nulls.NewString("Mark")}
err := tx.Create(&user)
a.NoError(err)

book := Book{Title: "Pop Book", Isbn: "PB1", UserID: nulls.NewInt(user.ID)}
err = tx.Create(&book)
a.NoError(err)

u := Users{}
q := tx.Q()
err = q.Eager("Books").Where("name = 'Mark'").All(&u)
a.NoError(err)
a.Equal(len(u), 1)
a.Equal(len(u[0].Books), 1)
})
}

func Test_All_Eager_Field_Not_Found_Error(t *testing.T) {
transaction(func(tx *pop.Connection) {
a := require.New(t)
Expand Down
25 changes: 22 additions & 3 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,28 @@ func (q *Query) RawQuery(stmt string, args ...interface{}) *Query {
// c.Eager().Find(model, 1) // will load all associations for model.
// c.Eager("Books").Find(model, 1) // will load only Book association for model.
func (c *Connection) Eager(fields ...string) *Connection {
c.eager = true
c.eagerFields = append(c.eagerFields, fields...)
return c
con := c.copy()
con.eager = true
con.eagerFields = append(c.eagerFields, fields...)
return con
}

// Eager will enable load associations of the model.
// by defaults loads all the associations on the model,
// but can take a variadic list of associations to load.
//
// q.Eager().Find(model, 1) // will load all associations for model.
// q.Eager("Books").Find(model, 1) // will load only Book association for model.
func (q *Query) Eager(fields ...string) *Query {
q.eager = true
q.eagerFields = append(q.eagerFields, fields...)
return q
}

// disableEager disables eager mode for current query and Connection.
func (q *Query) disableEager() {
q.Connection.eager, q.eager = false, false
q.Connection.eagerFields, q.eagerFields = []string{}, []string{}
}

// Where will append a where clause to the query. You may use `?` in place of
Expand Down

0 comments on commit d02695e

Please sign in to comment.