From 9cb81ce7c7c17fbdb31da047a04b28025e9cae9e Mon Sep 17 00:00:00 2001 From: Rajesh S Date: Thu, 21 Mar 2024 17:28:48 +0530 Subject: [PATCH 1/6] adding context timeouts for management queries --- lib/racmaint.go | 27 ++++++++++++++---- lib/shardingcfg.go | 69 +++++++++++++++++++++++++--------------------- 2 files changed, 59 insertions(+), 37 deletions(-) diff --git a/lib/racmaint.go b/lib/racmaint.go index 27c3bbea..2aef6912 100644 --- a/lib/racmaint.go +++ b/lib/racmaint.go @@ -102,9 +102,20 @@ func racMaintMain(shard int, interval int, cmdLineModuleName string) { binds[0], err = os.Hostname() binds[0] = strings.ToUpper(binds[0]) binds[1] = strings.ToUpper(cmdLineModuleName) // */ + waitTime := time.Second * time.Duration(interval) + //First time data loading + racMaint(&ctx, shard, db, racSQL, cmdLineModuleName, prev, waitTime/2) + + timeTicker := time.NewTicker(waitTime) for { - racMaint(ctx, shard, db, racSQL, cmdLineModuleName, prev) - time.Sleep(time.Second * time.Duration(interval)) + select { + case <-ctx.Done(): + logger.GetLogger().Log(logger.Alert, "Application main context has been closed, so exiting from racmaint data reload.") + return + case <-timeTicker.C: + //Periodic data loading + racMaint(&ctx, shard, db, racSQL, cmdLineModuleName, prev, waitTime/2) + } } } @@ -112,14 +123,18 @@ func racMaintMain(shard int, interval int, cmdLineModuleName string) { racMaint is the main function for RAC maintenance processing, being called regularly. When maintenance is planned, it calls workerpool.RacMaint to start the actuall processing */ -func racMaint(ctx context.Context, shard int, db *sql.DB, racSQL string, cmdLineModuleName string, prev map[racCfgKey]racCfg) { +func racMaint(ctx *context.Context, shard int, db *sql.DB, racSQL string, cmdLineModuleName string, prev map[racCfgKey]racCfg, queryTimeout time.Duration) { // // print this log for unittesting // if logger.GetLogger().V(logger.Verbose) { logger.GetLogger().Log(logger.Verbose, "Rac maint check, shard =", shard) } - conn, err := db.Conn(ctx) + //create cancellable context + queryContext, cancel := context.WithTimeout(*ctx, queryTimeout) + defer cancel() // Always call cancel to release resources associated with the context + + conn, err := db.Conn(queryContext) if err != nil { if logger.GetLogger().V(logger.Info) { logger.GetLogger().Log(logger.Info, "Error (conn) rac maint for shard =", shard, ",err :", err) @@ -127,7 +142,7 @@ func racMaint(ctx context.Context, shard int, db *sql.DB, racSQL string, cmdLine return } defer conn.Close() - stmt, err := conn.PrepareContext(ctx, racSQL) + stmt, err := conn.PrepareContext(queryContext, racSQL) if err != nil { if logger.GetLogger().V(logger.Info) { logger.GetLogger().Log(logger.Info, "Error (stmt) rac maint for shard =", shard, ",err :", err) @@ -139,7 +154,7 @@ func racMaint(ctx context.Context, shard int, db *sql.DB, racSQL string, cmdLine hostname = strings.ToUpper(hostname) module := strings.ToUpper(cmdLineModuleName) module_taf := fmt.Sprintf("%s_TAF", module) - rows, err := stmt.QueryContext(ctx, hostname, module_taf, module) + rows, err := stmt.QueryContext(queryContext, hostname, module_taf, module) if err != nil { if logger.GetLogger().V(logger.Info) { logger.GetLogger().Log(logger.Info, "Error (query) rac maint for shard =", shard, ",err :", err) diff --git a/lib/shardingcfg.go b/lib/shardingcfg.go index c6dac50c..995f53e0 100644 --- a/lib/shardingcfg.go +++ b/lib/shardingcfg.go @@ -100,7 +100,7 @@ func getSQL() string { /* load the physical to logical maping */ -func loadMap(ctx context.Context, db *sql.DB) error { +func loadMap(ctx *context.Context, db *sql.DB, queryTimeoutInterval time.Duration) error { if logger.GetLogger().V(logger.Verbose) { logger.GetLogger().Log(logger.Verbose, "Begin loading shard map") } @@ -109,17 +109,18 @@ func loadMap(ctx context.Context, db *sql.DB) error { logger.GetLogger().Log(logger.Verbose, "Done loading shard map") }() } - - conn, err := db.Conn(ctx) + queryContext, cancel := context.WithTimeout(*ctx, queryTimeoutInterval) + defer cancel() + conn, err := db.Conn(queryContext) if err != nil { return fmt.Errorf("Error (conn) loading shard map: %s", err.Error()) } defer conn.Close() - stmt, err := conn.PrepareContext(ctx, getSQL()) + stmt, err := conn.PrepareContext(queryContext, getSQL()) if err != nil { return fmt.Errorf("Error (stmt) loading shard map: %s", err.Error()) } - rows, err := stmt.QueryContext(ctx) + rows, err := stmt.QueryContext(queryContext) if err != nil { return fmt.Errorf("Error (query) loading shard map: %s", err.Error()) } @@ -216,7 +217,7 @@ func getWLSQL() string { /* load the whitelist mapping */ -func loadWhitelist(ctx context.Context, db *sql.DB) { +func loadWhitelist(ctx *context.Context, db *sql.DB, timeout time.Duration) { if logger.GetLogger().V(logger.Verbose) { logger.GetLogger().Log(logger.Verbose, "Begin loading whitelist") } @@ -225,19 +226,20 @@ func loadWhitelist(ctx context.Context, db *sql.DB) { logger.GetLogger().Log(logger.Verbose, "Done loading whitelist") }() } - - conn, err := db.Conn(ctx) + queryContext, cancel := context.WithTimeout(*ctx, timeout) + defer cancel() + conn, err := db.Conn(queryContext) if err != nil { logger.GetLogger().Log(logger.Alert, "Error (conn) loading whitelist:", err) return } defer conn.Close() - stmt, err := conn.PrepareContext(ctx, getWLSQL()) + stmt, err := conn.PrepareContext(queryContext, getWLSQL()) if err != nil { logger.GetLogger().Log(logger.Alert, "Error (stmt) loading whitelist:", err) return } - rows, err := stmt.QueryContext(ctx) + rows, err := stmt.QueryContext(queryContext) if err != nil { logger.GetLogger().Log(logger.Alert, "Error (query) loading whitelist:", err) return @@ -291,7 +293,10 @@ func InitShardingCfg() error { ctx := context.Background() var db *sql.DB var err error - + reloadInterval := time.Second * time.Duration(GetConfig().ShardingCfgReloadInterval) + if reloadInterval < 100*time.Millisecond { + reloadInterval = 100 * time.Millisecond + } i := 0 for ; i < 60; i++ { for shard := 0; shard < GetConfig().NumOfShards; shard++ { @@ -300,7 +305,7 @@ func InitShardingCfg() error { } db, err = openDb(shard) if err == nil { - err = loadMap(ctx, db) + err = loadMap(&ctx, db, reloadInterval/2) if err == nil { break } @@ -319,32 +324,34 @@ func InitShardingCfg() error { return errors.New("Failed to load shard map, no more retry") } if GetConfig().EnableWhitelistTest { - loadWhitelist(ctx, db) + loadWhitelist(&ctx, db, reloadInterval/2) } go func() { + reloadTimer := time.NewTimer(reloadInterval) //Periodic reload timer for { - reloadInterval := time.Second * time.Duration(GetConfig().ShardingCfgReloadInterval) - if reloadInterval < 100 * time.Millisecond { - reloadInterval = 100 * time.Millisecond - } - time.Sleep(reloadInterval) - for shard := 0; shard < GetConfig().NumOfShards; shard++ { - if db != nil { - db.Close() - } - db, err = openDb(shard) - if err == nil { - err = loadMap(ctx, db) + select { + case <-ctx.Done(): + logger.GetLogger().Log(logger.Alert, "Application main context has been closed, so exiting from shard-config data reload.") + return + case <-reloadTimer.C: + for shard := 0; shard < GetConfig().NumOfShards; shard++ { + if db != nil { + db.Close() + } + db, err = openDb(shard) if err == nil { - if shard == 0 && GetConfig().EnableWhitelistTest { - loadWhitelist(ctx, db) + err = loadMap(&ctx, db, reloadInterval/2) + if err == nil { + if shard == 0 && GetConfig().EnableWhitelistTest { + loadWhitelist(&ctx, db, reloadInterval/2) + } + break } - break } + logger.GetLogger().Log(logger.Warning, "Error <", err, "> loading the shard map from shard", shard) + evt := cal.NewCalEvent(cal.EventTypeError, "no_shard_map", cal.TransOK, "Error loading shard map") + evt.Completed() } - logger.GetLogger().Log(logger.Warning, "Error <", err, "> loading the shard map from shard", shard) - evt := cal.NewCalEvent(cal.EventTypeError, "no_shard_map", cal.TransOK, "Error loading shard map") - evt.Completed() } } }() From a8f3724447988b9cc0a3b20b4a8035c3c9e948cc Mon Sep 17 00:00:00 2001 From: Rajesh S Date: Fri, 22 Mar 2024 11:35:51 +0530 Subject: [PATCH 2/6] fixing shard key atuo discovery test --- .../sharding_tests/shard_basic/main_test.go | 233 +++++++++--------- 1 file changed, 115 insertions(+), 118 deletions(-) diff --git a/tests/functionaltest/sharding_tests/shard_basic/main_test.go b/tests/functionaltest/sharding_tests/shard_basic/main_test.go index a85fd48e..eef25f43 100644 --- a/tests/functionaltest/sharding_tests/shard_basic/main_test.go +++ b/tests/functionaltest/sharding_tests/shard_basic/main_test.go @@ -2,11 +2,11 @@ package main import ( "bytes" - "os/exec" "context" "database/sql" "fmt" "os" + "os/exec" "testing" "time" @@ -19,84 +19,82 @@ var mx testutil.Mux var tableName string func cfg() (map[string]string, map[string]string, testutil.WorkerType) { - fmt.Println ("setup() begin") - appcfg := make(map[string]string) - appcfg["bind_port"] = "31002" - appcfg["log_level"] = "5" - appcfg["log_file"] = "hera.log" - appcfg["rac_sql_interval"] = "0" + fmt.Println("setup() begin") + appcfg := make(map[string]string) + appcfg["bind_port"] = "31002" + appcfg["log_level"] = "5" + appcfg["log_file"] = "hera.log" + appcfg["rac_sql_interval"] = "0" //For sharding - appcfg["enable_sharding"] = "true" - appcfg["num_shards"] = "5" - appcfg["max_scuttle"] = "128" - appcfg["shard_key_name"] = "accountID" - appcfg["sharding_algo"] = "mod" + appcfg["enable_sharding"] = "true" + appcfg["num_shards"] = "5" + appcfg["max_scuttle"] = "128" + appcfg["shard_key_name"] = "accountid" + appcfg["sharding_algo"] = "mod" - opscfg := make(map[string]string) - opscfg["opscfg.default.server.max_connections"] = "3" - opscfg["opscfg.default.server.log_level"] = "5" + opscfg := make(map[string]string) + opscfg["opscfg.default.server.max_connections"] = "3" + opscfg["opscfg.default.server.log_level"] = "5" if os.Getenv("WORKER") == "postgres" { return appcfg, opscfg, testutil.PostgresWorker - } + } return appcfg, opscfg, testutil.MySQLWorker } //Helper function to delete and populate shard map with 128 scuttles -func populate_cam_shard_map() (string,error) { - cmd := exec.Command("mysql","-h",os.Getenv("mysql_ip"),"-p1-testDb","-uroot", "heratestdb", " < populate_cam_shard_map.sql") - //cmd.Stdin = strings.NewReader(sql) - var cmdOutBuf bytes.Buffer - cmd.Stdout = &cmdOutBuf - cmd.Run() - return cmdOutBuf.String(), nil +func populate_cam_shard_map() (string, error) { + cmd := exec.Command("mysql", "-h", os.Getenv("mysql_ip"), "-p1-testDb", "-uroot", "heratestdb", " < populate_cam_shard_map.sql") + //cmd.Stdin = strings.NewReader(sql) + var cmdOutBuf bytes.Buffer + cmd.Stdout = &cmdOutBuf + cmd.Run() + return cmdOutBuf.String(), nil } - func setupDb() error { testutil.RunDML("DROP TABLE IF EXISTS test_simple_table_2") - testutil.RunDML("CREATE TABLE test_simple_table_2 (accountID VARCHAR(64) PRIMARY KEY, NAME VARCHAR(64), STATUS VARCHAR(64), CONDN VARCHAR(64))") + testutil.RunDML("CREATE TABLE test_simple_table_2 (accountid VARCHAR(64) PRIMARY KEY, NAME VARCHAR(64), STATUS VARCHAR(64), CONDN VARCHAR(64))") testutil.RunDML("DROP TABLE IF EXISTS hera_shard_map") if os.Getenv("WORKER") == "postgres" { - testutil.RunDML("CREATE TABLE hera_shard_map (SCUTTLE_ID BIGINT, SHARD_ID BIGINT, STATUS CHAR(1), READ_STATUS CHAR(1), WRITE_STATUS CHAR(1), REMARKS VARCHAR(500))"); - } else { - testutil.RunDML("CREATE TABLE hera_shard_map (SCUTTLE_ID INT, SHARD_ID INT, STATUS CHAR(1), READ_STATUS CHAR(1), WRITE_STATUS CHAR(1), REMARKS VARCHAR(500))"); + testutil.RunDML("CREATE TABLE hera_shard_map (SCUTTLE_ID BIGINT, SHARD_ID BIGINT, STATUS CHAR(1), READ_STATUS CHAR(1), WRITE_STATUS CHAR(1), REMARKS VARCHAR(500))") + } else { + testutil.RunDML("CREATE TABLE hera_shard_map (SCUTTLE_ID INT, SHARD_ID INT, STATUS CHAR(1), READ_STATUS CHAR(1), WRITE_STATUS CHAR(1), REMARKS VARCHAR(500))") } - max_scuttle := 128; - err := testutil.PopulateShardMap(max_scuttle); - return err + max_scuttle := 128 + err := testutil.PopulateShardMap(max_scuttle) + return err } - -func TestMain (m *testing.M) { +func TestMain(m *testing.M) { os.Exit(testutil.UtilMain(m, cfg, setupDb)) } /* ########################################################################################## - # Sharding enabled with num_shards > 0 - # Sending two DMLs, first DML with shard key pass and 2nd one no shard key - # Verify 2nd request fails due to no shard key - # Verify Log, CAL events - # Send update, fetch requests with auto discovery - # Veriy update is sent to correct shard - # Veriy fetch is sent to correct shard and fields are updated correctly - # Verify Log - # - #############################################################################################*/ +# Sharding enabled with num_shards > 0 +# Sending two DMLs, first DML with shard key pass and 2nd one no shard key +# Verify 2nd request fails due to no shard key +# Verify Log, CAL events +# Send update, fetch requests with auto discovery +# Veriy update is sent to correct shard +# Veriy fetch is sent to correct shard and fields are updated correctly +# Verify Log +# +#############################################################################################*/ func TestShardBasic(t *testing.T) { - fmt.Println ("TestShardBasic begin +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") + fmt.Println("TestShardBasic begin +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") logger.GetLogger().Log(logger.Debug, "TestShardBasic begin +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n") - time.Sleep(8 * time.Second); - + time.Sleep(8 * time.Second) + hostname := testutil.GetHostname() - fmt.Println ("Hostname: ", hostname); - db, err := sql.Open("hera", hostname + ":31002") - if err != nil { - t.Fatal("Error starting Mux:", err) - return - } + fmt.Println("Hostname: ", hostname) + db, err := sql.Open("hera", hostname+":31002") + if err != nil { + t.Fatal("Error starting Mux:", err) + return + } db.SetMaxIdleConns(0) defer db.Close() @@ -108,8 +106,8 @@ func TestShardBasic(t *testing.T) { t.Fatalf("Error getting connection %s\n", err.Error()) } tx, _ := conn.BeginTx(ctx, nil) - stmt, _ := tx.PrepareContext(ctx, "/*cmd*/insert into test_simple_table_2 (accountID, Name, Status) VALUES(:accountID, :Name, :Status)") - _, err = stmt.Exec(sql.Named("accountID", "12346"), sql.Named("Name", "Steve"), sql.Named("Status", "done")) + stmt, _ := tx.PrepareContext(ctx, "/*cmd*/insert into test_simple_table_2 (accountid, Name, Status) VALUES(:accountid, :Name, :Status)") + _, err = stmt.Exec(sql.Named("accountid", "12346"), sql.Named("Name", "Steve"), sql.Named("Status", "done")) if err != nil { t.Fatalf("Error preparing test (create row in table) %s\n", err.Error()) } @@ -117,64 +115,64 @@ func TestShardBasic(t *testing.T) { if err != nil { t.Fatalf("Error commit %s\n", err.Error()) } - - fmt.Println ("Send an update request without shard key") + + fmt.Println("Send an update request without shard key") stmt, _ = conn.PrepareContext(ctx, "/*cmd*/Update test_simple_table_2 set Status = 'progess' where Name=?") stmt.Exec("Steve") stmt.Close() cancel() - conn.Close() - - fmt.Println ("Verify insert request is sent to shard 3") - count := testutil.RegexCount ("WORKER shd3.*Preparing.*insert into test_simple_table_2") - if (count < 1) { - t.Fatalf ("Error: Insert Query does NOT go to shd3"); - } - - fmt.Println ("Verify no shard key error is thrown for fetch request") - count = testutil.RegexCount ("Error preprocessing sharding, hangup: HERA-373: no shard key or more than one or bad logical") - if (count < 1) { - t.Fatalf ("Error: No Shard key error should be thrown for fetch request"); - } - cal_count := testutil.RegexCountFile ("SHARDING.*shard_key_not_found.*0.*sql=1093137600", "cal.log") - if (cal_count < 1) { - t.Fatalf ("Error: No Shard key event for fetch request in CAL"); - } - - fmt.Println ("Check log for shard key auto discovery"); - count = testutil.RegexCount ("shard key auto discovery: shardkey=accountid|12346&shardid=3&scuttleid=") - if (count < 1) { - t.Fatalf ("Error: Did NOT get shard key auto discovery in log"); - } - - fmt.Println ("Check CAL log for correct events") - cal_count = testutil.RegexCountFile ("T.*API.*CLIENT_SESSION_3", "cal.log") - if (cal_count < 1) { - t.Fatalf ("Error: Request is not executed by shard 3 as expected"); - } - - fmt.Println ("Open new connection as previous connection is already closed"); + conn.Close() + + fmt.Println("Verify insert request is sent to shard 3") + count := testutil.RegexCount("WORKER shd3.*Preparing.*insert into test_simple_table_2") + if count < 1 { + t.Fatalf("Error: Insert Query does NOT go to shd3") + } + + fmt.Println("Verify no shard key error is thrown for fetch request") + count = testutil.RegexCount("Error preprocessing sharding, hangup: HERA-373: no shard key or more than one or bad logical") + if count < 1 { + t.Fatalf("Error: No Shard key error should be thrown for fetch request") + } + cal_count := testutil.RegexCountFile("SHARDING.*shard_key_not_found.*0.*sql=1093137600", "cal.log") + if cal_count < 1 { + t.Fatalf("Error: No Shard key event for fetch request in CAL") + } + + fmt.Println("Check log for shard key auto discovery") + count = testutil.RegexCount("shard key auto discovery: shardkey=accountid|12346&shardid=3&scuttleid=") + if count < 1 { + t.Fatalf("Error: Did NOT get shard key auto discovery in log") + } + + fmt.Println("Check CAL log for correct events") + cal_count = testutil.RegexCountFile("T.*API.*CLIENT_SESSION_3", "cal.log") + if cal_count < 1 { + t.Fatalf("Error: Request is not executed by shard 3 as expected") + } + + fmt.Println("Open new connection as previous connection is already closed") ctx1, cancel1 := context.WithTimeout(context.Background(), 10*time.Second) conn1, err := db.Conn(ctx1) - if err != nil { - t.Fatalf("Error getting connection %s\n", err.Error()) - } - tx1, _ := conn1.BeginTx(ctx1, nil) - fmt.Println ("Update table with shard key passed"); - stmt1, _ := tx1.PrepareContext(ctx1, "/*cmd*/ update test_simple_table_2 set Status = 'In Progress' where accountID in (:accountID)") - _, err = stmt1.Exec(sql.Named("accountID", "12346")) - if err != nil { - t.Fatalf("Error updating row in table %s\n", err.Error()) - } - err = tx1.Commit() - if err != nil { - t.Fatalf("Error commit %s\n", err.Error()) - } - stmt1, _ = conn1.PrepareContext(ctx1, "/*TestShardingBasic*/Select name, status from test_simple_table_2 where accountID=:accountID") - rows1, _ := stmt1.Query(sql.Named("accountID", "12346")) - if !rows1.Next() { + if err != nil { + t.Fatalf("Error getting connection %s\n", err.Error()) + } + tx1, _ := conn1.BeginTx(ctx1, nil) + fmt.Println("Update table with shard key passed") + stmt1, _ := tx1.PrepareContext(ctx1, "/*cmd*/ update test_simple_table_2 set Status = 'In Progress' where accountid in (:accountid)") + _, err = stmt1.Exec(sql.Named("accountid", "12346")) + if err != nil { + t.Fatalf("Error updating row in table %s\n", err.Error()) + } + err = tx1.Commit() + if err != nil { + t.Fatalf("Error commit %s\n", err.Error()) + } + stmt1, _ = conn1.PrepareContext(ctx1, "/*TestShardingBasic*/Select name, status from test_simple_table_2 where accountid=:accountid") + rows1, _ := stmt1.Query(sql.Named("accountid", "12346")) + if !rows1.Next() { t.Fatalf("Expected 1 row") } var name, status string @@ -182,7 +180,7 @@ func TestShardBasic(t *testing.T) { if err != nil { t.Fatalf("Expected values %s", err.Error()) } - if (name != "Steve" || status != "In Progress") { + if name != "Steve" || status != "In Progress" { t.Fatalf("***Error: name= %s, status=%s", name, status) } rows1.Close() @@ -191,19 +189,18 @@ func TestShardBasic(t *testing.T) { cancel1() conn1.Close() - fmt.Println ("Verify update request is sent to shard 3") - count1 := testutil.RegexCount ("WORKER shd3.*Preparing.*update test_simple_table_2") - if (count1 < 1) { - t.Fatalf ("Error: Update Query does NOT go to shd3"); - } - - fmt.Println ("Verify select request is sent to shard 3") - count1 = testutil.RegexCount ("WORKER shd3.*Preparing.*TestShardingBasic.*Select name, status from test_simple_table_2") - if (count1 < 1) { - t.Fatalf ("Error: Select Query does NOT go to shd3"); - } + fmt.Println("Verify update request is sent to shard 3") + count1 := testutil.RegexCount("WORKER shd3.*Preparing.*update test_simple_table_2") + if count1 < 1 { + t.Fatalf("Error: Update Query does NOT go to shd3") + } + + fmt.Println("Verify select request is sent to shard 3") + count1 = testutil.RegexCount("WORKER shd3.*Preparing.*TestShardingBasic.*Select name, status from test_simple_table_2") + if count1 < 1 { + t.Fatalf("Error: Select Query does NOT go to shd3") + } testutil.DoDefaultValidation(t) - time.Sleep (time.Duration(2 * time.Second)) + time.Sleep(time.Duration(2 * time.Second)) logger.GetLogger().Log(logger.Debug, "TestShardBasic done -------------------------------------------------------------") } - From bb44de9c35d404d382f8c19ccc339535c1a7736e Mon Sep 17 00:00:00 2001 From: Rajesh S Date: Thu, 16 May 2024 10:04:25 +0530 Subject: [PATCH 3/6] changes for incorporate review comments for management query timeouts --- lib/config.go | 4 + lib/querybindblocker.go | 52 +++--- lib/racmaint.go | 13 +- lib/shardingcfg.go | 20 ++- .../main_test.go | 170 ++++++++++++++++++ 5 files changed, 217 insertions(+), 42 deletions(-) create mode 100644 tests/unittest/coordinator_sharding_mgmt_query_timeout/main_test.go diff --git a/lib/config.go b/lib/config.go index f0283cfa..3fc6c69e 100644 --- a/lib/config.go +++ b/lib/config.go @@ -176,6 +176,9 @@ type Config struct { // Max desired percentage of healthy workers for the worker pool MaxDesiredHealthyWorkerPct int + + //Timeout for management queries. + ManagementQueriesTimeoutInMs int } // The OpsConfig contains the configuration that can be modified during run time @@ -463,6 +466,7 @@ func InitConfig() error { gAppConfig.MaxDesiredHealthyWorkerPct = 90 } + gAppConfig.ManagementQueriesTimeoutInMs = cdb.GetOrDefaultInt("management_queries_timeout_ms", 200) return nil } diff --git a/lib/querybindblocker.go b/lib/querybindblocker.go index 001e9749..638f3c7c 100644 --- a/lib/querybindblocker.go +++ b/lib/querybindblocker.go @@ -31,14 +31,13 @@ import ( "github.com/paypal/hera/utility/logger" ) - type QueryBindBlockerEntry struct { - Herasqlhash uint32 - Herasqltext string // prefix since some sql is too long - Bindvarname string // prefix for in clause + Herasqlhash uint32 + Herasqltext string // prefix since some sql is too long + Bindvarname string // prefix for in clause Bindvarvalue string // when set to "BLOCKALLVALUES" should block all sqltext queries - Blockperc int - Heramodule string + Blockperc int + Heramodule string } type QueryBindBlockerCfg struct { @@ -48,7 +47,7 @@ type QueryBindBlockerCfg struct { // check by sqltext prefix (delay to end) } -func (cfg * QueryBindBlockerCfg) IsBlocked(sqltext string, bindPairs []string) (bool,string) { +func (cfg *QueryBindBlockerCfg) IsBlocked(sqltext string, bindPairs []string) (bool, string) { sqlhash := uint32(utility.GetSQLHash(sqltext)) if logger.GetLogger().V(logger.Verbose) { logger.GetLogger().Log(logger.Verbose, fmt.Sprintf("query bind blocker sqlhash and text %d %s", sqlhash, sqltext)) @@ -70,7 +69,7 @@ func (cfg * QueryBindBlockerCfg) IsBlocked(sqltext string, bindPairs []string) ( byBindValue, ok := byBindName[bindPairs[i]] if !ok { // strip numeric suffix to try to match - withoutNumSuffix := regexp.MustCompile("[_0-9]*$").ReplaceAllString(bindPairs[i],"") + withoutNumSuffix := regexp.MustCompile("[_0-9]*$").ReplaceAllString(bindPairs[i], "") byBindValue, ok = byBindName[withoutNumSuffix] if !ok { continue @@ -118,27 +117,26 @@ func (cfg * QueryBindBlockerCfg) IsBlocked(sqltext string, bindPairs []string) ( var g_module string var gQueryBindBlockerCfg atomic.Value -func GetQueryBindBlockerCfg() (*QueryBindBlockerCfg) { - cfg := gQueryBindBlockerCfg.Load() - if cfg == nil { - return nil - } - return cfg.(*QueryBindBlockerCfg) +func GetQueryBindBlockerCfg() *QueryBindBlockerCfg { + cfg := gQueryBindBlockerCfg.Load() + if cfg == nil { + return nil + } + return cfg.(*QueryBindBlockerCfg) } - func InitQueryBindBlocker(modName string) { g_module = modName - db, err := sql.Open("heraloop", fmt.Sprintf("0:0:0")) - if err != nil { + db, err := sql.Open("heraloop", fmt.Sprintf("0:0:0")) + if err != nil { logger.GetLogger().Log(logger.Alert, "Loading query bind blocker - conn err ", err) - return - } - db.SetMaxIdleConns(0) + return + } + db.SetMaxIdleConns(0) go func() { - time.Sleep(4*time.Second) + time.Sleep(4 * time.Second) logger.GetLogger().Log(logger.Info, "Loading query bind blocker - initial") loadBlockQueryBind(db) c := time.Tick(11 * time.Second) @@ -150,9 +148,9 @@ func InitQueryBindBlocker(modName string) { } func loadBlockQueryBind(db *sql.DB) { - ctx, cancel := context.WithTimeout(context.Background(), 5000*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(GetConfig().ManagementQueriesTimeoutInMs)*time.Millisecond) defer cancel() - conn, err := db.Conn(ctx); + conn, err := db.Conn(ctx) if err != nil { logger.GetLogger().Log(logger.Alert, "Error (conn) loading query bind blocker:", err) return @@ -172,7 +170,7 @@ func loadBlockQueryBind(db *sql.DB) { } defer rows.Close() - cfgLoad := QueryBindBlockerCfg{BySqlHash:make(map[uint32]map[string]map[string][]QueryBindBlockerEntry)} + cfgLoad := QueryBindBlockerCfg{BySqlHash: make(map[uint32]map[string]map[string][]QueryBindBlockerEntry)} rowCount := 0 for rows.Next() { @@ -182,9 +180,9 @@ func loadBlockQueryBind(db *sql.DB) { logger.GetLogger().Log(logger.Alert, "Error (row scan) loading query bind blocker:", err) continue } - + if len(entry.Herasqltext) < GetConfig().QueryBindBlockerMinSqlPrefix { - logger.GetLogger().Log(logger.Alert, "Error (row scan) loading query bind blocker - sqltext must be ", GetConfig().QueryBindBlockerMinSqlPrefix," bytes or more - sqlhash:", entry.Herasqlhash) + logger.GetLogger().Log(logger.Alert, "Error (row scan) loading query bind blocker - sqltext must be ", GetConfig().QueryBindBlockerMinSqlPrefix, " bytes or more - sqlhash:", entry.Herasqlhash) continue } rowCount++ @@ -200,7 +198,7 @@ func loadBlockQueryBind(db *sql.DB) { } bindVal, ok := bindName[entry.Bindvarvalue] if !ok { - bindVal = make([]QueryBindBlockerEntry,0) + bindVal = make([]QueryBindBlockerEntry, 0) bindName[entry.Bindvarvalue] = bindVal } bindName[entry.Bindvarvalue] = append(bindName[entry.Bindvarvalue], entry) diff --git a/lib/racmaint.go b/lib/racmaint.go index 2aef6912..9026ed8a 100644 --- a/lib/racmaint.go +++ b/lib/racmaint.go @@ -102,11 +102,11 @@ func racMaintMain(shard int, interval int, cmdLineModuleName string) { binds[0], err = os.Hostname() binds[0] = strings.ToUpper(binds[0]) binds[1] = strings.ToUpper(cmdLineModuleName) // */ - waitTime := time.Second * time.Duration(interval) //First time data loading - racMaint(&ctx, shard, db, racSQL, cmdLineModuleName, prev, waitTime/2) + racMaint(&ctx, shard, db, racSQL, cmdLineModuleName, prev, GetConfig().ManagementQueriesTimeoutInMs) - timeTicker := time.NewTicker(waitTime) + timeTicker := time.NewTicker(time.Second * time.Duration(interval)) + defer timeTicker.Stop() for { select { case <-ctx.Done(): @@ -114,7 +114,8 @@ func racMaintMain(shard int, interval int, cmdLineModuleName string) { return case <-timeTicker.C: //Periodic data loading - racMaint(&ctx, shard, db, racSQL, cmdLineModuleName, prev, waitTime/2) + racMaint(&ctx, shard, db, racSQL, cmdLineModuleName, prev, GetConfig().ManagementQueriesTimeoutInMs) + timeTicker.Reset(time.Second * time.Duration(interval)) } } } @@ -123,7 +124,7 @@ func racMaintMain(shard int, interval int, cmdLineModuleName string) { racMaint is the main function for RAC maintenance processing, being called regularly. When maintenance is planned, it calls workerpool.RacMaint to start the actuall processing */ -func racMaint(ctx *context.Context, shard int, db *sql.DB, racSQL string, cmdLineModuleName string, prev map[racCfgKey]racCfg, queryTimeout time.Duration) { +func racMaint(ctx *context.Context, shard int, db *sql.DB, racSQL string, cmdLineModuleName string, prev map[racCfgKey]racCfg, queryTimeoutInMs int) { // // print this log for unittesting // @@ -131,7 +132,7 @@ func racMaint(ctx *context.Context, shard int, db *sql.DB, racSQL string, cmdLin logger.GetLogger().Log(logger.Verbose, "Rac maint check, shard =", shard) } //create cancellable context - queryContext, cancel := context.WithTimeout(*ctx, queryTimeout) + queryContext, cancel := context.WithTimeout(*ctx, time.Duration(queryTimeoutInMs)*time.Millisecond) defer cancel() // Always call cancel to release resources associated with the context conn, err := db.Conn(queryContext) diff --git a/lib/shardingcfg.go b/lib/shardingcfg.go index 995f53e0..276fca52 100644 --- a/lib/shardingcfg.go +++ b/lib/shardingcfg.go @@ -100,7 +100,7 @@ func getSQL() string { /* load the physical to logical maping */ -func loadMap(ctx *context.Context, db *sql.DB, queryTimeoutInterval time.Duration) error { +func loadMap(ctx *context.Context, db *sql.DB, queryTimeoutInterval int) error { if logger.GetLogger().V(logger.Verbose) { logger.GetLogger().Log(logger.Verbose, "Begin loading shard map") } @@ -109,7 +109,7 @@ func loadMap(ctx *context.Context, db *sql.DB, queryTimeoutInterval time.Duratio logger.GetLogger().Log(logger.Verbose, "Done loading shard map") }() } - queryContext, cancel := context.WithTimeout(*ctx, queryTimeoutInterval) + queryContext, cancel := context.WithTimeout(*ctx, time.Duration(queryTimeoutInterval)*time.Millisecond) defer cancel() conn, err := db.Conn(queryContext) if err != nil { @@ -217,7 +217,7 @@ func getWLSQL() string { /* load the whitelist mapping */ -func loadWhitelist(ctx *context.Context, db *sql.DB, timeout time.Duration) { +func loadWhitelist(ctx *context.Context, db *sql.DB, timeoutInMs int) { if logger.GetLogger().V(logger.Verbose) { logger.GetLogger().Log(logger.Verbose, "Begin loading whitelist") } @@ -226,7 +226,7 @@ func loadWhitelist(ctx *context.Context, db *sql.DB, timeout time.Duration) { logger.GetLogger().Log(logger.Verbose, "Done loading whitelist") }() } - queryContext, cancel := context.WithTimeout(*ctx, timeout) + queryContext, cancel := context.WithTimeout(*ctx, time.Duration(timeoutInMs)*time.Millisecond) defer cancel() conn, err := db.Conn(queryContext) if err != nil { @@ -305,7 +305,7 @@ func InitShardingCfg() error { } db, err = openDb(shard) if err == nil { - err = loadMap(&ctx, db, reloadInterval/2) + err = loadMap(&ctx, db, GetConfig().ManagementQueriesTimeoutInMs) if err == nil { break } @@ -324,10 +324,11 @@ func InitShardingCfg() error { return errors.New("Failed to load shard map, no more retry") } if GetConfig().EnableWhitelistTest { - loadWhitelist(&ctx, db, reloadInterval/2) + loadWhitelist(&ctx, db, GetConfig().ManagementQueriesTimeoutInMs) } go func() { reloadTimer := time.NewTimer(reloadInterval) //Periodic reload timer + defer reloadTimer.Stop() for { select { case <-ctx.Done(): @@ -340,18 +341,19 @@ func InitShardingCfg() error { } db, err = openDb(shard) if err == nil { - err = loadMap(&ctx, db, reloadInterval/2) + err = loadMap(&ctx, db, GetConfig().ManagementQueriesTimeoutInMs) if err == nil { if shard == 0 && GetConfig().EnableWhitelistTest { - loadWhitelist(&ctx, db, reloadInterval/2) + loadWhitelist(&ctx, db, GetConfig().ManagementQueriesTimeoutInMs) } break } } logger.GetLogger().Log(logger.Warning, "Error <", err, "> loading the shard map from shard", shard) - evt := cal.NewCalEvent(cal.EventTypeError, "no_shard_map", cal.TransOK, "Error loading shard map") + evt := cal.NewCalEvent(cal.EventTypeError, "no_shard_map", cal.TransOK, err.Error()) evt.Completed() } + reloadTimer.Reset(reloadInterval) //Reset timer } } }() diff --git a/tests/unittest/coordinator_sharding_mgmt_query_timeout/main_test.go b/tests/unittest/coordinator_sharding_mgmt_query_timeout/main_test.go new file mode 100644 index 00000000..988e7aab --- /dev/null +++ b/tests/unittest/coordinator_sharding_mgmt_query_timeout/main_test.go @@ -0,0 +1,170 @@ +package main + +import ( + "context" + "database/sql" + + "fmt" + "os" + "strings" + "testing" + "time" + + _ "github.com/paypal/hera/client/gosqldriver/tcp" + "github.com/paypal/hera/tests/unittest/testutil" + "github.com/paypal/hera/utility/logger" +) + +var mx testutil.Mux +var tableName string + +func cfg() (map[string]string, map[string]string, testutil.WorkerType) { + + appcfg := make(map[string]string) + // best to chose an "unique" port in case golang runs tests in paralel + appcfg["bind_port"] = "31003" + appcfg["log_level"] = "5" + appcfg["log_file"] = "hera.log" + appcfg["enable_sharding"] = "true" + appcfg["num_shards"] = "3" + appcfg["max_scuttle"] = "9" + appcfg["shard_key_name"] = "id" + pfx := os.Getenv("MGMT_TABLE_PREFIX") + if pfx != "" { + appcfg["management_table_prefix"] = pfx + } + appcfg["sharding_cfg_reload_interval"] = "2" + appcfg["rac_sql_interval"] = "0" + appcfg["management_queries_timeout_ms"] = "2" + + opscfg := make(map[string]string) + opscfg["opscfg.default.server.max_connections"] = "3" + opscfg["opscfg.default.server.log_level"] = "5" + + return appcfg, opscfg, testutil.MySQLWorker +} + +func setupShardMap() { + twoTask := os.Getenv("TWO_TASK") + if !strings.HasPrefix(twoTask, "tcp") { + // not mysql + return + } + shard := 0 + db, err := sql.Open("heraloop", fmt.Sprintf("%d:0:0", shard)) + if err != nil { + testutil.Fatal("Error starting Mux:", err) + return + } + db.SetMaxIdleConns(0) + defer db.Close() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, err := db.Conn(ctx) + if err != nil { + testutil.Fatalf("Error getting connection %s\n", err.Error()) + } + defer conn.Close() + + testutil.RunDML("create table hera_shard_map ( scuttle_id smallint not null, shard_id tinyint not null, status char(1) , read_status char(1), write_status char(1), remarks varchar(500))") + + for i := 0; i < 1024; i++ { + shard := 0 + if i <= 8 { + shard = i % 3 + } + testutil.RunDML(fmt.Sprintf("insert into hera_shard_map ( scuttle_id, shard_id, status, read_status, write_status ) values ( %d, %d, 'Y', 'Y', 'Y' )", i, shard)) + } +} + +func before() error { + tableName = os.Getenv("TABLE_NAME") + if tableName == "" { + tableName = "jdbc_hera_test2" + } + if strings.HasPrefix(os.Getenv("TWO_TASK"), "tcp") { + // mysql + testutil.RunDML("create table jdbc_hera_test2 ( ID BIGINT, INT_VAL BIGINT, STR_VAL VARCHAR(500))") + } + return nil +} + +func TestMain(m *testing.M) { + os.Exit(testutil.UtilMain(m, cfg, before)) +} + +func cleanup(ctx context.Context, conn *sql.Conn) error { + tx, _ := conn.BeginTx(ctx, nil) + stmt, _ := tx.PrepareContext(ctx, "/*Cleanup*/delete from "+tableName+" where id != :id") + _, err := stmt.Exec(sql.Named("id", -123)) + if err != nil { + return err + } + err = tx.Commit() + return nil +} + +func TestShardingWithContextTimeout(t *testing.T) { + logger.GetLogger().Log(logger.Debug, "TestShardingWithContextTimeout setup") + setupShardMap() + logger.GetLogger().Log(logger.Debug, "TestShardingWithContextTimeout begin +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n") + time.Sleep(25 * time.Second) + hostname, _ := os.Hostname() + db, err := sql.Open("hera", hostname+":31003") + if err != nil { + t.Fatal("Error starting Mux:", err) + return + } + db.SetMaxIdleConns(0) + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + conn, err := db.Conn(ctx) + if err != nil { + t.Fatalf("Error getting connection %s\n", err.Error()) + } + cleanup(ctx, conn) + // insert one row in the table + tx, _ := conn.BeginTx(ctx, nil) + stmt, _ := tx.PrepareContext(ctx, "/*TestShardingWithContextTimeout*/insert into "+tableName+" (id, int_val, str_val) VALUES(:id, :int_val, :str_val)") + _, err = stmt.Exec(sql.Named("id", 1), sql.Named("int_val", time.Now().Unix()), sql.Named("str_val", "val 1")) + if err != nil { + t.Fatalf("Error preparing test (create row in table) %s\n", err.Error()) + } + err = tx.Commit() + if err != nil { + t.Fatalf("Error commit %s\n", err.Error()) + } + + stmt, _ = conn.PrepareContext(ctx, "/*TestShardingWithContextTimeout*/Select id, int_val, str_val from "+tableName+" where id=:id") + rows, _ := stmt.Query(sql.Named("id", 1)) + if !rows.Next() { + t.Fatalf("Expected 1 row") + } + var id, int_val uint64 + var str_val sql.NullString + err = rows.Scan(&id, &int_val, &str_val) + if err != nil { + t.Fatalf("Expected values %s", err.Error()) + } + if str_val.String != "val 1" { + t.Fatalf("Expected val 1 , got: %s", str_val.String) + } + + rows.Close() + stmt.Close() + + cancel() + conn.Close() + + out, err := testutil.BashCmd("grep 'Preparing: /\\*TestShardingWithContextTimeout\\*/' hera.log | grep 'WORKER shd2' | wc -l") + if (err != nil) || (len(out) == 0) { + err = nil + t.Fatalf("Request did not run on shard 2. err = %v, len(out) = %d", err, len(out)) + } + if out[0] != '2' { + t.Fatalf("Expected 2 excutions on shard 2, instead got %d", int(out[0]-'0')) + } + + logger.GetLogger().Log(logger.Debug, "TestShardingWithContextTimeout done -------------------------------------------------------------") +} From 9c1df4f5695f5b62955a2919e092200280319da8 Mon Sep 17 00:00:00 2001 From: Rajesh S Date: Wed, 22 May 2024 17:17:05 +0530 Subject: [PATCH 4/6] incorporate review comments and adding client side context support --- client/gosqldriver/connection.go | 306 +++++++++++++----- client/gosqldriver/statement.go | 10 + client/gosqldriver/utils.go | 67 ++++ lib/config.go | 35 +- lib/querybindblocker.go | 2 +- lib/racmaint.go | 8 +- lib/shardingcfg.go | 14 +- .../main_test.go | 71 +--- .../querybindblocker_timeout/main_test.go | 66 ++++ .../rac_maint_mgmt_query_timeout/main_test.go | 68 ++++ 10 files changed, 469 insertions(+), 178 deletions(-) create mode 100644 client/gosqldriver/utils.go create mode 100644 tests/unittest/querybindblocker_timeout/main_test.go create mode 100644 tests/unittest/rac_maint_mgmt_query_timeout/main_test.go diff --git a/client/gosqldriver/connection.go b/client/gosqldriver/connection.go index e2186b22..eae1ba53 100644 --- a/client/gosqldriver/connection.go +++ b/client/gosqldriver/connection.go @@ -19,15 +19,15 @@ package gosqldriver import ( + "context" "database/sql/driver" "errors" "fmt" - "net" - "os" - "github.com/paypal/hera/common" "github.com/paypal/hera/utility/encoding/netstring" "github.com/paypal/hera/utility/logger" + "net" + "os" ) var corrIDUnsetCmd = netstring.NewNetstringFrom(common.CmdClientCalCorrelationID, []byte("CorrId=NotSet")) @@ -39,26 +39,108 @@ type heraConnection struct { // for the sharding extension shardKeyPayload []byte // correlation id - corrID *netstring.Netstring + corrID *netstring.Netstring clientinfo *netstring.Netstring + + // for context support (Go 1.8+) + watching bool + watcher chan<- context.Context + closech chan struct{} + finished chan<- struct{} + canceled atomicError // set non-nil if conn is canceled + closed atomicBool // set when conn is closed, before closech is closed } // NewHeraConnection creates a structure implementing a driver.Con interface func NewHeraConnection(conn net.Conn) driver.Conn { - hera := &heraConnection{conn: conn, id: conn.RemoteAddr().String(), reader: netstring.NewNetstringReader(conn), corrID: corrIDUnsetCmd} + hera := &heraConnection{conn: conn, + id: conn.RemoteAddr().String(), + reader: netstring.NewNetstringReader(conn), + corrID: corrIDUnsetCmd, + closech: make(chan struct{}), + } + + hera.startWatcher() if logger.GetLogger().V(logger.Info) { logger.GetLogger().Log(logger.Info, hera.id, "create driver connection") } return hera } +func (heraConn *heraConnection) watchCancel(ctx context.Context) error { + if heraConn.watching { + // Reach here if canceled, + // so the connection is already invalid + heraConn.cleanup() + return nil + } + // When ctx is already cancelled, don't watch it. + if err := ctx.Err(); err != nil { + return err + } + // When ctx is not cancellable, don't watch it. + if ctx.Done() == nil { + return nil + } + // When watcher is not alive, can't watch it. + if heraConn.watcher == nil { + return nil + } + + heraConn.watching = true + heraConn.watcher <- ctx + return nil +} + +// Closes the network connection and unsets internal variables. Do not call this +// function after successfully authentication, call Close instead. This function +// is called before auth or on auth failure because HERA will have already +// closed the network connection. +func (heraConn *heraConnection) cleanup() { + if heraConn.closed.Swap(true) { + return + } + + // Makes cleanup idempotent + close(heraConn.closech) + if heraConn.conn == nil { + return + } + heraConn.finish() + if err := heraConn.conn.Close(); err != nil { + logger.GetLogger().Log(logger.Alert, err) + } +} + +//error +func (heraConn *heraConnection) error() error { + if heraConn.closed.Load() { + if err := heraConn.canceled.Value(); err != nil { + return err + } + return ErrInvalidConn + } + return nil +} + +// finish is called when the query has succeeded. +func (heraConn *heraConnection) finish() { + if !heraConn.watching || heraConn.finished == nil { + return + } + select { + case heraConn.finished <- struct{}{}: + heraConn.watching = false + case <-heraConn.closech: + } +} // Prepare returns a prepared statement, bound to this connection. -func (c *heraConnection) Prepare(query string) (driver.Stmt, error) { +func (heraConn *heraConnection) Prepare(query string) (driver.Stmt, error) { if logger.GetLogger().V(logger.Debug) { - logger.GetLogger().Log(logger.Debug, c.id, "prepare SQL:", query) + logger.GetLogger().Log(logger.Debug, heraConn.id, "prepare SQL:", query) } - return newStmt(c, query), nil + return newStmt(heraConn, query), nil } // Close invalidates and potentially stops any current @@ -69,46 +151,94 @@ func (c *heraConnection) Prepare(query string) (driver.Stmt, error) { // connections and only calls Close when there's a surplus of // idle connections, it shouldn't be necessary for drivers to // do their own connection caching. -func (c *heraConnection) Close() error { +func (heraConn *heraConnection) Close() error { if logger.GetLogger().V(logger.Info) { - logger.GetLogger().Log(logger.Info, c.id, "close driver connection") + logger.GetLogger().Log(logger.Info, heraConn.id, "close driver connection") } - c.conn.Close() + heraConn.cleanup() return nil } +//Start watcher for connection +func (heraConn *heraConnection) startWatcher() { + watcher := make(chan context.Context, 1) + heraConn.watcher = watcher + finished := make(chan struct{}) + heraConn.finished = finished + go func() { + for { + var ctx context.Context + select { + case ctx = <-watcher: + case <-heraConn.closech: + return + } + + select { + case <-ctx.Done(): + heraConn.cancel(ctx.Err()) + case <-finished: + case <-heraConn.closech: + return + } + } + }() +} + +// finish is called when the query has canceled. +func (heraConn *heraConnection) cancel(err error) { + heraConn.canceled.Set(err) + heraConn.cleanup() +} + // Begin starts and returns a new transaction. -func (c *heraConnection) Begin() (driver.Tx, error) { +func (heraConn *heraConnection) Begin() (driver.Tx, error) { if logger.GetLogger().V(logger.Debug) { - logger.GetLogger().Log(logger.Debug, c.id, "begin txn") + logger.GetLogger().Log(logger.Debug, heraConn.id, "begin txn") + } + if heraConn.closed.Load() { + logger.GetLogger().Log(logger.Alert, ErrInvalidConn) + return nil, driver.ErrBadConn } - return &tx{hera: c}, nil + return &tx{hera: heraConn}, nil } // internal function to execute commands -func (c *heraConnection) exec(cmd int, payload []byte) error { - return c.execNs(netstring.NewNetstringFrom(cmd, payload)) +func (heraConn *heraConnection) exec(cmd int, payload []byte) error { + if heraConn.closed.Load() { + logger.GetLogger().Log(logger.Alert, ErrInvalidConn) + return driver.ErrBadConn + } + return heraConn.execNs(netstring.NewNetstringFrom(cmd, payload)) } // internal function to execute commands -func (c *heraConnection) execNs(ns *netstring.Netstring) error { +func (heraConn *heraConnection) execNs(ns *netstring.Netstring) error { + if heraConn.closed.Load() { + logger.GetLogger().Log(logger.Alert, ErrInvalidConn) + return driver.ErrBadConn + } if logger.GetLogger().V(logger.Verbose) { payload := string(ns.Payload) if len(payload) > 1000 { payload = payload[:1000] } - logger.GetLogger().Log(logger.Verbose, c.id, "send command:", ns.Cmd, ", payload:", payload) + logger.GetLogger().Log(logger.Verbose, heraConn.id, "send command:", ns.Cmd, ", payload:", payload) } - _, err := c.conn.Write(ns.Serialized) + _, err := heraConn.conn.Write(ns.Serialized) return err } // returns the next message from the connection -func (c *heraConnection) getResponse() (*netstring.Netstring, error) { - ns, err := c.reader.ReadNext() +func (heraConn *heraConnection) getResponse() (*netstring.Netstring, error) { + if heraConn.closed.Load() { + logger.GetLogger().Log(logger.Alert, ErrInvalidConn) + return nil, driver.ErrBadConn + } + ns, err := heraConn.reader.ReadNext() if err != nil { if logger.GetLogger().V(logger.Warning) { - logger.GetLogger().Log(logger.Warning, c.id, "Failed to read response") + logger.GetLogger().Log(logger.Warning, heraConn.id, "Failed to read response") } return nil, errors.New("Failed to read response") } @@ -123,9 +253,9 @@ func (c *heraConnection) getResponse() (*netstring.Netstring, error) { } // implementing the extension HeraConn interface -func (c *heraConnection) SetShardID(shard int) error { - c.exec(common.CmdSetShardID, []byte(fmt.Sprintf("%d", shard))) - ns, err := c.getResponse() +func (heraConn *heraConnection) SetShardID(shard int) error { + heraConn.exec(common.CmdSetShardID, []byte(fmt.Sprintf("%d", shard))) + ns, err := heraConn.getResponse() if err != nil { return err } @@ -139,14 +269,14 @@ func (c *heraConnection) SetShardID(shard int) error { } // implementing the extension HeraConn interface -func (c *heraConnection) ResetShardID() error { - return c.SetShardID(-1) +func (heraConn *heraConnection) ResetShardID() error { + return heraConn.SetShardID(-1) } // implementing the extension HeraConn interface -func (c *heraConnection) GetNumShards() (int, error) { - c.exec(common.CmdGetNumShards, nil) - ns, err := c.getResponse() +func (heraConn *heraConnection) GetNumShards() (int, error) { + heraConn.exec(common.CmdGetNumShards, nil) + ns, err := heraConn.getResponse() if err != nil { return -1, err } @@ -162,81 +292,87 @@ func (c *heraConnection) GetNumShards() (int, error) { } // implementing the extension HeraConn interface -func (c *heraConnection) SetShardKeyPayload(payload string) { - c.shardKeyPayload = []byte(payload) +func (heraConn *heraConnection) SetShardKeyPayload(payload string) { + heraConn.shardKeyPayload = []byte(payload) } // implementing the extension HeraConn interface -func (c *heraConnection) ResetShardKeyPayload() { - c.SetShardKeyPayload("") +func (heraConn *heraConnection) ResetShardKeyPayload() { + heraConn.SetShardKeyPayload("") } // implementing the extension HeraConn interface -func (c *heraConnection) SetCalCorrID(corrID string) { - c.corrID = netstring.NewNetstringFrom(common.CmdClientCalCorrelationID, []byte(fmt.Sprintf("CorrId=%s", corrID))) +func (heraConn *heraConnection) SetCalCorrID(corrID string) { + heraConn.corrID = netstring.NewNetstringFrom(common.CmdClientCalCorrelationID, []byte(fmt.Sprintf("CorrId=%s", corrID))) } // SetClientInfo actually sends it over to Hera server -func (c *heraConnection) SetClientInfo(poolName string, host string)(error){ +func (heraConn *heraConnection) SetClientInfo(poolName string, host string) error { if len(poolName) <= 0 && len(host) <= 0 { return nil } - + if heraConn.closed.Load() { + logger.GetLogger().Log(logger.Alert, ErrInvalidConn) + return driver.ErrBadConn + } pid := os.Getpid() data := fmt.Sprintf("PID: %d, HOST: %s, Poolname: %s, Command: SetClientInfo,", pid, host, poolName) - c.clientinfo = netstring.NewNetstringFrom(common.CmdClientInfo, []byte(string(data))) - if logger.GetLogger().V(logger.Verbose) { - logger.GetLogger().Log(logger.Verbose, "SetClientInfo", c.clientinfo.Serialized) - } - - _, err := c.conn.Write(c.clientinfo.Serialized) - if err != nil { - if logger.GetLogger().V(logger.Warning) { - logger.GetLogger().Log(logger.Warning, "Failed to send client info") - } - return errors.New("Failed custom auth, failed to send client info") - } - ns, err := c.reader.ReadNext() - if err != nil { - if logger.GetLogger().V(logger.Warning) { - logger.GetLogger().Log(logger.Warning, "Failed to read server info") - } - return errors.New("Failed to read server info") - } - if logger.GetLogger().V(logger.Debug) { - logger.GetLogger().Log(logger.Debug, "Server info:", string(ns.Payload)) - } + heraConn.clientinfo = netstring.NewNetstringFrom(common.CmdClientInfo, []byte(string(data))) + if logger.GetLogger().V(logger.Verbose) { + logger.GetLogger().Log(logger.Verbose, "SetClientInfo", heraConn.clientinfo.Serialized) + } + + _, err := heraConn.conn.Write(heraConn.clientinfo.Serialized) + if err != nil { + if logger.GetLogger().V(logger.Warning) { + logger.GetLogger().Log(logger.Warning, "Failed to send client info") + } + return errors.New("Failed custom auth, failed to send client info") + } + ns, err := heraConn.reader.ReadNext() + if err != nil { + if logger.GetLogger().V(logger.Warning) { + logger.GetLogger().Log(logger.Warning, "Failed to read server info") + } + return errors.New("Failed to read server info") + } + if logger.GetLogger().V(logger.Debug) { + logger.GetLogger().Log(logger.Debug, "Server info:", string(ns.Payload)) + } return nil } -func (c *heraConnection) SetClientInfoWithPoolStack(poolName string, host string, poolStack string)(error){ +func (heraConn *heraConnection) SetClientInfoWithPoolStack(poolName string, host string, poolStack string) error { if len(poolName) <= 0 && len(host) <= 0 && len(poolStack) <= 0 { return nil } - + if heraConn.closed.Load() { + logger.GetLogger().Log(logger.Alert, ErrInvalidConn) + return driver.ErrBadConn + } pid := os.Getpid() data := fmt.Sprintf("PID: %d, HOST: %s, Poolname: %s, PoolStack: %s, Command: SetClientInfo,", pid, host, poolName, poolStack) - c.clientinfo = netstring.NewNetstringFrom(common.CmdClientInfo, []byte(string(data))) - if logger.GetLogger().V(logger.Verbose) { - logger.GetLogger().Log(logger.Verbose, "SetClientInfo", c.clientinfo.Serialized) - } - - _, err := c.conn.Write(c.clientinfo.Serialized) - if err != nil { - if logger.GetLogger().V(logger.Warning) { - logger.GetLogger().Log(logger.Warning, "Failed to send client info") - } - return errors.New("Failed custom auth, failed to send client info") - } - ns, err := c.reader.ReadNext() - if err != nil { - if logger.GetLogger().V(logger.Warning) { - logger.GetLogger().Log(logger.Warning, "Failed to read server info") - } - return errors.New("Failed to read server info") - } - if logger.GetLogger().V(logger.Debug) { - logger.GetLogger().Log(logger.Debug, "Server info:", string(ns.Payload)) - } + heraConn.clientinfo = netstring.NewNetstringFrom(common.CmdClientInfo, []byte(string(data))) + if logger.GetLogger().V(logger.Verbose) { + logger.GetLogger().Log(logger.Verbose, "SetClientInfo", heraConn.clientinfo.Serialized) + } + + _, err := heraConn.conn.Write(heraConn.clientinfo.Serialized) + if err != nil { + if logger.GetLogger().V(logger.Warning) { + logger.GetLogger().Log(logger.Warning, "Failed to send client info") + } + return errors.New("Failed custom auth, failed to send client info") + } + ns, err := heraConn.reader.ReadNext() + if err != nil { + if logger.GetLogger().V(logger.Warning) { + logger.GetLogger().Log(logger.Warning, "Failed to read server info") + } + return errors.New("Failed to read server info") + } + if logger.GetLogger().V(logger.Debug) { + logger.GetLogger().Log(logger.Debug, "Server info:", string(ns.Payload)) + } return nil -} \ No newline at end of file +} diff --git a/client/gosqldriver/statement.go b/client/gosqldriver/statement.go index 90217b60..3445b425 100644 --- a/client/gosqldriver/statement.go +++ b/client/gosqldriver/statement.go @@ -82,6 +82,7 @@ func (st *stmt) NumInput() int { // Implements driver.Stmt. // Exec executes a query that doesn't return rows, such as an INSERT or UPDATE. func (st *stmt) Exec(args []driver.Value) (driver.Result, error) { + defer st.hera.finish() sk := 0 if len(st.hera.shardKeyPayload) > 0 { sk = 1 @@ -167,6 +168,10 @@ func (st *stmt) Exec(args []driver.Value) (driver.Result, error) { func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { //TODO: refactor ExecContext / Exec to reuse code //TODO: honor the context timeout and return when it is canceled + if err := st.hera.watchCancel(ctx); err != nil { + return nil, err + } + defer st.hera.finish() sk := 0 if len(st.hera.shardKeyPayload) > 0 { sk = 1 @@ -255,6 +260,7 @@ func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driv // Implements driver.Stmt. // Query executes a query that may return rows, such as a SELECT. func (st *stmt) Query(args []driver.Value) (driver.Rows, error) { + defer st.hera.finish() sk := 0 if len(st.hera.shardKeyPayload) > 0 { sk = 1 @@ -359,6 +365,10 @@ Loop: func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { // TODO: refactor Query/QueryContext to reuse code // TODO: honor the context timeout and return when it is canceled + if err := st.hera.watchCancel(ctx); err != nil { + return nil, err + } + defer st.hera.finish() sk := 0 if len(st.hera.shardKeyPayload) > 0 { sk = 1 diff --git a/client/gosqldriver/utils.go b/client/gosqldriver/utils.go new file mode 100644 index 00000000..98eb5131 --- /dev/null +++ b/client/gosqldriver/utils.go @@ -0,0 +1,67 @@ +package gosqldriver + +import ( + "errors" + "sync" + "sync/atomic" +) + +var ErrInvalidConn = errors.New("invalid connection") + +// atomicError provides thread-safe error handling +type atomicError struct { + value atomic.Value + mu sync.Mutex +} + +// Set sets the error value atomically. The value must not be nil. +func (ae *atomicError) Set(err error) { + if err == nil { + panic("atomicError: nil error value") + } + ae.mu.Lock() + defer ae.mu.Unlock() + ae.value.Store(err) +} + +// Value returns the current error value, or nil if none is set. +func (ae *atomicError) Value() error { + v := ae.value.Load() + if v == nil { + return nil + } + return v.(error) +} + +type atomicBool struct { + value uint32 + mu sync.Mutex +} + +// Store sets the value of the bool regardless of the previous value +func (ab *atomicBool) Store(value bool) { + ab.mu.Lock() + defer ab.mu.Unlock() + if value { + atomic.StoreUint32(&ab.value, 1) + } else { + atomic.StoreUint32(&ab.value, 0) + } +} + +// Load returns whether the current boolean value is true +func (ab *atomicBool) Load() bool { + ab.mu.Lock() + defer ab.mu.Unlock() + return atomic.LoadUint32(&ab.value) > 0 +} + +// Swap sets the value of the bool and returns the old value. +func (ab *atomicBool) Swap(value bool) bool { + ab.mu.Lock() + defer ab.mu.Unlock() + if value { + return atomic.SwapUint32(&ab.value, 1) > 0 + } + return atomic.SwapUint32(&ab.value, 0) > 0 +} diff --git a/lib/config.go b/lib/config.go index 3fc6c69e..635b7a48 100644 --- a/lib/config.go +++ b/lib/config.go @@ -20,13 +20,12 @@ package lib import ( "errors" "fmt" + "github.com/paypal/hera/config" + "github.com/paypal/hera/utility/logger" "os" "path/filepath" "strings" "sync/atomic" - - "github.com/paypal/hera/config" - "github.com/paypal/hera/utility/logger" ) //The Config contains all the static configuration @@ -80,11 +79,11 @@ type Config struct { // time_skew_threshold_error(15) TimeSkewThresholdErrorSec int // max_stranded_time_interval(2000) - StrandedWorkerTimeoutMs int + StrandedWorkerTimeoutMs int HighLoadStrandedWorkerTimeoutMs int - HighLoadSkipInitiateRecoverPct int - HighLoadPct int - InitLimitPct int + HighLoadSkipInitiateRecoverPct int + HighLoadPct int + InitLimitPct int // the worker scheduler policy LifoScheduler bool @@ -110,7 +109,7 @@ type Config struct { HostnamePrefix map[string]string ShardingCrossKeysErr bool - CfgFromTns bool + CfgFromTns bool CfgFromTnsOverrideNumShards int // -1 no-override CfgFromTnsOverrideTaf int // -1 no-override, 0 override-false, 1 override-true CfgFromTnsOverrideRWSplit int // -1 no-override, readChildPct @@ -156,8 +155,8 @@ type Config struct { // when numWorkers changes, it will write to this channel, for worker manager to update numWorkersCh chan int - EnableConnLimitCheck bool - EnableQueryBindBlocker bool + EnableConnLimitCheck bool + EnableQueryBindBlocker bool QueryBindBlockerMinSqlPrefix int // taf testing @@ -169,7 +168,7 @@ type Config struct { EnableDanglingWorkerRecovery bool GoStatsInterval int - RandomStartMs int + RandomStartMs int // The max number of database connections to be established per second MaxDbConnectsPerSec int @@ -178,7 +177,7 @@ type Config struct { MaxDesiredHealthyWorkerPct int //Timeout for management queries. - ManagementQueriesTimeoutInMs int + ManagementQueriesTimeoutInUs int } // The OpsConfig contains the configuration that can be modified during run time @@ -277,10 +276,9 @@ func InitConfig() error { gAppConfig.StrandedWorkerTimeoutMs = cdb.GetOrDefaultInt("max_stranded_time_interval", 2000) gAppConfig.HighLoadStrandedWorkerTimeoutMs = cdb.GetOrDefaultInt("high_load_max_stranded_time_interval", 600111) gAppConfig.HighLoadSkipInitiateRecoverPct = cdb.GetOrDefaultInt("high_load_skip_initiate_recover_pct", 80) - gAppConfig.HighLoadPct = cdb.GetOrDefaultInt("high_load_pct", 130) // >100 disabled + gAppConfig.HighLoadPct = cdb.GetOrDefaultInt("high_load_pct", 130) // >100 disabled gAppConfig.InitLimitPct = cdb.GetOrDefaultInt("init_limit_pct", 125) // >100 disabled - gAppConfig.StateLogInterval = cdb.GetOrDefaultInt("state_log_interval", 1) if gAppConfig.StateLogInterval <= 0 { gAppConfig.StateLogInterval = 1 @@ -303,7 +301,7 @@ func InitConfig() error { gAppConfig.ChildExecutable = "postgresworker" } } else { - // db type is not supported + // db type is not supported return errors.New("database type must be either Oracle or MySQL") } @@ -428,9 +426,8 @@ func InitConfig() error { fmt.Sscanf(cdb.GetOrDefaultString("bind_eviction_decr_per_sec", "10.0"), "%f", &gAppConfig.BindEvictionDecrPerSec) - gAppConfig.SkipEvictRegex= cdb.GetOrDefaultString("skip_eviction_host_prefix","") - gAppConfig.EvictRegex= cdb.GetOrDefaultString("eviction_host_prefix", "") - + gAppConfig.SkipEvictRegex = cdb.GetOrDefaultString("skip_eviction_host_prefix", "") + gAppConfig.EvictRegex = cdb.GetOrDefaultString("eviction_host_prefix", "") gAppConfig.BouncerEnabled = cdb.GetOrDefaultBool("bouncer_enabled", true) gAppConfig.BouncerStartupDelay = cdb.GetOrDefaultInt("bouncer_startup_delay", 10) @@ -466,7 +463,7 @@ func InitConfig() error { gAppConfig.MaxDesiredHealthyWorkerPct = 90 } - gAppConfig.ManagementQueriesTimeoutInMs = cdb.GetOrDefaultInt("management_queries_timeout_ms", 200) + gAppConfig.ManagementQueriesTimeoutInUs = cdb.GetOrDefaultInt("management_queries_timeout_us", 200000) //200 milliseconds return nil } diff --git a/lib/querybindblocker.go b/lib/querybindblocker.go index 638f3c7c..8df02c30 100644 --- a/lib/querybindblocker.go +++ b/lib/querybindblocker.go @@ -148,7 +148,7 @@ func InitQueryBindBlocker(modName string) { } func loadBlockQueryBind(db *sql.DB) { - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(GetConfig().ManagementQueriesTimeoutInMs)*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(GetConfig().ManagementQueriesTimeoutInUs)*time.Microsecond) defer cancel() conn, err := db.Conn(ctx) if err != nil { diff --git a/lib/racmaint.go b/lib/racmaint.go index 9026ed8a..6da59a91 100644 --- a/lib/racmaint.go +++ b/lib/racmaint.go @@ -103,7 +103,7 @@ func racMaintMain(shard int, interval int, cmdLineModuleName string) { binds[0] = strings.ToUpper(binds[0]) binds[1] = strings.ToUpper(cmdLineModuleName) // */ //First time data loading - racMaint(&ctx, shard, db, racSQL, cmdLineModuleName, prev, GetConfig().ManagementQueriesTimeoutInMs) + racMaint(&ctx, shard, db, racSQL, cmdLineModuleName, prev, GetConfig().ManagementQueriesTimeoutInUs) timeTicker := time.NewTicker(time.Second * time.Duration(interval)) defer timeTicker.Stop() @@ -114,7 +114,7 @@ func racMaintMain(shard int, interval int, cmdLineModuleName string) { return case <-timeTicker.C: //Periodic data loading - racMaint(&ctx, shard, db, racSQL, cmdLineModuleName, prev, GetConfig().ManagementQueriesTimeoutInMs) + racMaint(&ctx, shard, db, racSQL, cmdLineModuleName, prev, GetConfig().ManagementQueriesTimeoutInUs) timeTicker.Reset(time.Second * time.Duration(interval)) } } @@ -124,7 +124,7 @@ func racMaintMain(shard int, interval int, cmdLineModuleName string) { racMaint is the main function for RAC maintenance processing, being called regularly. When maintenance is planned, it calls workerpool.RacMaint to start the actuall processing */ -func racMaint(ctx *context.Context, shard int, db *sql.DB, racSQL string, cmdLineModuleName string, prev map[racCfgKey]racCfg, queryTimeoutInMs int) { +func racMaint(ctx *context.Context, shard int, db *sql.DB, racSQL string, cmdLineModuleName string, prev map[racCfgKey]racCfg, queryTimeoutInUs int) { // // print this log for unittesting // @@ -132,7 +132,7 @@ func racMaint(ctx *context.Context, shard int, db *sql.DB, racSQL string, cmdLin logger.GetLogger().Log(logger.Verbose, "Rac maint check, shard =", shard) } //create cancellable context - queryContext, cancel := context.WithTimeout(*ctx, time.Duration(queryTimeoutInMs)*time.Millisecond) + queryContext, cancel := context.WithTimeout(*ctx, time.Duration(queryTimeoutInUs)*time.Microsecond) defer cancel() // Always call cancel to release resources associated with the context conn, err := db.Conn(queryContext) diff --git a/lib/shardingcfg.go b/lib/shardingcfg.go index 276fca52..ea6f0f25 100644 --- a/lib/shardingcfg.go +++ b/lib/shardingcfg.go @@ -109,7 +109,7 @@ func loadMap(ctx *context.Context, db *sql.DB, queryTimeoutInterval int) error { logger.GetLogger().Log(logger.Verbose, "Done loading shard map") }() } - queryContext, cancel := context.WithTimeout(*ctx, time.Duration(queryTimeoutInterval)*time.Millisecond) + queryContext, cancel := context.WithTimeout(*ctx, time.Duration(queryTimeoutInterval)*time.Microsecond) defer cancel() conn, err := db.Conn(queryContext) if err != nil { @@ -226,7 +226,7 @@ func loadWhitelist(ctx *context.Context, db *sql.DB, timeoutInMs int) { logger.GetLogger().Log(logger.Verbose, "Done loading whitelist") }() } - queryContext, cancel := context.WithTimeout(*ctx, time.Duration(timeoutInMs)*time.Millisecond) + queryContext, cancel := context.WithTimeout(*ctx, time.Duration(timeoutInMs)*time.Microsecond) defer cancel() conn, err := db.Conn(queryContext) if err != nil { @@ -305,13 +305,13 @@ func InitShardingCfg() error { } db, err = openDb(shard) if err == nil { - err = loadMap(&ctx, db, GetConfig().ManagementQueriesTimeoutInMs) + err = loadMap(&ctx, db, GetConfig().ManagementQueriesTimeoutInUs) if err == nil { break } } logger.GetLogger().Log(logger.Warning, "Error <", err, "> loading the shard map from shard", shard) - evt := cal.NewCalEvent(cal.EventTypeError, "no_shard_map", cal.TransOK, "Error loading shard map") + evt := cal.NewCalEvent(cal.EventTypeError, "no_shard_map", cal.TransOK, fmt.Sprintf("Error loading shard map %v", err)) evt.Completed() } if err == nil { @@ -324,7 +324,7 @@ func InitShardingCfg() error { return errors.New("Failed to load shard map, no more retry") } if GetConfig().EnableWhitelistTest { - loadWhitelist(&ctx, db, GetConfig().ManagementQueriesTimeoutInMs) + loadWhitelist(&ctx, db, GetConfig().ManagementQueriesTimeoutInUs) } go func() { reloadTimer := time.NewTimer(reloadInterval) //Periodic reload timer @@ -341,10 +341,10 @@ func InitShardingCfg() error { } db, err = openDb(shard) if err == nil { - err = loadMap(&ctx, db, GetConfig().ManagementQueriesTimeoutInMs) + err = loadMap(&ctx, db, GetConfig().ManagementQueriesTimeoutInUs) if err == nil { if shard == 0 && GetConfig().EnableWhitelistTest { - loadWhitelist(&ctx, db, GetConfig().ManagementQueriesTimeoutInMs) + loadWhitelist(&ctx, db, GetConfig().ManagementQueriesTimeoutInUs) } break } diff --git a/tests/unittest/coordinator_sharding_mgmt_query_timeout/main_test.go b/tests/unittest/coordinator_sharding_mgmt_query_timeout/main_test.go index 988e7aab..fca222aa 100644 --- a/tests/unittest/coordinator_sharding_mgmt_query_timeout/main_test.go +++ b/tests/unittest/coordinator_sharding_mgmt_query_timeout/main_test.go @@ -35,7 +35,7 @@ func cfg() (map[string]string, map[string]string, testutil.WorkerType) { } appcfg["sharding_cfg_reload_interval"] = "2" appcfg["rac_sql_interval"] = "0" - appcfg["management_queries_timeout_ms"] = "2" + appcfg["management_queries_timeout_us"] = "400" opscfg := make(map[string]string) opscfg["opscfg.default.server.max_connections"] = "3" @@ -66,14 +66,14 @@ func setupShardMap() { } defer conn.Close() - testutil.RunDML("create table hera_shard_map ( scuttle_id smallint not null, shard_id tinyint not null, status char(1) , read_status char(1), write_status char(1), remarks varchar(500))") + testutil.DBDirect("create table hera_shard_map ( scuttle_id smallint not null, shard_id tinyint not null, status char(1) , read_status char(1), write_status char(1), remarks varchar(500))", os.Getenv("MYSQL_IP"), "heratestdb", testutil.MySQL) - for i := 0; i < 1024; i++ { + for i := 0; i < 9; i++ { shard := 0 - if i <= 8 { + if i >= 3 { shard = i % 3 } - testutil.RunDML(fmt.Sprintf("insert into hera_shard_map ( scuttle_id, shard_id, status, read_status, write_status ) values ( %d, %d, 'Y', 'Y', 'Y' )", i, shard)) + testutil.DBDirect(fmt.Sprintf("insert into hera_shard_map ( scuttle_id, shard_id, status, read_status, write_status ) values ( %d, %d, 'Y', 'Y', 'Y' )", i, shard), os.Getenv("MYSQL_IP"), "heratestdb", testutil.MySQL) } } @@ -84,7 +84,7 @@ func before() error { } if strings.HasPrefix(os.Getenv("TWO_TASK"), "tcp") { // mysql - testutil.RunDML("create table jdbc_hera_test2 ( ID BIGINT, INT_VAL BIGINT, STR_VAL VARCHAR(500))") + testutil.DBDirect("create table jdbc_hera_test2 ( ID BIGINT, INT_VAL BIGINT, STR_VAL VARCHAR(500))", os.Getenv("MYSQL_IP"), "heratestdb", testutil.MySQL) } return nil } @@ -93,17 +93,6 @@ func TestMain(m *testing.M) { os.Exit(testutil.UtilMain(m, cfg, before)) } -func cleanup(ctx context.Context, conn *sql.Conn) error { - tx, _ := conn.BeginTx(ctx, nil) - stmt, _ := tx.PrepareContext(ctx, "/*Cleanup*/delete from "+tableName+" where id != :id") - _, err := stmt.Exec(sql.Named("id", -123)) - if err != nil { - return err - } - err = tx.Commit() - return nil -} - func TestShardingWithContextTimeout(t *testing.T) { logger.GetLogger().Log(logger.Debug, "TestShardingWithContextTimeout setup") setupShardMap() @@ -118,52 +107,10 @@ func TestShardingWithContextTimeout(t *testing.T) { db.SetMaxIdleConns(0) defer db.Close() - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - conn, err := db.Conn(ctx) - if err != nil { - t.Fatalf("Error getting connection %s\n", err.Error()) - } - cleanup(ctx, conn) - // insert one row in the table - tx, _ := conn.BeginTx(ctx, nil) - stmt, _ := tx.PrepareContext(ctx, "/*TestShardingWithContextTimeout*/insert into "+tableName+" (id, int_val, str_val) VALUES(:id, :int_val, :str_val)") - _, err = stmt.Exec(sql.Named("id", 1), sql.Named("int_val", time.Now().Unix()), sql.Named("str_val", "val 1")) - if err != nil { - t.Fatalf("Error preparing test (create row in table) %s\n", err.Error()) - } - err = tx.Commit() - if err != nil { - t.Fatalf("Error commit %s\n", err.Error()) - } - - stmt, _ = conn.PrepareContext(ctx, "/*TestShardingWithContextTimeout*/Select id, int_val, str_val from "+tableName+" where id=:id") - rows, _ := stmt.Query(sql.Named("id", 1)) - if !rows.Next() { - t.Fatalf("Expected 1 row") - } - var id, int_val uint64 - var str_val sql.NullString - err = rows.Scan(&id, &int_val, &str_val) - if err != nil { - t.Fatalf("Expected values %s", err.Error()) - } - if str_val.String != "val 1" { - t.Fatalf("Expected val 1 , got: %s", str_val.String) - } - - rows.Close() - stmt.Close() - - cancel() - conn.Close() - - out, err := testutil.BashCmd("grep 'Preparing: /\\*TestShardingWithContextTimeout\\*/' hera.log | grep 'WORKER shd2' | wc -l") - if (err != nil) || (len(out) == 0) { + out := testutil.RegexCountFile("loading shard map: context deadline exceeded", "cal.log") + if out < 2 { err = nil - t.Fatalf("Request did not run on shard 2. err = %v, len(out) = %d", err, len(out)) - } - if out[0] != '2' { - t.Fatalf("Expected 2 excutions on shard 2, instead got %d", int(out[0]-'0')) + t.Fatalf("sharding management query should fail with context timeout") } logger.GetLogger().Log(logger.Debug, "TestShardingWithContextTimeout done -------------------------------------------------------------") diff --git a/tests/unittest/querybindblocker_timeout/main_test.go b/tests/unittest/querybindblocker_timeout/main_test.go new file mode 100644 index 00000000..7d1b007a --- /dev/null +++ b/tests/unittest/querybindblocker_timeout/main_test.go @@ -0,0 +1,66 @@ +package main + +import ( + "database/sql" + "fmt" + "os" + "testing" + "time" + + "github.com/paypal/hera/tests/unittest/testutil" + "github.com/paypal/hera/utility/logger" +) + +var mx testutil.Mux + +func cfg() (map[string]string, map[string]string, testutil.WorkerType) { + fmt.Println("setup() begin") + appcfg := make(map[string]string) + // best to chose an "unique" port in case golang runs tests in paralel + appcfg["bind_port"] = "31002" + appcfg["log_level"] = "5" + appcfg["log_file"] = "hera.log" + appcfg["rac_sql_interval"] = "0" + appcfg["enable_query_bind_blocker"] = "true" + appcfg["management_queries_timeout_us"] = "400" + + opscfg := make(map[string]string) + opscfg["opscfg.default.server.max_connections"] = "3" + opscfg["opscfg.default.server.log_level"] = "5" + if os.Getenv("WORKER") == "postgres" { + return appcfg, opscfg, testutil.PostgresWorker + } + return appcfg, opscfg, testutil.MySQLWorker +} + +func teardown() { + mx.StopServer() +} + +func TestMain(m *testing.M) { + os.Exit(testutil.UtilMain(m, cfg, nil)) +} + +func TestQueryBindBlockerWithTimeout(t *testing.T) { + testutil.DBDirect("create table hera_rate_limiter (herasqlhash numeric not null, herasqltext varchar(4000) not null, bindvarname varchar(200) not null, bindvarvalue varchar(200) not null, blockperc numeric not null, heramodule varchar(100) not null, end_time numeric not null, remarks varchar(200) not null)", os.Getenv("MYSQL_IP"), "heratestdb", testutil.MySQL) + + logger.GetLogger().Log(logger.Debug, "TestQueryBindBlockerWithTimeout begin +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n") + time.Sleep(16 * time.Second) + hostname, _ := os.Hostname() + db, err := sql.Open("heraloop", hostname+":31002") + if err != nil { + t.Fatal("Error starting Mux:", err) + return + } + db.SetMaxIdleConns(0) + defer db.Close() + time.Sleep(2 * time.Second) + out := testutil.RegexCountFile("loading query bind blocker: context deadline exceeded", "hera.log") + if out < 1 { + err = nil + t.Fatalf("query bind blocker management query should fail with context timeout") + } + + logger.GetLogger().Log(logger.Debug, "TestQueryBindBlockerWithTimeout done -------------------------------------------------------------") + +} diff --git a/tests/unittest/rac_maint_mgmt_query_timeout/main_test.go b/tests/unittest/rac_maint_mgmt_query_timeout/main_test.go new file mode 100644 index 00000000..7fef725d --- /dev/null +++ b/tests/unittest/rac_maint_mgmt_query_timeout/main_test.go @@ -0,0 +1,68 @@ +package main + +import ( + "database/sql" + "github.com/paypal/hera/tests/unittest/testutil" + "github.com/paypal/hera/utility/logger" + "os" + "testing" + "time" +) + +var mx testutil.Mux +var tableName string + +func cfg() (map[string]string, map[string]string, testutil.WorkerType) { + + appcfg := make(map[string]string) + // best to chose an "unique" port in case golang runs tests in paralel + appcfg["bind_port"] = "31002" + appcfg["log_level"] = "5" + appcfg["log_file"] = "hera.log" + appcfg["sharding_cfg_reload_interval"] = "0" + appcfg["rac_sql_interval"] = "1" + + opscfg := make(map[string]string) + opscfg["opscfg.default.server.max_connections"] = "3" + opscfg["opscfg.default.server.log_level"] = "5" + appcfg["management_queries_timeout_us"] = "200" + + //return appcfg, opscfg, testutil.OracleWorker + return appcfg, opscfg, testutil.MySQLWorker +} + +func before() error { + os.Setenv("PARALLEL", "1") + pfx := os.Getenv("MGMT_TABLE_PREFIX") + if pfx == "" { + pfx = "hera" + } + tableName = pfx + "_maint" + return nil +} + +func TestMain(m *testing.M) { + os.Exit(testutil.UtilMain(m, cfg, before)) +} + +func TestRacMaintWithWithTimeout(t *testing.T) { + + logger.GetLogger().Log(logger.Debug, "TestRacMaintWithWithTimeout begin +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n") + time.Sleep(16 * time.Second) + hostname, _ := os.Hostname() + db, err := sql.Open("heraloop", hostname+":31002") + if err != nil { + t.Fatal("Error starting Mux:", err) + return + } + db.SetMaxIdleConns(0) + defer db.Close() + time.Sleep(2 * time.Second) + out := testutil.RegexCountFile("rac maint for shard = 0 ,err : context deadline exceeded", "hera.log") + if out < 1 { + err = nil + t.Fatalf("rac maint management query should fail with context timeout") + } + + logger.GetLogger().Log(logger.Debug, "TestRacMaintWithWithTimeout done -------------------------------------------------------------") +} From af57938d88e734376d34205fa17b9941df4ac882 Mon Sep 17 00:00:00 2001 From: Rajesh S Date: Wed, 22 May 2024 23:06:06 +0530 Subject: [PATCH 5/6] changes for fixing timeout tests --- tests/unittest/bindEvict/main_test.go | 6 +++--- tests/unittest/bindThrottle/main_test.go | 2 +- tests/unittest/querybindblocker_timeout/main_test.go | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/unittest/bindEvict/main_test.go b/tests/unittest/bindEvict/main_test.go index 3b8b87e0..3306fb63 100644 --- a/tests/unittest/bindEvict/main_test.go +++ b/tests/unittest/bindEvict/main_test.go @@ -114,8 +114,8 @@ func fastAndSlowBinds() error { // client threads of slow queries var stop2 int var badCliErr string - mkClients(1+int(max_conn*1.6), &stop2, 29001111, "badClient", &badCliErr, db) - time.Sleep(3100 * time.Millisecond) + mkClients(1+int(max_conn*1.2), &stop2, 29001111, "badClient", &badCliErr, db) + time.Sleep(5100 * time.Millisecond) /* if (testutil.RegexCountFile("BIND_THROTTLE", "cal.log") == 0) { return fmt.Errorf("BIND_THROTTLE was not triggered") } @@ -208,7 +208,7 @@ func TestBindEvict(t *testing.T) { logger.GetLogger().Log(logger.Debug, "TestBindEvict +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n") err := fastAndSlowBinds() if err != nil { - t.Fatalf("main step function returned err %s", err.Error()) + t.Errorf("main step function returned err %s", err.Error()) } if testutil.RegexCountFile("BIND_THROTTLE", "cal.log") == 0 { t.Fatalf("BIND_THROTTLE was not triggered") diff --git a/tests/unittest/bindThrottle/main_test.go b/tests/unittest/bindThrottle/main_test.go index 7cdede41..23717487 100644 --- a/tests/unittest/bindThrottle/main_test.go +++ b/tests/unittest/bindThrottle/main_test.go @@ -220,7 +220,7 @@ func TestBindThrottle(t *testing.T) { } // */ logger.GetLogger().Log(logger.Debug, "BindThrottle midpt +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n") - err = partialBadLoad(0.8) + err = partialBadLoad(0.7) if err != nil { // t.Fatalf("main step function returned err %s", err.Error()) // can be triggered since test only has one sql } diff --git a/tests/unittest/querybindblocker_timeout/main_test.go b/tests/unittest/querybindblocker_timeout/main_test.go index 7d1b007a..1f0d19be 100644 --- a/tests/unittest/querybindblocker_timeout/main_test.go +++ b/tests/unittest/querybindblocker_timeout/main_test.go @@ -22,7 +22,7 @@ func cfg() (map[string]string, map[string]string, testutil.WorkerType) { appcfg["log_file"] = "hera.log" appcfg["rac_sql_interval"] = "0" appcfg["enable_query_bind_blocker"] = "true" - appcfg["management_queries_timeout_us"] = "400" + appcfg["management_queries_timeout_us"] = "200" opscfg := make(map[string]string) opscfg["opscfg.default.server.max_connections"] = "3" @@ -54,7 +54,7 @@ func TestQueryBindBlockerWithTimeout(t *testing.T) { } db.SetMaxIdleConns(0) defer db.Close() - time.Sleep(2 * time.Second) + time.Sleep(3 * time.Second) out := testutil.RegexCountFile("loading query bind blocker: context deadline exceeded", "hera.log") if out < 1 { err = nil From 3097b91d9050c00b66d0d003080303d0c401f01d Mon Sep 17 00:00:00 2001 From: Rajesh S Date: Thu, 23 May 2024 09:37:51 +0530 Subject: [PATCH 6/6] reduce timeout for timeout for context deadline --- tests/unittest/querybindblocker_timeout/main_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittest/querybindblocker_timeout/main_test.go b/tests/unittest/querybindblocker_timeout/main_test.go index 1f0d19be..7177ccc3 100644 --- a/tests/unittest/querybindblocker_timeout/main_test.go +++ b/tests/unittest/querybindblocker_timeout/main_test.go @@ -22,7 +22,7 @@ func cfg() (map[string]string, map[string]string, testutil.WorkerType) { appcfg["log_file"] = "hera.log" appcfg["rac_sql_interval"] = "0" appcfg["enable_query_bind_blocker"] = "true" - appcfg["management_queries_timeout_us"] = "200" + appcfg["management_queries_timeout_us"] = "100" opscfg := make(map[string]string) opscfg["opscfg.default.server.max_connections"] = "3" @@ -54,7 +54,7 @@ func TestQueryBindBlockerWithTimeout(t *testing.T) { } db.SetMaxIdleConns(0) defer db.Close() - time.Sleep(3 * time.Second) + time.Sleep(5 * time.Second) out := testutil.RegexCountFile("loading query bind blocker: context deadline exceeded", "hera.log") if out < 1 { err = nil