Skip to content

Commit

Permalink
Add Spark extension example
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Dec 21, 2022
1 parent 9467bf0 commit 3e0fd3f
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 211 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ An example application that runs low cost/high performance inference with AWS In

## Big data integration

### [Spark Image Classification](apache-spark/image-classification/README.md)
### [Spark Image Classification](apache-spark/spark3.0/image-classification/README.md)

Contains Spark image classification demos.

Expand Down
140 changes: 26 additions & 114 deletions apache-spark/notebook/Image_Classification_Spark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Image Classification with Apache Spark\n",
"# Image Classification with DJL Spark Support\n",
"\n",
"In this example, we will use Jupyter Notebook to run image Classification with Apache Spark on Scala. To execute this Scala kernel successfully, you need to install [Almond](https://almond.sh/), a Scala kernel for Jupyter Notebook. Almond provide extensive functionalities for Scala and Spark applications.\n",
"In this example, we will use Jupyter Notebook to run image Classification with DJL Spark extension on Scala. To execute this Scala kernel successfully, you need to install [Almond](https://almond.sh/), a Scala kernel for Jupyter Notebook. Almond provide extensive functionalities for Scala and Spark applications.\n",
"\n",
"[Almond installation instruction](https://almond.sh/docs/quick-start-install) (Note: only Scala 2.12 are tested)\n",
"\n",
Expand All @@ -24,13 +24,13 @@
"metadata": {},
"outputs": [],
"source": [
"import $ivy.`org.apache.spark::spark-sql:3.0.1`\n",
"import $ivy.`org.apache.spark::spark-mllib:3.0.1`\n",
"import $ivy.`ai.djl:api:0.10.0`\n",
"import $ivy.`ai.djl.pytorch:pytorch-model-zoo:0.10.0`\n",
"import $ivy.`ai.djl.pytorch:pytorch-native-auto:1.7.1`\n",
"// import $ivy.`ai.djl.mxnet:mxnet-model-zoo:0.10.0`\n",
"// import $ivy.`ai.djl.mxnet:mxnet-native-auto:1.7.0-backport`"
"import $ivy.`org.apache.spark::spark-sql:3.2.2`\n",
"import $ivy.`ai.djl:api:0.20.0`\n",
"import $ivy.`ai.djl.spark:spark:0.20.0`\n",
"import $ivy.`ai.djl.pytorch:pytorch-model-zoo:0.20.0`\n",
"import $ivy.`ai.djl.pytorch:pytorch-native-cpu-precxx11:1.13.0`\n",
"// import $ivy.`ai.djl.mxnet:mxnet-engine:0.20.0`\n",
"// import $ivy.`ai.djl.mxnet:mxnet-native-mkl:1.9.1`"
]
},
{
Expand All @@ -46,106 +46,17 @@
"metadata": {},
"outputs": [],
"source": [
"import java.util\n",
"import ai.djl.Model\n",
"import org.apache.spark.sql.NotebookSparkSession\n",
"import ai.djl.modality.Classifications\n",
"import ai.djl.modality.cv.transform.{ Resize, ToTensor}\n",
"import ai.djl.ndarray.types.{DataType, Shape}\n",
"import ai.djl.ndarray.{NDList, NDManager}\n",
"import ai.djl.repository.zoo.{Criteria, ModelZoo, ZooModel}\n",
"import ai.djl.training.util.ProgressBar\n",
"import ai.djl.translate.{Batchifier, Pipeline, Translator, TranslatorContext}\n",
"import ai.djl.util.Utils\n",
"import org.apache.spark.ml.image.ImageSchema\n",
"import org.apache.spark.sql.functions.col\n",
"import org.apache.spark.sql.{Encoders, Row, NotebookSparkSession}\n",
"import ai.djl.spark.SparkTransformer\n",
"import ai.djl.spark.translator.SparkImageClassificationTranslator\n",
"import org.apache.spark.sql.SparkSession\n",
"\n",
"import org.apache.log4j.{Level, Logger}\n",
"Logger.getLogger(\"org\").setLevel(Level.OFF) // avoid too much message popping out\n",
"Logger.getLogger(\"ai\").setLevel(Level.OFF) // avoid too much message popping out"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create Translator\n",
"\n",
"A Translator in DJL is used to define the preprocessing and postprocessing logic. The following code is to \n",
"\n",
"- preprocess: convert a Spark DataFrame Row to DJL NDArray.\n",
"- postprocess: convert inference result to classifications"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
" // Translator: a class used to do preprocessing and post processing\n",
" class MyTranslator extends Translator[Row, Classifications] {\n",
"\n",
" private var classes: java.util.List[String] = new util.ArrayList[String]()\n",
" private val pipeline: Pipeline = new Pipeline()\n",
" .add(new Resize(224, 224))\n",
" .add(new ToTensor())\n",
"\n",
" override def prepare(manager: NDManager, model: Model): Unit = {\n",
" classes = Utils.readLines(model.getArtifact(\"synset.txt\").openStream())\n",
" }\n",
"\n",
" override def processInput(ctx: TranslatorContext, row: Row): NDList = {\n",
"\n",
" val height = ImageSchema.getHeight(row)\n",
" val width = ImageSchema.getWidth(row)\n",
" val channel = ImageSchema.getNChannels(row)\n",
" var image = ctx.getNDManager.create(ImageSchema.getData(row), new Shape(height, width, channel)).toType(DataType.UINT8, true)\n",
" // BGR to RGB\n",
" image = image.flip(2)\n",
" pipeline.transform(new NDList(image))\n",
" }\n",
"\n",
" // Deal with the output.,NDList contains output result, usually one or more NDArray(s).\n",
" override def processOutput(ctx: TranslatorContext, list: NDList): Classifications = {\n",
" var probabilitiesNd = list.singletonOrThrow\n",
" probabilitiesNd = probabilitiesNd.softmax(0)\n",
" new Classifications(classes, probabilitiesNd)\n",
" }\n",
"\n",
" override def getBatchifier: Batchifier = Batchifier.STACK\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the model\n",
"\n",
"Now, we just need to fetch the model from a URL. The url can be a hdfs (hdfs://), file (file://) or http (https://) format. We use Criteria as a container to store the model and translator information. Then, all we need to do is to load the model from it.\n",
"\n",
"Note: DJL Criteria and Model are not serializable, so we add `lazy` declaration.\n",
"\n",
"If you are using MXNet as the backend engine, plase uncomment the mxnet model url."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"val modelUrl = \"https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/pytorch_resnet18.zip?model_name=traced_resnet18\"\n",
"// val modelUrl = \"https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/mxnet_resnet18.zip?model_name=resnet18_v1\"\n",
"lazy val criteria = Criteria.builder\n",
" .setTypes(classOf[Row], classOf[Classifications])\n",
" .optModelUrls(modelUrl)\n",
" .optTranslator(new MyTranslator())\n",
" .optProgress(new ProgressBar)\n",
" .build()\n",
"lazy val model = ModelZoo.loadModel(criteria)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -173,7 +84,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's try to load the images from the local folder using Spark ML library:"
"Let's try to load the images from the local folder using Spark library:"
]
},
{
Expand All @@ -190,7 +101,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Then We can run inference on these images. All we need to do is to create a `Predictor` and run inference with DJL."
"Then We can run inference on these images. All we need to do is to create a `SparkTransformer` and run inference with DJL."
]
},
{
Expand All @@ -199,14 +110,15 @@
"metadata": {},
"outputs": [],
"source": [
"val result = df.select(col(\"image.*\")).mapPartitions(partition => {\n",
" val predictor = model.newPredictor()\n",
" partition.map(row => {\n",
" // image data stored as HWC format\n",
" predictor.predict(row).toString\n",
" })\n",
"})(Encoders.STRING)\n",
"println(result.collect().mkString(\"\\n\"))"
"System.setProperty(\"PYTORCH_PRECXX11\", \"true\")\n",
"val transformer = new SparkTransformer[Classifications]()\n",
" .setInputCol(\"image.*\")\n",
" .setOutputCol(\"value\")\n",
" .setModelUrl(\"https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/pytorch_resnet18.zip?model_name=traced_resnet18\")\n",
" .setOutputClass(classOf[Classifications])\n",
" .setTranslator(new SparkImageClassificationTranslator())\n",
"val outputDf = transformer.transform(df)\n",
"println(outputDf.collect().mkString(\"\\n\"))"
]
}
],
Expand All @@ -222,7 +134,7 @@
"mimetype": "text/x-scala",
"name": "scala",
"nbconvert_exporter": "script",
"version": "2.12.12"
"version": "2.12.11"
}
},
"nbformat": 4,
Expand Down
16 changes: 8 additions & 8 deletions apache-spark/spark3.0/image-classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ We use PyTorch engine by default, you can switch to MXNet or TensorFlow Engine a

