Skip to content

Commit

Permalink
Update method name and description
Browse files Browse the repository at this point in the history
Signed-off-by: Rupal Mahajan <maharup@amazon.com>
  • Loading branch information
rupal-bq committed Jun 14, 2023
1 parent f815709 commit ec08d0b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@ package org.opensearch.sql
import org.apache.spark.sql.{DataFrame, SparkSession, Row}
import org.apache.spark.sql.types._

/**
* Spark SQL Application entrypoint
*
* @param args(0)
* sql query
* @param args(1)
* opensearch index name
* @param args(2-6)
* opensearch connection values required for flint-integration jar. host, port, scheme, auth, region respectively.
* @return
* write sql query result to given opensearch index
*/
object SQLJob {
def main(args: Array[String]) {
// Get the SQL query and Opensearch Config from the command line arguments
Expand All @@ -27,7 +39,7 @@ object SQLJob {
val result: DataFrame = spark.sql(query)

// Get Data
val data = getData(result, spark)
val data = getFormattedData(result, spark)

// Write data to OpenSearch index
val aos = Map(
Expand All @@ -49,7 +61,17 @@ object SQLJob {
}
}

def getData(result: DataFrame, spark: SparkSession): DataFrame = {
/**
* Create a new formatted dataframe with json result, json schema and EMR_STEP_ID.
*
* @param result
* sql query result dataframe
* @param spark
* spark session
* @return
* dataframe with result, schema and emr step id
*/
def getFormattedData(result: DataFrame, spark: SparkSession): DataFrame = {
// Create the schema dataframe
val schemaRows = result.schema.fields.map { field =>
Row(field.name, field.dataType.typeName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class SQLJobTest extends AnyFunSuite{
)
val input: DataFrame = spark.createDataFrame(spark.sparkContext.parallelize(inputRows), inputSchema)

test("Test getData method") {
test("Test getFormattedData method") {
// Define expected dataframe
val expectedSchema = StructType(Seq(
StructField("result", ArrayType(StringType, containsNull = true), nullable = true),
Expand All @@ -43,7 +43,7 @@ class SQLJobTest extends AnyFunSuite{
val expected: DataFrame = spark.createDataFrame(spark.sparkContext.parallelize(expectedRows), expectedSchema)

// Compare the result
val result = SQLJob.getData(input, spark)
val result = SQLJob.getFormattedData(input, spark)
assertEqualDataframe(expected, result)
}

Expand Down

0 comments on commit ec08d0b

Please sign in to comment.