From e2f360164b772772387d0841cccb50159b478068 Mon Sep 17 00:00:00 2001 From: Rajesh S <105205300+rasamala83@users.noreply.github.com> Date: Thu, 16 May 2024 10:17:34 +0530 Subject: [PATCH] adding changes to disable automatic bind throttle (#392) * adding changes to disable automatic bind throttle * updating values bind throttle decrese per sec and removed unused code * updating bind eviction test * fixing review comment * fixing review comment * adding test for if rate limit table not exist or empty * move tests to different package to avoid running them in parallel * updating sleep time in tests * added changes for increase throttling recovery speed * changes for updating text check condition in test code * reverting changes for bind throttle * reverted partial changes for local copy of bindEvict object it is going taken care separate change request * changes for simplifying test code for qbb --------- Co-authored-by: Rajesh S --- lib/bindevict.go | 16 +- lib/config.go | 26 +- lib/querybindblocker.go | 66 ++--- tests/unittest/bindThrottle/main_test.go | 242 ++++++++++++++++++ tests/unittest/querybindblocker/main_test.go | 188 +++++++++----- .../main_test.go | 68 +++++ 6 files changed, 491 insertions(+), 115 deletions(-) create mode 100644 tests/unittest/bindThrottle/main_test.go create mode 100644 tests/unittest/querybindblocker_ratelimit_table_empty/main_test.go diff --git a/lib/bindevict.go b/lib/bindevict.go index a31585b6..67d5bb22 100644 --- a/lib/bindevict.go +++ b/lib/bindevict.go @@ -42,21 +42,21 @@ type BindEvict struct { // evicted binds get throttled to have overall steady state during bad bind queries // nested map uses sqlhash "bindName|bindValue" BindThrottle map[uint32]map[string]*BindThrottle - lock sync.Mutex + lock sync.Mutex } func GetBindEvict() *BindEvict { cfg := gBindEvict.Load() if cfg == nil { - out := BindEvict{BindThrottle:make(map[uint32]map[string]*BindThrottle)} + out := BindEvict{BindThrottle: make(map[uint32]map[string]*BindThrottle)} gBindEvict.Store(&out) return &out } return cfg.(*BindEvict) } func (this *BindEvict) Copy() *BindEvict { - out := BindEvict{BindThrottle:make(map[uint32]map[string]*BindThrottle)} - for k,v := range this.BindThrottle { + out := BindEvict{BindThrottle: make(map[uint32]map[string]*BindThrottle)} + for k, v := range this.BindThrottle { out.BindThrottle[k] = v } return &out @@ -77,7 +77,7 @@ func NormalizeBindName(bindName0 string) string { func (entry *BindThrottle) decrAllowEveryX(y int) { if y >= 2 && logger.GetLogger().V(logger.Warning) { - info := fmt.Sprintf("hash:%d bindName:%s val:%s allowEveryX:%d-%d",entry.Sqlhash, entry.Name, entry.Value, entry.AllowEveryX, y) + info := fmt.Sprintf("hash:%d bindName:%s val:%s allowEveryX:%d-%d", entry.Sqlhash, entry.Name, entry.Value, entry.AllowEveryX, y) logger.GetLogger().Log(logger.Warning, "bind throttle decr", info) } entry.AllowEveryX -= y @@ -96,7 +96,7 @@ func (entry *BindThrottle) decrAllowEveryX(y int) { // copy everything except bindKV (skipping it is deleting it) bindKV := fmt.Sprintf("%s|%s", entry.Name, entry.Value) updateCopy := make(map[string]*BindThrottle) - for k,v := range GetBindEvict().BindThrottle[entry.Sqlhash] { + for k, v := range GetBindEvict().BindThrottle[entry.Sqlhash] { if k == bindKV { continue } @@ -107,7 +107,7 @@ func (entry *BindThrottle) decrAllowEveryX(y int) { } func (entry *BindThrottle) incrAllowEveryX() { if logger.GetLogger().V(logger.Warning) { - info := fmt.Sprintf("hash:%d bindName:%s val:%s prev:%d",entry.Sqlhash, entry.Name, entry.Value, entry.AllowEveryX) + info := fmt.Sprintf("hash:%d bindName:%s val:%s prev:%d", entry.Sqlhash, entry.Name, entry.Value, entry.AllowEveryX) logger.GetLogger().Log(logger.Warning, "bind throttle incr", info) } entry.AllowEveryX = 3*entry.AllowEveryX + 1 @@ -149,7 +149,7 @@ func (be *BindEvict) ShouldBlock(sqlhash uint32, bindKV map[string]string, heavy entry.RecentAttempt.Store(&now) entry.AllowEveryXCount++ if entry.AllowEveryXCount < entry.AllowEveryX { - return true/*block*/, entry + return true /*block*/, entry } entry.AllowEveryXCount = 0 diff --git a/lib/config.go b/lib/config.go index f0283cfa..4b8bdc66 100644 --- a/lib/config.go +++ b/lib/config.go @@ -80,11 +80,11 @@ type Config struct { // time_skew_threshold_error(15) TimeSkewThresholdErrorSec int // max_stranded_time_interval(2000) - StrandedWorkerTimeoutMs int + StrandedWorkerTimeoutMs int HighLoadStrandedWorkerTimeoutMs int - HighLoadSkipInitiateRecoverPct int - HighLoadPct int - InitLimitPct int + HighLoadSkipInitiateRecoverPct int + HighLoadPct int + InitLimitPct int // the worker scheduler policy LifoScheduler bool @@ -110,7 +110,7 @@ type Config struct { HostnamePrefix map[string]string ShardingCrossKeysErr bool - CfgFromTns bool + CfgFromTns bool CfgFromTnsOverrideNumShards int // -1 no-override CfgFromTnsOverrideTaf int // -1 no-override, 0 override-false, 1 override-true CfgFromTnsOverrideRWSplit int // -1 no-override, readChildPct @@ -156,8 +156,8 @@ type Config struct { // when numWorkers changes, it will write to this channel, for worker manager to update numWorkersCh chan int - EnableConnLimitCheck bool - EnableQueryBindBlocker bool + EnableConnLimitCheck bool + EnableQueryBindBlocker bool QueryBindBlockerMinSqlPrefix int // taf testing @@ -169,7 +169,7 @@ type Config struct { EnableDanglingWorkerRecovery bool GoStatsInterval int - RandomStartMs int + RandomStartMs int // The max number of database connections to be established per second MaxDbConnectsPerSec int @@ -274,10 +274,9 @@ func InitConfig() error { gAppConfig.StrandedWorkerTimeoutMs = cdb.GetOrDefaultInt("max_stranded_time_interval", 2000) gAppConfig.HighLoadStrandedWorkerTimeoutMs = cdb.GetOrDefaultInt("high_load_max_stranded_time_interval", 600111) gAppConfig.HighLoadSkipInitiateRecoverPct = cdb.GetOrDefaultInt("high_load_skip_initiate_recover_pct", 80) - gAppConfig.HighLoadPct = cdb.GetOrDefaultInt("high_load_pct", 130) // >100 disabled + gAppConfig.HighLoadPct = cdb.GetOrDefaultInt("high_load_pct", 130) // >100 disabled gAppConfig.InitLimitPct = cdb.GetOrDefaultInt("init_limit_pct", 125) // >100 disabled - gAppConfig.StateLogInterval = cdb.GetOrDefaultInt("state_log_interval", 1) if gAppConfig.StateLogInterval <= 0 { gAppConfig.StateLogInterval = 1 @@ -300,7 +299,7 @@ func InitConfig() error { gAppConfig.ChildExecutable = "postgresworker" } } else { - // db type is not supported + // db type is not supported return errors.New("database type must be either Oracle or MySQL") } @@ -425,9 +424,8 @@ func InitConfig() error { fmt.Sscanf(cdb.GetOrDefaultString("bind_eviction_decr_per_sec", "10.0"), "%f", &gAppConfig.BindEvictionDecrPerSec) - gAppConfig.SkipEvictRegex= cdb.GetOrDefaultString("skip_eviction_host_prefix","") - gAppConfig.EvictRegex= cdb.GetOrDefaultString("eviction_host_prefix", "") - + gAppConfig.SkipEvictRegex = cdb.GetOrDefaultString("skip_eviction_host_prefix", "") + gAppConfig.EvictRegex = cdb.GetOrDefaultString("eviction_host_prefix", "") gAppConfig.BouncerEnabled = cdb.GetOrDefaultBool("bouncer_enabled", true) gAppConfig.BouncerStartupDelay = cdb.GetOrDefaultInt("bouncer_startup_delay", 10) diff --git a/lib/querybindblocker.go b/lib/querybindblocker.go index 001e9749..8ad2ad0f 100644 --- a/lib/querybindblocker.go +++ b/lib/querybindblocker.go @@ -31,14 +31,13 @@ import ( "github.com/paypal/hera/utility/logger" ) - type QueryBindBlockerEntry struct { - Herasqlhash uint32 - Herasqltext string // prefix since some sql is too long - Bindvarname string // prefix for in clause + Herasqlhash uint32 + Herasqltext string // prefix since some sql is too long + Bindvarname string // prefix for in clause Bindvarvalue string // when set to "BLOCKALLVALUES" should block all sqltext queries - Blockperc int - Heramodule string + Blockperc int + Heramodule string } type QueryBindBlockerCfg struct { @@ -48,7 +47,10 @@ type QueryBindBlockerCfg struct { // check by sqltext prefix (delay to end) } -func (cfg * QueryBindBlockerCfg) IsBlocked(sqltext string, bindPairs []string) (bool,string) { +var lastLoggingTime time.Time +var defaultQBBTableMissingErrorLoggingInterval = 2 * time.Hour + +func (cfg *QueryBindBlockerCfg) IsBlocked(sqltext string, bindPairs []string) (bool, string) { sqlhash := uint32(utility.GetSQLHash(sqltext)) if logger.GetLogger().V(logger.Verbose) { logger.GetLogger().Log(logger.Verbose, fmt.Sprintf("query bind blocker sqlhash and text %d %s", sqlhash, sqltext)) @@ -70,7 +72,7 @@ func (cfg * QueryBindBlockerCfg) IsBlocked(sqltext string, bindPairs []string) ( byBindValue, ok := byBindName[bindPairs[i]] if !ok { // strip numeric suffix to try to match - withoutNumSuffix := regexp.MustCompile("[_0-9]*$").ReplaceAllString(bindPairs[i],"") + withoutNumSuffix := regexp.MustCompile("[_0-9]*$").ReplaceAllString(bindPairs[i], "") byBindValue, ok = byBindName[withoutNumSuffix] if !ok { continue @@ -118,28 +120,27 @@ func (cfg * QueryBindBlockerCfg) IsBlocked(sqltext string, bindPairs []string) ( var g_module string var gQueryBindBlockerCfg atomic.Value -func GetQueryBindBlockerCfg() (*QueryBindBlockerCfg) { - cfg := gQueryBindBlockerCfg.Load() - if cfg == nil { - return nil - } - return cfg.(*QueryBindBlockerCfg) +func GetQueryBindBlockerCfg() *QueryBindBlockerCfg { + cfg := gQueryBindBlockerCfg.Load() + if cfg == nil { + return nil + } + return cfg.(*QueryBindBlockerCfg) } - func InitQueryBindBlocker(modName string) { g_module = modName - db, err := sql.Open("heraloop", fmt.Sprintf("0:0:0")) - if err != nil { + db, err := sql.Open("heraloop", fmt.Sprintf("0:0:0")) + if err != nil { logger.GetLogger().Log(logger.Alert, "Loading query bind blocker - conn err ", err) - return - } - db.SetMaxIdleConns(0) - + return + } + db.SetMaxIdleConns(0) go func() { - time.Sleep(4*time.Second) + time.Sleep(4 * time.Second) logger.GetLogger().Log(logger.Info, "Loading query bind blocker - initial") + loadBlockQueryBind(db) c := time.Tick(11 * time.Second) for now := range c { @@ -152,11 +153,12 @@ func InitQueryBindBlocker(modName string) { func loadBlockQueryBind(db *sql.DB) { ctx, cancel := context.WithTimeout(context.Background(), 5000*time.Millisecond) defer cancel() - conn, err := db.Conn(ctx); + conn, err := db.Conn(ctx) if err != nil { logger.GetLogger().Log(logger.Alert, "Error (conn) loading query bind blocker:", err) return } + defer conn.Close() q := fmt.Sprintf("SELECT /*queryBindBlocker*/ %ssqlhash, %ssqltext, bindvarname, bindvarvalue, blockperc, %smodule FROM %s_rate_limiter where %smodule='%s'", GetConfig().StateLogPrefix, GetConfig().StateLogPrefix, GetConfig().StateLogPrefix, GetConfig().ManagementTablePrefix, GetConfig().StateLogPrefix, g_module) logger.GetLogger().Log(logger.Info, "Loading query bind blocker meta-sql "+q) @@ -167,12 +169,18 @@ func loadBlockQueryBind(db *sql.DB) { } rows, err := stmt.QueryContext(ctx) if err != nil { - logger.GetLogger().Log(logger.Alert, "Error (query) loading query bind blocker:", err) - return + if lastLoggingTime.IsZero() || time.Since(lastLoggingTime) > defaultQBBTableMissingErrorLoggingInterval { + //In case table missing log alert event for every 2 hour + logger.GetLogger().Log(logger.Alert, "Error (query) loading query bind blocker:", err) + lastLoggingTime = time.Now() + return + } else { + return + } } defer rows.Close() - cfgLoad := QueryBindBlockerCfg{BySqlHash:make(map[uint32]map[string]map[string][]QueryBindBlockerEntry)} + cfgLoad := QueryBindBlockerCfg{BySqlHash: make(map[uint32]map[string]map[string][]QueryBindBlockerEntry)} rowCount := 0 for rows.Next() { @@ -182,9 +190,9 @@ func loadBlockQueryBind(db *sql.DB) { logger.GetLogger().Log(logger.Alert, "Error (row scan) loading query bind blocker:", err) continue } - + if len(entry.Herasqltext) < GetConfig().QueryBindBlockerMinSqlPrefix { - logger.GetLogger().Log(logger.Alert, "Error (row scan) loading query bind blocker - sqltext must be ", GetConfig().QueryBindBlockerMinSqlPrefix," bytes or more - sqlhash:", entry.Herasqlhash) + logger.GetLogger().Log(logger.Alert, "Error (row scan) loading query bind blocker - sqltext must be ", GetConfig().QueryBindBlockerMinSqlPrefix, " bytes or more - sqlhash:", entry.Herasqlhash) continue } rowCount++ @@ -200,7 +208,7 @@ func loadBlockQueryBind(db *sql.DB) { } bindVal, ok := bindName[entry.Bindvarvalue] if !ok { - bindVal = make([]QueryBindBlockerEntry,0) + bindVal = make([]QueryBindBlockerEntry, 0) bindName[entry.Bindvarvalue] = bindVal } bindName[entry.Bindvarvalue] = append(bindName[entry.Bindvarvalue], entry) diff --git a/tests/unittest/bindThrottle/main_test.go b/tests/unittest/bindThrottle/main_test.go new file mode 100644 index 00000000..7cdede41 --- /dev/null +++ b/tests/unittest/bindThrottle/main_test.go @@ -0,0 +1,242 @@ +package main + +import ( + "context" + "database/sql" + "fmt" + "os" + "testing" + "time" + + //"github.com/paypal/hera/client/gosqldriver" + _ "github.com/paypal/hera/client/gosqldriver/tcp" /*to register the driver*/ + + "github.com/paypal/hera/tests/unittest/testutil" + "github.com/paypal/hera/utility/logger" +) + +var mx testutil.Mux +var tableName string +var max_conn float64 + +func cfg() (map[string]string, map[string]string, testutil.WorkerType) { + + appcfg := make(map[string]string) + // best to chose an "unique" port in case golang runs tests in paralel + appcfg["bind_port"] = "31002" + appcfg["log_level"] = "5" + appcfg["log_file"] = "hera.log" + appcfg["sharding_cfg_reload_interval"] = "0" + appcfg["rac_sql_interval"] = "0" + appcfg["child.executable"] = "mysqlworker" + appcfg["bind_eviction_names"] = "p" + appcfg["bind_eviction_threshold_pct"] = "50" + + appcfg["request_backlog_timeout"] = "1000" + appcfg["soft_eviction_probability"] = "100" + + opscfg := make(map[string]string) + max_conn = 25 + opscfg["opscfg.default.server.max_connections"] = fmt.Sprintf("%d", int(max_conn)) + opscfg["opscfg.default.server.log_level"] = "5" + + opscfg["opscfg.default.server.saturation_recover_threshold"] = "10" + //opscfg["opscfg.default.server.saturation_recover_throttle_rate"]= "100" + opscfg["opscfg.hera.server.saturation_recover_throttle_rate"] = "100" + // saturation_recover_throttle_rate + + return appcfg, opscfg, testutil.MySQLWorker +} + +func before() error { + fmt.Printf("before run mysql") + testutil.RunMysql("create table sleep_info (id bigint, seconds float);") + testutil.RunMysql("insert into sleep_info (id,seconds) values(10, 0.01);") + testutil.RunMysql("insert into sleep_info (id,seconds) values(100, 0.1);") + testutil.RunMysql("insert into sleep_info (id,seconds) values(1600, 2.6);") + testutil.RunMysql("insert into sleep_info (id,seconds) values(21001111, 0.1);") + testutil.RunMysql("insert into sleep_info (id,seconds) values(22001111, 0.1);") + testutil.RunMysql("insert into sleep_info (id,seconds) values(29001111, 3.9);") + out, err := testutil.RunMysql(`DELIMITER $$ +CREATE FUNCTION sleep_option (id bigint) +RETURNS float +DETERMINISTIC +BEGIN + declare dur float; + declare rv bigint; + select max(seconds) into dur from sleep_info where sleep_info.id=id; + select sleep(dur) into rv; + RETURN dur; +END$$ +DELIMITER ;`) + if err != nil { + fmt.Printf("err after run mysql " + err.Error()) + return nil + } + fmt.Printf("after run mysql " + out) // */ + return nil +} + +func TestMain(m *testing.M) { + logger.GetLogger().Log(logger.Debug, "begin 20230918kkang TestMain") + fmt.Printf("TestMain 20230918kkang\n") + os.Exit(testutil.UtilMain(m, cfg, before)) +} + +func sleepyQ(conn *sql.Conn, delayRow int) error { + stmt, err := conn.PrepareContext(context.Background(), fmt.Sprintf("select * from sleep_info where ( seconds > sleep_option(?) or seconds > 0.0 ) and id=%d", delayRow)) + if err != nil { + fmt.Printf("Error preparing sleepyQ %s\n", err.Error()) + return err + } + defer stmt.Close() + rows, err := stmt.Query(delayRow) + if err != nil { + fmt.Printf("Error query sleepyQ %s\n", err.Error()) + return err + } + defer rows.Close() + return nil +} + +var normCliErr error + +func NormCliErr() error { + if normCliErr == nil { + normCliErr = fmt.Errorf("normal client got error") + } + return normCliErr +} + +func partialBadLoad(fracBad float64) error { + db, err := sql.Open("hera", "127.0.0.1:31002") + if err != nil { + fmt.Printf("Error db %s\n", err.Error()) + return err + } + db.SetConnMaxLifetime(111 * time.Second) + db.SetMaxIdleConns(0) + db.SetMaxOpenConns(22111) + defer db.Close() + + // client threads of slow queries + var stop2 int + var stop3 int + var badCliErr string + var cliErr string + numBad := int(max_conn * fracBad) + numNorm := int(max_conn*2.1) + 1 - numBad + fmt.Printf("spawning clients bad%d norm%d\n", numBad, numNorm) + mkClients(numBad, &stop2, 29001111, "badClient", &badCliErr, db) + mkClients(numNorm, &stop3, 100, "normClient", &cliErr, db) // bind value is short, so bindevict won't trigger + time.Sleep(3000 * time.Millisecond) + + // start normal clients after initial backlog timeouts + var stop int + var normCliErrStr string + mkClients(1, &stop, 21001111, "n client", &normCliErrStr, db) + time.Sleep(1000 * time.Millisecond) + + // if we throttle down or stop, it restores + stop2 = 1 // stop bad clients + stop3 = 1 + time.Sleep(3 * time.Second) //Make sure that clear throttle + conn, err := db.Conn(context.Background()) + if err != nil { + fmt.Printf("Error conn %s\n", err.Error()) + return err + } + defer conn.Close() + err = sleepyQ(conn, 29001111) + if err != nil { + msg := fmt.Sprintf("test failed, throttle down didn't restore") + fmt.Printf("%s", msg) + return fmt.Errorf("%s", msg) + } + + stop = 1 + // tolerate soft eviction on normal client when we did not use bind eviction + if len(normCliErrStr) != 0 { + return NormCliErr() + } // */ + return nil +} + +func mkClients(num int, stop *int, bindV int, grpName string, outErr *string, db *sql.DB) { + for i := 0; i < num; i++ { + go func(clientId int) { + count := 0 + var conn *sql.Conn + var err error + var curErr string + for *stop == 0 { + nowStr := time.Now().Format("15:04:05.000000 ") + if conn == nil { + conn, err = db.Conn(context.Background()) + fmt.Printf("%s connected %d\n", grpName, clientId) + if err != nil { + fmt.Printf("%s %s Error %d conn %s\n", nowStr, grpName, clientId, err.Error()) + time.Sleep(7 * time.Millisecond) + continue + } + } + + fmt.Printf("%s %s %d loop%d %s\n", nowStr, grpName, clientId, count, time.Now().Format("20060102j150405.000000")) + err := sleepyQ(conn, bindV) + if err != nil { + if err.Error() == curErr { + fmt.Printf("%s %s %d same err twice\n", nowStr, grpName, clientId) + conn.Close() + conn = nil + } else { + curErr = err.Error() + *outErr = curErr + fmt.Printf("%s %s %d err %s\n", nowStr, grpName, clientId, curErr) + } + } + count++ + time.Sleep(10 * time.Millisecond) + } + fmt.Printf("%s %s %d END loop%d\n", time.Now().Format("15:04:05.000000 "), grpName, clientId, count) + }(i) + } +} + +func TestBindThrottle(t *testing.T) { + // we would like to clear hera.log, but even if we try, lots of messages still go there + logger.GetLogger().Log(logger.Debug, "BindThrottle +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n") + err := partialBadLoad(0.10) + if err != nil && err != NormCliErr() { + t.Fatalf("main step function returned err %s", err.Error()) + } + if testutil.RegexCountFile("BIND_THROTTLE", "cal.log") > 0 { + t.Fatalf("BIND_THROTTLE should not trigger") + } + if testutil.RegexCountFile("BIND_EVICT", "cal.log") > 0 { + t.Fatalf("BIND_EVICT should not trigger") + } + if testutil.RegexCountFile("HERA-10", "hera.log") == 0 { + t.Fatal("backlog timeout or saturation was not triggered") + } // */ + + logger.GetLogger().Log(logger.Debug, "BindThrottle midpt +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n") + err = partialBadLoad(0.8) + if err != nil { + // t.Fatalf("main step function returned err %s", err.Error()) // can be triggered since test only has one sql + } + if testutil.RegexCountFile("BIND_THROTTLE", "cal.log") < 0 { + t.Fatalf("BIND_THROTTLE should trigger") + } + if testutil.RegexCountFile("BIND_EVICT", "cal.log") == 0 { + t.Fatalf("BIND_EVICT should trigger") + } + + if testutil.RegexCountFile(".*BIND_EVICT\t1354401077\t1.*", "cal.log") < 1 { + t.Fatalf("BIND_EVICT should trigger for SQL HASH 1354401077") + } + + if testutil.RegexCountFile(".*BIND_THROTTLE\t1354401077\t1.*", "cal.log") < 1 { + t.Fatalf("BIND_THROTTLE should trigger for SQL HASH 1354401077") + } + logger.GetLogger().Log(logger.Debug, "BindThrottle done +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n") +} // */ diff --git a/tests/unittest/querybindblocker/main_test.go b/tests/unittest/querybindblocker/main_test.go index a4defd1b..595a7eab 100644 --- a/tests/unittest/querybindblocker/main_test.go +++ b/tests/unittest/querybindblocker/main_test.go @@ -15,7 +15,7 @@ import ( var mx testutil.Mux func cfg() (map[string]string, map[string]string, testutil.WorkerType) { - fmt.Println ("setup() begin") + fmt.Println("setup() begin") appcfg := make(map[string]string) // best to chose an "unique" port in case golang runs tests in paralel appcfg["bind_port"] = "31002" @@ -29,7 +29,7 @@ func cfg() (map[string]string, map[string]string, testutil.WorkerType) { opscfg["opscfg.default.server.log_level"] = "5" if os.Getenv("WORKER") == "postgres" { return appcfg, opscfg, testutil.PostgresWorker - } + } return appcfg, opscfg, testutil.MySQLWorker } @@ -59,31 +59,45 @@ func TestQueryBindBlocker(t *testing.T) { ctx := context.Background() // cleanup and insert one row in the table - conn, err := db.Conn(ctx); + conn, err := db.Conn(ctx) if err != nil { t.Fatalf("Error getting connection %s\n", err.Error()) } if true { - tx0,err := conn.BeginTx(ctx, nil) - if err != nil { t.Fatalf("tx0 %s", err.Error()) } - stmtD,err := tx0.PrepareContext(ctx, "delete from hera_rate_limiter") - if err != nil { t.Fatalf("stmtD %s", err.Error()) } + tx0, err := conn.BeginTx(ctx, nil) + if err != nil { + t.Fatalf("tx0 %s", err.Error()) + } + stmtD, err := tx0.PrepareContext(ctx, "delete from hera_rate_limiter") + if err != nil { + t.Fatalf("stmtD %s", err.Error()) + } _, err = stmtD.Exec() - if err != nil { t.Fatalf("stmtD exec %s", err.Error()) } + if err != nil { + t.Fatalf("stmtD exec %s", err.Error()) + } err = tx0.Commit() - if err != nil { t.Fatalf("commit0 %s", err.Error()) } + if err != nil { + t.Fatalf("commit0 %s", err.Error()) + } - tx,err := conn.BeginTx(ctx, nil) - if err != nil { t.Fatalf("tx %s", err.Error()) } - stmt,err := tx.PrepareContext(ctx, "/*setup qbb t*/delete from qbb_test") - if err != nil { t.Fatalf("prep %s", err.Error()) } + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + t.Fatalf("tx %s", err.Error()) + } + stmt, err := tx.PrepareContext(ctx, "/*setup qbb t*/delete from qbb_test") + if err != nil { + t.Fatalf("prep %s", err.Error()) + } _, err = stmt.Exec() if err != nil { t.Fatalf("Error preparing test (delete table) %s\n", err.Error()) } - stmt,err = tx.PrepareContext(ctx, "/*setup qbb t*/insert into qbb_test(id, note) VALUES(?, ?)") - if err != nil { t.Fatalf("prep ins %s", err.Error()) } + stmt, err = tx.PrepareContext(ctx, "/*setup qbb t*/insert into qbb_test(id, note) VALUES(?, ?)") + if err != nil { + t.Fatalf("prep ins %s", err.Error()) + } _, err = stmt.Exec(11, "eleven") if err != nil { t.Fatalf("Error preparing test (create row in table) %s\n", err.Error()) @@ -95,13 +109,15 @@ func TestQueryBindBlocker(t *testing.T) { } if true { - tx,err := conn.BeginTx(ctx, nil) - if err != nil { t.Fatalf("tx findQ %s", err.Error()) } - stmt,err := tx.PrepareContext(ctx, "/*qbb_test.find*/select id, note from qbb_test where id=? for update") + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + t.Fatalf("tx findQ %s", err.Error()) + } + stmt, err := tx.PrepareContext(ctx, "/*qbb_test.find*/select id, note from qbb_test where id=? for update") if err != nil { t.Fatalf("Error prep sel %s\n", err.Error()) } - rows,err := stmt.Query(11) + rows, err := stmt.Query(11) if err != nil { t.Fatalf("Error query sel %s\n", err.Error()) } @@ -109,20 +125,30 @@ func TestQueryBindBlocker(t *testing.T) { t.Fatalf("Expected 1 row") } err = tx.Rollback() - if err != nil { t.Fatalf("rollback error %s", err.Error()) } + if err != nil { + t.Fatalf("rollback error %s", err.Error()) + } } // above baseline checks fmt.Printf("DONE DONE baseline check\n") if true { - tx0,err := conn.BeginTx(ctx, nil) - if err != nil { t.Fatalf("tx0 %s", err.Error()) } - stmtD,err := tx0.PrepareContext(ctx, "delete from hera_rate_limiter") - if err != nil { t.Fatalf("prep stmtD %s", err.Error()) } - _,err = stmtD.Exec() - if err != nil { t.Fatalf("stmtD %s", err.Error()) } - stmt,err := tx0.PrepareContext(ctx, "insert into hera_rate_limiter (herasqlhash, herasqltext, bindvarname, bindvarvalue, blockperc, heramodule, end_time, remarks) values ( ?, ?, ?, ?, ?, ?, ?, ?)") - if err != nil { t.Fatalf("ins prep %s", err.Error()) } + tx0, err := conn.BeginTx(ctx, nil) + if err != nil { + t.Fatalf("tx0 %s", err.Error()) + } + stmtD, err := tx0.PrepareContext(ctx, "delete from hera_rate_limiter") + if err != nil { + t.Fatalf("prep stmtD %s", err.Error()) + } + _, err = stmtD.Exec() + if err != nil { + t.Fatalf("stmtD %s", err.Error()) + } + stmt, err := tx0.PrepareContext(ctx, "insert into hera_rate_limiter (herasqlhash, herasqltext, bindvarname, bindvarvalue, blockperc, heramodule, end_time, remarks) values ( ?, ?, ?, ?, ?, ?, ?, ?)") + if err != nil { + t.Fatalf("ins prep %s", err.Error()) + } _, err = stmt.Exec(51938198, "/*qbb_test.find*/selec", "p1", @@ -131,74 +157,106 @@ func TestQueryBindBlocker(t *testing.T) { "hera-test", 2000111222, "block100") - if err != nil { t.Fatalf("ins exec %s", err.Error()) } + if err != nil { + t.Fatalf("ins exec %s", err.Error()) + } err = tx0.Commit() - if err != nil { t.Fatalf("commit tx0 %s", err.Error()) } + if err != nil { + t.Fatalf("commit tx0 %s", err.Error()) + } fmt.Printf("wait wait: loading basic block\n") time.Sleep(12 * time.Second) - tx,err := conn.BeginTx(ctx, nil) - if err != nil { t.Fatalf("tx %s", err.Error()) } - stmt,err = tx.PrepareContext(ctx, "/*qbb_test.find*/select id, note from qbb_test where id=? for update") + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + t.Fatalf("tx %s", err.Error()) + } + stmt, err = tx.PrepareContext(ctx, "/*qbb_test.find*/select id, note from qbb_test where id=? for update") if err != nil { t.Fatalf("Error prep sel %s\n", err.Error()) } - _,err = stmt.Query(11) + _, err = stmt.Query(11) if err == nil { t.Fatalf("Error query should have been blocked") } tx.Rollback() // can have error because connection could be closed - conn, err = db.Conn(ctx); - if err != nil { t.Fatalf("conn %s", err.Error()) } + conn, err = db.Conn(ctx) + if err != nil { + t.Fatalf("conn %s", err.Error()) + } } if true { - tx0,err := conn.BeginTx(ctx, nil) - if err != nil { t.Fatalf("tx0 %s", err.Error()) } - stmtD,err := tx0.PrepareContext(ctx, "delete from hera_rate_limiter") - if err != nil { t.Fatalf("prep err %s", err.Error()) } + tx0, err := conn.BeginTx(ctx, nil) + if err != nil { + t.Fatalf("tx0 %s", err.Error()) + } + stmtD, err := tx0.PrepareContext(ctx, "delete from hera_rate_limiter") + if err != nil { + t.Fatalf("prep err %s", err.Error()) + } _, err = stmtD.Exec() - if err != nil { t.Fatalf("stmtD %s", err.Error()) } - stmt,err := tx0.PrepareContext(ctx, "insert into hera_rate_limiter (herasqlhash, herasqltext, bindvarname, bindvarvalue, blockperc, heramodule, end_time, remarks) values ( ?, ?, ?, ?, ?, ?, 2000111222, ?)") - if err != nil { t.Fatalf("prep ins %s", err.Error()) } + if err != nil { + t.Fatalf("stmtD %s", err.Error()) + } + stmt, err := tx0.PrepareContext(ctx, "insert into hera_rate_limiter (herasqlhash, herasqltext, bindvarname, bindvarvalue, blockperc, heramodule, end_time, remarks) values ( ?, ?, ?, ?, ?, ?, 2000111222, ?)") + if err != nil { + t.Fatalf("prep ins %s", err.Error()) + } _, err = stmt.Exec(51938197, "/*qbb_test.find*/select id, note from qbb_test where id=:p1 for upd", "p1", "11", 100, "hera-test", "WrongHash") - if err != nil { t.Fatalf("exec1 %s", err.Error()) } + if err != nil { + t.Fatalf("exec1 %s", err.Error()) + } _, err = stmt.Exec(51938198, "/*bb_test.find*/select id, note from qbb_test where id=:p1 for upd", "p1", "11", 100, "hera-test", "WrongSqlText") - if err != nil { t.Fatalf("exec2 %s", err.Error()) } + if err != nil { + t.Fatalf("exec2 %s", err.Error()) + } _, err = stmt.Exec(51938198, "/*bb_test.find*/select id, note from qbb_test where id=:p1 for upd", "notId", "11", 100, "hera-test", "WrongBindName") - if err != nil { t.Fatalf("exec3 %s", err.Error()) } + if err != nil { + t.Fatalf("exec3 %s", err.Error()) + } _, err = stmt.Exec(51938198, "/*bb_test.find*/select id, note from qbb_test where id=:p1 for upd", "p1", "333", 100, "hera-test", "WrongBindVal") - if err != nil { t.Fatalf("exec4 %s", err.Error()) } + if err != nil { + t.Fatalf("exec4 %s", err.Error()) + } _, err = stmt.Exec(51938198, "/*bb_test.find*/select id, note from qbb_test where id=:p1 for upd", "p1", "11", 100, "nothera-test", "WrongBindModule") - if err != nil { t.Fatalf("exec5 %s", err.Error()) } + if err != nil { + t.Fatalf("exec5 %s", err.Error()) + } err = tx0.Commit() - if err != nil { t.Fatalf("tx0 commit %s", err.Error()) } + if err != nil { + t.Fatalf("tx0 commit %s", err.Error()) + } fmt.Printf("wait wait: loading close to block, but ultimately not\n") time.Sleep(12 * time.Second) - tx,err := conn.BeginTx(ctx, nil) - if err != nil { t.Fatalf("begin tx %s", err.Error()) } - stmt,err = tx.PrepareContext(ctx, "/*qbb_test.find*/select id, note from qbb_test where id=? for update") + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + t.Fatalf("begin tx %s", err.Error()) + } + stmt, err = tx.PrepareContext(ctx, "/*qbb_test.find*/select id, note from qbb_test where id=? for update") if err != nil { t.Fatalf("Error prep sel %s\n", err.Error()) } - _,err = stmt.Query(11) + _, err = stmt.Query(11) if err != nil { - t.Fatalf("Error query might have been erroneously blocked %s",err.Error()) + t.Fatalf("Error query might have been erroneously blocked %s", err.Error()) } err = tx.Rollback() - if err != nil { t.Fatalf("rollback %s", err.Error()) } + if err != nil { + t.Fatalf("rollback %s", err.Error()) + } } if true { - tx0,_ := conn.BeginTx(ctx, nil) - stmtD,_ := tx0.PrepareContext(ctx, "delete from hera_rate_limiter") + tx0, _ := conn.BeginTx(ctx, nil) + stmtD, _ := tx0.PrepareContext(ctx, "delete from hera_rate_limiter") stmtD.Exec() - stmt,_ := tx0.PrepareContext(ctx, "insert into hera_rate_limiter (herasqlhash, herasqltext, bindvarname, bindvarvalue, blockperc, heramodule, end_time, remarks) values ( ?, ?, ?, ?, ?, ?, ?, ?)") + stmt, _ := tx0.PrepareContext(ctx, "insert into hera_rate_limiter (herasqlhash, herasqltext, bindvarname, bindvarvalue, blockperc, heramodule, end_time, remarks) values ( ?, ?, ?, ?, ?, ?, ?, ?)") stmt.Exec(51938198, "/*qbb_test.find*/selec", "p1", @@ -213,16 +271,18 @@ func TestQueryBindBlocker(t *testing.T) { time.Sleep(12 * time.Second) countBlock := 0 - for i:=0; i<100; i++ { - conn, err = db.Conn(ctx); - if err != nil { t.Fatalf("conn %s", err.Error()) } + for i := 0; i < 100; i++ { + conn, err = db.Conn(ctx) + if err != nil { + t.Fatalf("conn %s", err.Error()) + } - tx,_ := conn.BeginTx(ctx, nil) - stmt,err := tx.PrepareContext(ctx, "/*qbb_test.find*/select id, note from qbb_test where id=? for update") + tx, _ := conn.BeginTx(ctx, nil) + stmt, err := tx.PrepareContext(ctx, "/*qbb_test.find*/select id, note from qbb_test where id=? for update") if err != nil { t.Fatalf("Error prep sel %s\n", err.Error()) } - _,err = stmt.Query(11) + _, err = stmt.Query(11) if err != nil { countBlock++ } diff --git a/tests/unittest/querybindblocker_ratelimit_table_empty/main_test.go b/tests/unittest/querybindblocker_ratelimit_table_empty/main_test.go new file mode 100644 index 00000000..c07118f2 --- /dev/null +++ b/tests/unittest/querybindblocker_ratelimit_table_empty/main_test.go @@ -0,0 +1,68 @@ +package main + +import ( + "database/sql" + "fmt" + "os" + "testing" + "time" + + "github.com/paypal/hera/tests/unittest/testutil" + "github.com/paypal/hera/utility/logger" +) + +var mx testutil.Mux + +func cfg() (map[string]string, map[string]string, testutil.WorkerType) { + fmt.Println("setup() begin") + appcfg := make(map[string]string) + // best to chose an "unique" port in case golang runs tests in paralel + appcfg["bind_port"] = "31002" + appcfg["log_level"] = "5" + appcfg["log_file"] = "hera.log" + appcfg["rac_sql_interval"] = "0" + appcfg["enable_query_bind_blocker"] = "true" + + opscfg := make(map[string]string) + opscfg["opscfg.default.server.max_connections"] = "3" + opscfg["opscfg.default.server.log_level"] = "5" + if os.Getenv("WORKER") == "postgres" { + return appcfg, opscfg, testutil.PostgresWorker + } + return appcfg, opscfg, testutil.MySQLWorker +} + +func teardown() { + mx.StopServer() +} + +func TestMain(m *testing.M) { + os.Exit(testutil.UtilMain(m, cfg, nil)) +} + +func TestQueryBindBlockerTableNotExistOrEmpty(t *testing.T) { + testutil.RunDML("DROP TABLE IF EXISTS hera_rate_limiter") + + logger.GetLogger().Log(logger.Debug, "TestQueryBindBlockerTableNotExistOrEmpty begin +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n") + + shard := 0 + db, err := sql.Open("heraloop", fmt.Sprintf("%d:0:0", shard)) + if err != nil { + t.Fatal("Error starting Mux:", err) + return + } + db.SetMaxIdleConns(0) + defer db.Close() + + time.Sleep(6 * time.Second) + if testutil.RegexCountFile("loading query bind blocker: SQL error: Error 1146", "hera.log") == 0 { + t.Fatalf("expected to see table 'hera_rate_limiter' doesn't exist error") + } + + testutil.RunDML("create table hera_rate_limiter (herasqlhash numeric not null, herasqltext varchar(4000) not null, bindvarname varchar(200) not null, bindvarvalue varchar(200) not null, blockperc numeric not null, heramodule varchar(100) not null, end_time numeric not null, remarks varchar(200) not null)") + time.Sleep(15 * time.Second) + if testutil.RegexCountFile("Loaded 0 sqlhashes, 0 entries, query bind blocker entries", "hera.log") == 0 { + t.Fatalf("expected to 0 entries from hera_rate_limiter table") + } + logger.GetLogger().Log(logger.Debug, "TestQueryBindBlockerTableNotExistOrEmpty ends +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n") +}