#### MXNet
```
libraryDependencies += "ai.djl.mxnet" % "mxnet-model-zoo" % "0.10.0"
libraryDependencies += "ai.djl.mxnet" % "mxnet-native-auto" % "1.7.0-backport"
libraryDependencies += "ai.djl.mxnet" % "mxnet-engine" % "0.20.0"
libraryDependencies += "ai.djl.mxnet" % "mxnet-native-mkl" % "1.9.1"
```

#### TensorFlow
```
libraryDependencies += "ai.djl.tensorflow" % "tensorflow-model-zoo" % "0.10.0"
libraryDependencies += "ai.djl.tensorflow" % "tensorflow-native-auto" % "2.3.0"
libraryDependencies += "ai.djl.tensorflow" % "tensorflow-engine" % "0.20.0"
libraryDependencies += "ai.djl.tensorflow" % "tensorflow-native-cpu" % "2.7.0"
```

### gradle
Expand All @@ -32,14 +32,14 @@ You should add these in `dependencies`

#### MXNet
```
runtimeOnly "ai.djl.mxnet:mxnet-model-zoo:0.10.0"
runtimeOnly "ai.djl.mxnet:mxnet-native-auto:1.7.0-backport"
runtimeOnly "ai.djl.mxnet:mxnet-engine:0.20.0"
runtimeOnly "ai.djl.mxnet:mxnet-native-mkl:1.9.1"
```

