Skip to content

Commit

Permalink
Refactor FlintJob with FlintStatement and StatementExecutionManager (#…
Browse files Browse the repository at this point in the history
…635)

* Refactor FlintJob with FlintStatement and StatementExecutionManager

Signed-off-by: Louis Chu <clingzhi@amazon.com>

* Resolve comments

Signed-off-by: Louis Chu <clingzhi@amazon.com>

---------

Signed-off-by: Louis Chu <clingzhi@amazon.com>
  • Loading branch information
noCharger committed Sep 10, 2024
1 parent 9ac50a7 commit 7d8f2b0
Show file tree
Hide file tree
Showing 11 changed files with 209 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@ trait StatementExecutionManager {
/**
* Terminates the statement lifecycle.
*/
def terminateStatementsExecution(): Unit
def terminateStatementExecution(): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ object FlintSparkConf {
FlintConfig("spark.flint.job.query")
.doc("Flint query for batch and streaming job")
.createOptional()
val QUERY_ID =
FlintConfig("spark.flint.job.queryId")
.doc("Flint query id for batch and streaming job")
.createOptional()
val JOB_TYPE =
FlintConfig(s"spark.flint.job.type")
.doc("Flint job type. Including interactive and streaming")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
val resultIndex = "query_results2"
val appId = "00feq82b752mbt0p"
val dataSourceName = "my_glue1"
val queryId = "testQueryId"
var osClient: OSClient = _
val threadLocalFuture = new ThreadLocal[Future[Unit]]()

Expand Down Expand Up @@ -91,7 +92,7 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
* all Spark conf required by Flint code underlying manually.
*/
spark.conf.set(DATA_SOURCE_NAME.key, dataSourceName)
spark.conf.set(JOB_TYPE.key, "streaming")
spark.conf.set(JOB_TYPE.key, FlintJobType.STREAMING)

/**
* FlintJob.main() is not called because we need to manually set these variables within a
Expand All @@ -103,9 +104,10 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
jobRunId,
spark,
query,
queryId,
dataSourceName,
resultIndex,
true,
FlintJobType.STREAMING,
streamingRunningCount)
job.terminateJVM = false
job.start()
Expand Down Expand Up @@ -144,7 +146,6 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {

assert(result.status == "SUCCESS", s"expected status is SUCCESS, but got ${result.status}")
assert(result.error.isEmpty, s"we don't expect error, but got ${result.error}")
assert(result.queryId.isEmpty, s"we don't expect query id, but got ${result.queryId}")

commonAssert(result, jobRunId, query, queryStartTime)
true
Expand Down Expand Up @@ -362,7 +363,9 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
result.queryRunTime < System.currentTimeMillis() - queryStartTime,
s"expected query run time ${result.queryRunTime} should be less than ${System
.currentTimeMillis() - queryStartTime}, but it is not")
assert(result.queryId.isEmpty, s"we don't expect query id, but got ${result.queryId}")
assert(
result.queryId == queryId,
s"expected query id is ${queryId}, but got ${result.queryId}")
}

def pollForResultAndAssert(expected: REPLResult => Boolean, jobId: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ case class CommandContext(
jobId: String,
spark: SparkSession,
dataSource: String,
jobType: String,
sessionId: String,
sessionManager: SessionManager,
queryResultWriter: QueryResultWriter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@ import java.util.concurrent.atomic.AtomicInteger
import org.opensearch.flint.core.logging.CustomLogging
import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.registerGauge
import play.api.libs.json._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.types._

/**
* Spark SQL Application entrypoint
Expand All @@ -32,7 +30,7 @@ object FlintJob extends Logging with FlintJobExecutor {
val (queryOption, resultIndexOption) = parseArgs(args)

val conf = createSparkConf()
val jobType = conf.get("spark.flint.job.type", "batch")
val jobType = conf.get("spark.flint.job.type", FlintJobType.BATCH)
CustomLogging.logInfo(s"""Job type is: ${jobType}""")
conf.set(FlintSparkConf.JOB_TYPE.key, jobType)

Expand All @@ -41,6 +39,8 @@ object FlintJob extends Logging with FlintJobExecutor {
if (query.isEmpty) {
logAndThrow(s"Query undefined for the ${jobType} job.")
}
val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "")

if (resultIndexOption.isEmpty) {
logAndThrow("resultIndex is not set")
}
Expand All @@ -66,9 +66,10 @@ object FlintJob extends Logging with FlintJobExecutor {
jobId,
createSparkSession(conf),
query,
queryId,
dataSource,
resultIndexOption.get,
jobType.equalsIgnoreCase("streaming"),
jobType,
streamingRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
jobOperator.start()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,20 @@ import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.flint.config.FlintSparkConf.REFRESH_POLICY
import org.apache.spark.sql.types._
import org.apache.spark.sql.util._
import org.apache.spark.util.Utils

object SparkConfConstants {
val SQL_EXTENSIONS_KEY = "spark.sql.extensions"
val DEFAULT_SQL_EXTENSIONS =
"org.opensearch.flint.spark.FlintPPLSparkExtensions,org.opensearch.flint.spark.FlintSparkExtensions"
}

object FlintJobType {
val INTERACTIVE = "interactive"
val BATCH = "batch"
val STREAMING = "streaming"
}

trait FlintJobExecutor {
this: Logging =>

Expand Down Expand Up @@ -131,7 +138,7 @@ trait FlintJobExecutor {
* https://github.com/opensearch-project/opensearch-spark/issues/324
*/
def configDYNMaxExecutors(conf: SparkConf, jobType: String): Unit = {
if (jobType.equalsIgnoreCase("streaming")) {
if (jobType.equalsIgnoreCase(FlintJobType.STREAMING)) {
conf.set(
"spark.dynamicAllocation.maxExecutors",
conf
Expand Down Expand Up @@ -524,4 +531,25 @@ trait FlintJobExecutor {
CustomLogging.logError(t)
throw t
}

def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = {
if (className.isEmpty) {
defaultConstructor
} else {
try {
val classObject = Utils.classForName(className)
val ctor = if (args.isEmpty) {
classObject.getDeclaredConstructor()
} else {
classObject.getDeclaredConstructor(args.map(_.getClass.asInstanceOf[Class[_]]): _*)
}
ctor.setAccessible(true)
ctor.newInstance(args.map(_.asInstanceOf[Object]): _*).asInstanceOf[T]
} catch {
case e: Exception =>
throw new RuntimeException(s"Failed to instantiate provider: $className", e)
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.FlintREPLConfConstants._
import org.apache.spark.sql.SessionUpdateMode._
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.util.{ThreadUtils, Utils}
import org.apache.spark.util.ThreadUtils

object FlintREPLConfConstants {
val HEARTBEAT_INTERVAL_MILLIS = 60000L
Expand Down Expand Up @@ -87,8 +87,9 @@ object FlintREPL extends Logging with FlintJobExecutor {
conf.set(FlintSparkConf.JOB_TYPE.key, jobType)

val query = getQuery(queryOption, jobType, conf)
val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "")

if (jobType.equalsIgnoreCase("streaming")) {
if (jobType.equalsIgnoreCase(FlintJobType.STREAMING)) {
if (resultIndexOption.isEmpty) {
logAndThrow("resultIndex is not set")
}
Expand All @@ -100,9 +101,10 @@ object FlintREPL extends Logging with FlintJobExecutor {
jobId,
createSparkSession(conf),
query,
queryId,
dataSource,
resultIndexOption.get,
true,
jobType,
streamingRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
jobOperator.start()
Expand Down Expand Up @@ -174,6 +176,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
jobId,
spark,
dataSource,
jobType,
sessionId,
sessionManager,
queryResultWriter,
Expand Down Expand Up @@ -220,7 +223,7 @@ object FlintREPL extends Logging with FlintJobExecutor {

def getQuery(queryOption: Option[String], jobType: String, conf: SparkConf): String = {
queryOption.getOrElse {
if (jobType.equalsIgnoreCase("streaming")) {
if (jobType.equalsIgnoreCase(FlintJobType.STREAMING)) {
val defaultQuery = conf.get(FlintSparkConf.QUERY.key, "")
if (defaultQuery.isEmpty) {
logAndThrow("Query undefined for the streaming job.")
Expand Down Expand Up @@ -352,7 +355,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
canPickUpNextStatement = updatedCanPickUpNextStatement
lastCanPickCheckTime = updatedLastCanPickCheckTime
} finally {
statementsExecutionManager.terminateStatementsExecution()
statementsExecutionManager.terminateStatementExecution()
}

Thread.sleep(commandContext.queryLoopExecutionFrequency)
Expand Down Expand Up @@ -975,26 +978,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
}
}

private def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = {
if (className.isEmpty) {
defaultConstructor
} else {
try {
val classObject = Utils.classForName(className)
val ctor = if (args.isEmpty) {
classObject.getDeclaredConstructor()
} else {
classObject.getDeclaredConstructor(args.map(_.getClass.asInstanceOf[Class[_]]): _*)
}
ctor.setAccessible(true)
ctor.newInstance(args.map(_.asInstanceOf[Object]): _*).asInstanceOf[T]
} catch {
case e: Exception =>
throw new RuntimeException(s"Failed to instantiate provider: $className", e)
}
}
}

private def instantiateSessionManager(
spark: SparkSession,
resultIndexOption: Option[String]): SessionManager = {
Expand Down
Loading

0 comments on commit 7d8f2b0

Please sign in to comment.