Skip to content

Commit

Permalink
[spark] Use batch predict API (#3242)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 authored Jun 6, 2024
1 parent 57d44f2 commit d75ff83
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row}

import scala.jdk.CollectionConverters.{collectionAsScalaIterableConverter, seqAsJavaListConverter}

/**
* QuestionAnswerer performs question answering on text.
*
* @param uid An immutable unique ID for the object and its derivatives.
*/
class QuestionAnswerer(override val uid: String) extends BaseTextPredictor[Array[QAInput], Array[String]]
class QuestionAnswerer(override val uid: String) extends BaseTextPredictor[QAInput, String]
with HasInputCols with HasOutputCol {

def this() = this(Identifiable.randomUID("QuestionAnswerer"))
Expand All @@ -46,8 +48,8 @@ class QuestionAnswerer(override val uid: String) extends BaseTextPredictor[Array
*/
def setOutputCol(value: String): this.type = set(outputCol, value)

setDefault(inputClass, classOf[Array[QAInput]])
setDefault(outputClass, classOf[Array[String]])
setDefault(inputClass, classOf[QAInput])
setDefault(outputClass, classOf[String])
setDefault(translatorFactory, new QuestionAnsweringTranslatorFactory())

/**
Expand All @@ -72,8 +74,8 @@ class QuestionAnswerer(override val uid: String) extends BaseTextPredictor[Array
val predictor = model.newPredictor()
iter.grouped($(batchSize)).flatMap { batch =>
val inputs = batch.map(row => new QAInput(row.getString(inputColIndices(0)),
row.getString(inputColIndices(1)))).toArray
val output = predictor.predict(inputs)
row.getString(inputColIndices(1)))).asJava
val output = predictor.batchPredict(inputs).asScala
batch.zip(output).map { case (row, out) =>
Row.fromSeq(row.toSeq :+ out)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types.{ArrayType, DoubleType, StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row}

import scala.jdk.CollectionConverters.collectionAsScalaIterableConverter
import scala.jdk.CollectionConverters.{collectionAsScalaIterableConverter, seqAsJavaListConverter}

/**
* TextClassifier performs text classification on text.
*
* @param uid An immutable unique ID for the object and its derivatives.
*/
class TextClassifier(override val uid: String) extends BaseTextPredictor[Array[String], Array[Classifications]]
class TextClassifier(override val uid: String) extends BaseTextPredictor[String, Classifications]
with HasInputCol with HasOutputCol {

def this() = this(Identifiable.randomUID("TextClassifier"))
Expand Down Expand Up @@ -57,8 +57,8 @@ class TextClassifier(override val uid: String) extends BaseTextPredictor[Array[S
*/
def setTopK(value: Int): this.type = set(topK, value)

setDefault(inputClass, classOf[Array[String]])
setDefault(outputClass, classOf[Array[Classifications]])
setDefault(inputClass, classOf[String])
setDefault(outputClass, classOf[Classifications])
setDefault(translatorFactory, new TextClassificationTranslatorFactory())

/**
Expand All @@ -84,8 +84,8 @@ class TextClassifier(override val uid: String) extends BaseTextPredictor[Array[S
override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = {
val predictor = model.newPredictor()
iter.grouped($(batchSize)).flatMap { batch =>
val inputs = batch.map(_.getString(inputColIndex)).toArray
val output = predictor.predict(inputs)
val inputs = batch.map(_.getString(inputColIndex)).asJava
val output = predictor.batchPredict(inputs).asScala
batch.zip(output).map { case (row, out) =>
Row.fromSeq(row.toSeq :+ Row(out.getClassNames.toArray(), out.getProbabilities.toArray(),
out.topK[Classifications.Classification]().asScala.map(_.toString)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types.{ArrayType, FloatType, StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row}

import scala.jdk.CollectionConverters.{collectionAsScalaIterableConverter, seqAsJavaListConverter}

/**
* TextEmbedder performs text embedding on text.
*
* @param uid An immutable unique ID for the object and its derivatives.
*/
class TextEmbedder(override val uid: String) extends BaseTextPredictor[Array[String], Array[Array[Float]]]
class TextEmbedder(override val uid: String) extends BaseTextPredictor[String, Array[Float]]
with HasInputCol with HasOutputCol {

def this() = this(Identifiable.randomUID("TextEmbedder"))
Expand All @@ -44,8 +46,8 @@ class TextEmbedder(override val uid: String) extends BaseTextPredictor[Array[Str
*/
def setOutputCol(value: String): this.type = set(outputCol, value)

setDefault(inputClass, classOf[Array[String]])
setDefault(outputClass, classOf[Array[Array[Float]]])
setDefault(inputClass, classOf[String])
setDefault(outputClass, classOf[Array[Float]])
setDefault(translatorFactory, new TextEmbeddingTranslatorFactory())

/**
Expand All @@ -68,8 +70,8 @@ class TextEmbedder(override val uid: String) extends BaseTextPredictor[Array[Str
override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = {
val predictor = model.newPredictor()
iter.grouped($(batchSize)).flatMap { batch =>
val inputs = batch.map(_.getString(inputColIndex)).toArray
val output = predictor.predict(inputs)
val inputs = batch.map(_.getString(inputColIndex)).asJava
val output = predictor.batchPredict(inputs).asScala
batch.zip(output).map { case (row, out) =>
Row.fromSeq(row.toSeq :+ out)
}
Expand Down

0 comments on commit d75ff83

Please sign in to comment.