diff --git a/mysql.go b/mysql.go index 24e3d86e..da69a159 100644 --- a/mysql.go +++ b/mysql.go @@ -1,6 +1,7 @@ package pop import ( + "bytes" "fmt" "io" "os" @@ -176,9 +177,7 @@ func (m *mysql) LoadSchema(r io.Reader) error { } func (m *mysql) TruncateAll(tx *Connection) error { - stmts := []struct { - Stmt string `db:"stmt"` - }{} + stmts := []string{} err := tx.RawQuery(mysqlTruncate, m.Details().Database).All(&stmts) if err != nil { return err @@ -186,11 +185,15 @@ func (m *mysql) TruncateAll(tx *Connection) error { if len(stmts) == 0 { return nil } - qs := []string{} - for _, x := range stmts { - qs = append(qs, x.Stmt) - } - return tx.RawQuery(strings.Join(qs, " ")).Exec() + + var qb bytes.Buffer + // #49: Disable foreign keys before truncation + qb.WriteString("SET SESSION FOREIGN_KEY_CHECKS = 0; ") + qb.WriteString(strings.Join(stmts, " ")) + // #49: Re-enable foreign keys after truncation + qb.WriteString(" SET SESSION FOREIGN_KEY_CHECKS = 1;") + + return tx.RawQuery(qb.String()).Exec() } func newMySQL(deets *ConnectionDetails) dialect { @@ -201,4 +204,4 @@ func newMySQL(deets *ConnectionDetails) dialect { return cd } -const mysqlTruncate = "SELECT concat('TRUNCATE TABLE `', TABLE_NAME, '`;') as stmt FROM INFORMATION_SCHEMA.TABLES where TABLE_SCHEMA = ?" +const mysqlTruncate = "SELECT concat('TRUNCATE TABLE `', TABLE_NAME, '`;') as stmt FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_type <> 'VIEW'"