Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding context timeouts for management queries based on reload interval #390

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
306 changes: 221 additions & 85 deletions client/gosqldriver/connection.go

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions client/gosqldriver/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ func (st *stmt) NumInput() int {
// Implements driver.Stmt.
// Exec executes a query that doesn't return rows, such as an INSERT or UPDATE.
func (st *stmt) Exec(args []driver.Value) (driver.Result, error) {
defer st.hera.finish()
sk := 0
if len(st.hera.shardKeyPayload) > 0 {
sk = 1
Expand Down Expand Up @@ -167,6 +168,10 @@ func (st *stmt) Exec(args []driver.Value) (driver.Result, error) {
func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
//TODO: refactor ExecContext / Exec to reuse code
//TODO: honor the context timeout and return when it is canceled
if err := st.hera.watchCancel(ctx); err != nil {
return nil, err
}
defer st.hera.finish()
sk := 0
if len(st.hera.shardKeyPayload) > 0 {
sk = 1
Expand Down Expand Up @@ -255,6 +260,7 @@ func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driv
// Implements driver.Stmt.
// Query executes a query that may return rows, such as a SELECT.
func (st *stmt) Query(args []driver.Value) (driver.Rows, error) {
defer st.hera.finish()
sk := 0
if len(st.hera.shardKeyPayload) > 0 {
sk = 1
Expand Down Expand Up @@ -359,6 +365,10 @@ Loop:
func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
// TODO: refactor Query/QueryContext to reuse code
// TODO: honor the context timeout and return when it is canceled
if err := st.hera.watchCancel(ctx); err != nil {
return nil, err
}
defer st.hera.finish()
sk := 0
if len(st.hera.shardKeyPayload) > 0 {
sk = 1
Expand Down
67 changes: 67 additions & 0 deletions client/gosqldriver/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package gosqldriver

import (
"errors"
"sync"
"sync/atomic"
)

var ErrInvalidConn = errors.New("invalid connection")

// atomicError provides thread-safe error handling
type atomicError struct {
value atomic.Value
mu sync.Mutex
}

// Set sets the error value atomically. The value must not be nil.
func (ae *atomicError) Set(err error) {
if err == nil {
panic("atomicError: nil error value")
}
ae.mu.Lock()
defer ae.mu.Unlock()
ae.value.Store(err)
}

// Value returns the current error value, or nil if none is set.
func (ae *atomicError) Value() error {
v := ae.value.Load()
if v == nil {
return nil
}
return v.(error)
}

type atomicBool struct {
value uint32
mu sync.Mutex
}

// Store sets the value of the bool regardless of the previous value
func (ab *atomicBool) Store(value bool) {
ab.mu.Lock()
defer ab.mu.Unlock()
if value {
atomic.StoreUint32(&ab.value, 1)
} else {
atomic.StoreUint32(&ab.value, 0)
}
}

// Load returns whether the current boolean value is true
func (ab *atomicBool) Load() bool {
ab.mu.Lock()
defer ab.mu.Unlock()
return atomic.LoadUint32(&ab.value) > 0
}

// Swap sets the value of the bool and returns the old value.
func (ab *atomicBool) Swap(value bool) bool {
ab.mu.Lock()
defer ab.mu.Unlock()
if value {
return atomic.SwapUint32(&ab.value, 1) > 0
}
return atomic.SwapUint32(&ab.value, 0) > 0
}
9 changes: 6 additions & 3 deletions lib/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@ package lib
import (
"errors"
"fmt"
"github.com/paypal/hera/config"
"github.com/paypal/hera/utility/logger"
"os"
"path/filepath"
"strings"
"sync/atomic"

"github.com/paypal/hera/config"
"github.com/paypal/hera/utility/logger"
)

//The Config contains all the static configuration
Expand Down Expand Up @@ -176,6 +175,9 @@ type Config struct {

// Max desired percentage of healthy workers for the worker pool
MaxDesiredHealthyWorkerPct int

//Timeout for management queries.
ManagementQueriesTimeoutInUs int
}

// The OpsConfig contains the configuration that can be modified during run time
Expand Down Expand Up @@ -461,6 +463,7 @@ func InitConfig() error {
gAppConfig.MaxDesiredHealthyWorkerPct = 90
}

gAppConfig.ManagementQueriesTimeoutInUs = cdb.GetOrDefaultInt("management_queries_timeout_us", 200000) //200 milliseconds
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion lib/querybindblocker.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func InitQueryBindBlocker(modName string) {
}

func loadBlockQueryBind(db *sql.DB) {
ctx, cancel := context.WithTimeout(context.Background(), 5000*time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(GetConfig().ManagementQueriesTimeoutInUs)*time.Microsecond)
defer cancel()
conn, err := db.Conn(ctx)
if err != nil {
Expand Down
28 changes: 22 additions & 6 deletions lib/racmaint.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,32 +102,48 @@ func racMaintMain(shard int, interval int, cmdLineModuleName string) {
binds[0], err = os.Hostname()
binds[0] = strings.ToUpper(binds[0])
binds[1] = strings.ToUpper(cmdLineModuleName) // */
//First time data loading
racMaint(&ctx, shard, db, racSQL, cmdLineModuleName, prev, GetConfig().ManagementQueriesTimeoutInUs)

timeTicker := time.NewTicker(time.Second * time.Duration(interval))
defer timeTicker.Stop()
for {
racMaint(ctx, shard, db, racSQL, cmdLineModuleName, prev)
time.Sleep(time.Second * time.Duration(interval))
select {
case <-ctx.Done():
logger.GetLogger().Log(logger.Alert, "Application main context has been closed, so exiting from racmaint data reload.")
return
case <-timeTicker.C:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with this but why do we need to introduce the timeTicker? I feel the existing code is simpler.

//Periodic data loading
racMaint(&ctx, shard, db, racSQL, cmdLineModuleName, prev, GetConfig().ManagementQueriesTimeoutInUs)
timeTicker.Reset(time.Second * time.Duration(interval))
}
}
}

/*
racMaint is the main function for RAC maintenance processing, being called regularly.
When maintenance is planned, it calls workerpool.RacMaint to start the actuall processing
*/
func racMaint(ctx context.Context, shard int, db *sql.DB, racSQL string, cmdLineModuleName string, prev map[racCfgKey]racCfg) {
func racMaint(ctx *context.Context, shard int, db *sql.DB, racSQL string, cmdLineModuleName string, prev map[racCfgKey]racCfg, queryTimeoutInUs int) {
//
// print this log for unittesting
//
if logger.GetLogger().V(logger.Verbose) {
logger.GetLogger().Log(logger.Verbose, "Rac maint check, shard =", shard)
}
conn, err := db.Conn(ctx)
//create cancellable context
queryContext, cancel := context.WithTimeout(*ctx, time.Duration(queryTimeoutInUs)*time.Microsecond)
defer cancel() // Always call cancel to release resources associated with the context

conn, err := db.Conn(queryContext)
if err != nil {
if logger.GetLogger().V(logger.Info) {
logger.GetLogger().Log(logger.Info, "Error (conn) rac maint for shard =", shard, ",err :", err)
}
return
}
defer conn.Close()
stmt, err := conn.PrepareContext(ctx, racSQL)
stmt, err := conn.PrepareContext(queryContext, racSQL)
if err != nil {
if logger.GetLogger().V(logger.Info) {
logger.GetLogger().Log(logger.Info, "Error (stmt) rac maint for shard =", shard, ",err :", err)
Expand All @@ -139,7 +155,7 @@ func racMaint(ctx context.Context, shard int, db *sql.DB, racSQL string, cmdLine
hostname = strings.ToUpper(hostname)
module := strings.ToUpper(cmdLineModuleName)
module_taf := fmt.Sprintf("%s_TAF", module)
rows, err := stmt.QueryContext(ctx, hostname, module_taf, module)
rows, err := stmt.QueryContext(queryContext, hostname, module_taf, module)
if err != nil {
if logger.GetLogger().V(logger.Info) {
logger.GetLogger().Log(logger.Info, "Error (query) rac maint for shard =", shard, ",err :", err)
Expand Down
73 changes: 41 additions & 32 deletions lib/shardingcfg.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func getSQL() string {
/*
load the physical to logical maping
*/
func loadMap(ctx context.Context, db *sql.DB) error {
func loadMap(ctx *context.Context, db *sql.DB, queryTimeoutInterval int) error {
if logger.GetLogger().V(logger.Verbose) {
logger.GetLogger().Log(logger.Verbose, "Begin loading shard map")
}
Expand All @@ -109,17 +109,18 @@ func loadMap(ctx context.Context, db *sql.DB) error {
logger.GetLogger().Log(logger.Verbose, "Done loading shard map")
}()
}

conn, err := db.Conn(ctx)
queryContext, cancel := context.WithTimeout(*ctx, time.Duration(queryTimeoutInterval)*time.Microsecond)
defer cancel()
conn, err := db.Conn(queryContext)
if err != nil {
return fmt.Errorf("Error (conn) loading shard map: %s", err.Error())
}
defer conn.Close()
stmt, err := conn.PrepareContext(ctx, getSQL())
stmt, err := conn.PrepareContext(queryContext, getSQL())
if err != nil {
return fmt.Errorf("Error (stmt) loading shard map: %s", err.Error())
}
rows, err := stmt.QueryContext(ctx)
rows, err := stmt.QueryContext(queryContext)
if err != nil {
return fmt.Errorf("Error (query) loading shard map: %s", err.Error())
}
Expand Down Expand Up @@ -216,7 +217,7 @@ func getWLSQL() string {
/*
load the whitelist mapping
*/
func loadWhitelist(ctx context.Context, db *sql.DB) {
func loadWhitelist(ctx *context.Context, db *sql.DB, timeoutInMs int) {
if logger.GetLogger().V(logger.Verbose) {
logger.GetLogger().Log(logger.Verbose, "Begin loading whitelist")
}
Expand All @@ -225,19 +226,20 @@ func loadWhitelist(ctx context.Context, db *sql.DB) {
logger.GetLogger().Log(logger.Verbose, "Done loading whitelist")
}()
}

conn, err := db.Conn(ctx)
queryContext, cancel := context.WithTimeout(*ctx, time.Duration(timeoutInMs)*time.Microsecond)
defer cancel()
conn, err := db.Conn(queryContext)
if err != nil {
logger.GetLogger().Log(logger.Alert, "Error (conn) loading whitelist:", err)
return
}
defer conn.Close()
stmt, err := conn.PrepareContext(ctx, getWLSQL())
stmt, err := conn.PrepareContext(queryContext, getWLSQL())
if err != nil {
logger.GetLogger().Log(logger.Alert, "Error (stmt) loading whitelist:", err)
return
}
rows, err := stmt.QueryContext(ctx)
rows, err := stmt.QueryContext(queryContext)
if err != nil {
logger.GetLogger().Log(logger.Alert, "Error (query) loading whitelist:", err)
return
Expand Down Expand Up @@ -291,7 +293,10 @@ func InitShardingCfg() error {
ctx := context.Background()
var db *sql.DB
var err error

reloadInterval := time.Second * time.Duration(GetConfig().ShardingCfgReloadInterval)
if reloadInterval < 100*time.Millisecond {
reloadInterval = 100 * time.Millisecond
}
i := 0
for ; i < 60; i++ {
for shard := 0; shard < GetConfig().NumOfShards; shard++ {
Expand All @@ -300,13 +305,13 @@ func InitShardingCfg() error {
}
db, err = openDb(shard)
if err == nil {
err = loadMap(ctx, db)
err = loadMap(&ctx, db, GetConfig().ManagementQueriesTimeoutInUs)
if err == nil {
break
}
}
logger.GetLogger().Log(logger.Warning, "Error <", err, "> loading the shard map from shard", shard)
evt := cal.NewCalEvent(cal.EventTypeError, "no_shard_map", cal.TransOK, "Error loading shard map")
evt := cal.NewCalEvent(cal.EventTypeError, "no_shard_map", cal.TransOK, fmt.Sprintf("Error loading shard map %v", err))
evt.Completed()
}
if err == nil {
Expand All @@ -319,32 +324,36 @@ func InitShardingCfg() error {
return errors.New("Failed to load shard map, no more retry")
}
if GetConfig().EnableWhitelistTest {
loadWhitelist(ctx, db)
loadWhitelist(&ctx, db, GetConfig().ManagementQueriesTimeoutInUs)
}
go func() {
reloadTimer := time.NewTimer(reloadInterval) //Periodic reload timer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be replaced with NewTicker as the timer fires just once.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added reset and clean up options for timer

defer reloadTimer.Stop()
for {
reloadInterval := time.Second * time.Duration(GetConfig().ShardingCfgReloadInterval)
if reloadInterval < 100 * time.Millisecond {
reloadInterval = 100 * time.Millisecond
}
time.Sleep(reloadInterval)
for shard := 0; shard < GetConfig().NumOfShards; shard++ {
if db != nil {
db.Close()
}
db, err = openDb(shard)
if err == nil {
err = loadMap(ctx, db)
select {
case <-ctx.Done():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us also add timeout-related tests for validation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added tests.

logger.GetLogger().Log(logger.Alert, "Application main context has been closed, so exiting from shard-config data reload.")
return
case <-reloadTimer.C:
for shard := 0; shard < GetConfig().NumOfShards; shard++ {
if db != nil {
db.Close()
}
db, err = openDb(shard)
if err == nil {
if shard == 0 && GetConfig().EnableWhitelistTest {
loadWhitelist(ctx, db)
err = loadMap(&ctx, db, GetConfig().ManagementQueriesTimeoutInUs)
if err == nil {
if shard == 0 && GetConfig().EnableWhitelistTest {
loadWhitelist(&ctx, db, GetConfig().ManagementQueriesTimeoutInUs)
}
break
}
break
}
logger.GetLogger().Log(logger.Warning, "Error <", err, "> loading the shard map from shard", shard)
evt := cal.NewCalEvent(cal.EventTypeError, "no_shard_map", cal.TransOK, err.Error())
evt.Completed()
}
logger.GetLogger().Log(logger.Warning, "Error <", err, "> loading the shard map from shard", shard)
evt := cal.NewCalEvent(cal.EventTypeError, "no_shard_map", cal.TransOK, "Error loading shard map")
evt.Completed()
reloadTimer.Reset(reloadInterval) //Reset timer
}
}
}()
Expand Down
Loading
Loading