#### TensorFlow
```
runtimeOnly "ai.djl.tensorflow:tensorflow-model-zoo:0.10.0"
runtimeOnly "ai.djl.tensorflow:tensorflow-native-auto:2.3.0"
runtimeOnly "ai.djl.tensorflow:tensorflow-engine:0.20.0"
runtimeOnly "ai.djl.tensorflow:tensorflow-native-cpu:2.7.0"
```

Apart from that, you may also need to use `NaiveEngine` mode in MXNet for multi-thread inference. You can also add `OMP_NUM_THREAD` environment variable to have the best performance optimization.
Expand Down
25 changes: 18 additions & 7 deletions apache-spark/spark3.0/image-classification/build.gradle
Original file line number Diff line number Diff line change
@@ -1,32 +1,43 @@
plugins {
id 'scala'
id 'application'
id 'com.github.johnrengelman.shadow' version '7.1.2'
}

group "com.example"
version "1.0-SNAPSHOT"

repositories {
mavenCentral()
mavenLocal()
}

dependencies {
implementation platform("ai.djl:bom:0.12.0")
implementation "org.apache.spark:spark-sql_2.12:3.0.1"
implementation "org.apache.spark:spark-mllib_2.12:3.0.1"
implementation platform("ai.djl:bom:0.20.0")
implementation "ai.djl:api"


runtimeOnly "ai.djl.pytorch:pytorch-model-zoo"
runtimeOnly "ai.djl.pytorch:pytorch-native-auto"
implementation "ai.djl.spark:spark"
runtimeOnly "ai.djl.pytorch:pytorch-engine"
runtimeOnly "ai.djl.pytorch:pytorch-native-cpu-precxx11"
}

compileScala {
scalaCompileOptions.setAdditionalParameters(["-target:jvm-1.8"])
}

application {
sourceCompatibility = JavaVersion.VERSION_1_8
targetCompatibility = JavaVersion.VERSION_1_8

getMainClass().set(System.getProperty("main", "com.examples.ImageClassificationExample"))
}

shadowJar {
zip64 true
mergeServiceFiles()
exclude "META-INF/*.SF"
exclude 'META-INF/*.DSA'
exclude 'META-INF/*.RSA'
exclude "LICENSE*"
}

tasks.distTar.enabled = false
17 changes: 8 additions & 9 deletions apache-spark/spark3.0/image-classification/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@ scalacOptions += "-target:jvm-1.8"

resolvers += Resolver.jcenterRepo

libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.0.1"
libraryDependencies += "org.apache.spark" %% "spark-mllib" % "3.0.1"
libraryDependencies += "ai.djl" % "api" % "0.12.0"
libraryDependencies += "ai.djl" % "api" % "0.20.0"
libraryDependencies += "ai.djl.spark" % "spark" % "0.20.0"

// libraryDependencies += "ai.djl.mxnet" % "mxnet-model-zoo" % "0.12.0"
// libraryDependencies += "ai.djl.mxnet" % "mxnet-native-auto" % "1.8.0"
// libraryDependencies += "ai.djl.mxnet" % "mxnet-engine" % "0.20.0"
// libraryDependencies += "ai.djl.mxnet" % "mxnet-native-mkl" % "1.9.1"

libraryDependencies += "ai.djl.pytorch" % "pytorch-model-zoo" % "0.12.0"
libraryDependencies += "ai.djl.pytorch" % "pytorch-native-auto" % "1.8.1"
libraryDependencies += "ai.djl.pytorch" % "pytorch-engine" % "0.20.0"
libraryDependencies += "ai.djl.pytorch" % "pytorch-native-cpu-precxx11" % "1.13.0"

// libraryDependencies += "ai.djl.tensorflow" % "tensorflow-model-zoo" % "0.12.0"
// libraryDependencies += "ai.djl.tensorflow" % "tensorflow-native-auto" % "2.4.1"
// libraryDependencies += "ai.djl.tensorflow" % "tensorflow-engine" % "0.20.0"
// libraryDependencies += "ai.djl.tensorflow" % "tensorflow-native-cpu" % "2.7.0"
1 change: 1 addition & 0 deletions apache-spark/spark3.0/image-classification/gradle
1 change: 1 addition & 0 deletions apache-spark/spark3.0/image-classification/gradlew.bat
Loading

0 comments on commit 3e0fd3f

Please sign in to comment.