Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Spark extension example #272

Merged
merged 1 commit into from
Jan 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ An example application that runs low cost/high performance inference with AWS In

## Big data integration

### [Spark Image Classification](apache-spark/image-classification/README.md)
### [Spark Image Classification](apache-spark/spark3.0/image-classification/README.md)

Contains Spark image classification demos.

Expand Down
140 changes: 26 additions & 114 deletions apache-spark/notebook/Image_Classification_Spark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Image Classification with Apache Spark\n",
"# Image Classification with DJL Spark Support\n",
"\n",
"In this example, we will use Jupyter Notebook to run image Classification with Apache Spark on Scala. To execute this Scala kernel successfully, you need to install [Almond](https://almond.sh/), a Scala kernel for Jupyter Notebook. Almond provide extensive functionalities for Scala and Spark applications.\n",
"In this example, we will use Jupyter Notebook to run image Classification with DJL Spark extension on Scala. To execute this Scala kernel successfully, you need to install [Almond](https://almond.sh/), a Scala kernel for Jupyter Notebook. Almond provide extensive functionalities for Scala and Spark applications.\n",
"\n",
"[Almond installation instruction](https://almond.sh/docs/quick-start-install) (Note: only Scala 2.12 are tested)\n",
"\n",
Expand All @@ -24,13 +24,13 @@
"metadata": {},
"outputs": [],
"source": [
"import $ivy.`org.apache.spark::spark-sql:3.0.1`\n",
"import $ivy.`org.apache.spark::spark-mllib:3.0.1`\n",
"import $ivy.`ai.djl:api:0.10.0`\n",
"import $ivy.`ai.djl.pytorch:pytorch-model-zoo:0.10.0`\n",
"import $ivy.`ai.djl.pytorch:pytorch-native-auto:1.7.1`\n",
"// import $ivy.`ai.djl.mxnet:mxnet-model-zoo:0.10.0`\n",
"// import $ivy.`ai.djl.mxnet:mxnet-native-auto:1.7.0-backport`"
"import $ivy.`org.apache.spark::spark-sql:3.2.2`\n",
"import $ivy.`ai.djl:api:0.20.0`\n",
"import $ivy.`ai.djl.spark:spark:0.20.0`\n",
"import $ivy.`ai.djl.pytorch:pytorch-model-zoo:0.20.0`\n",
"import $ivy.`ai.djl.pytorch:pytorch-native-cpu-precxx11:1.13.0`\n",
"// import $ivy.`ai.djl.mxnet:mxnet-engine:0.20.0`\n",
"// import $ivy.`ai.djl.mxnet:mxnet-native-mkl:1.9.1`"
]
},
{
Expand All @@ -46,106 +46,17 @@
"metadata": {},
"outputs": [],
"source": [
"import java.util\n",
"import ai.djl.Model\n",
"import org.apache.spark.sql.NotebookSparkSession\n",
"import ai.djl.modality.Classifications\n",
"import ai.djl.modality.cv.transform.{ Resize, ToTensor}\n",
"import ai.djl.ndarray.types.{DataType, Shape}\n",
"import ai.djl.ndarray.{NDList, NDManager}\n",
"import ai.djl.repository.zoo.{Criteria, ModelZoo, ZooModel}\n",
"import ai.djl.training.util.ProgressBar\n",
"import ai.djl.translate.{Batchifier, Pipeline, Translator, TranslatorContext}\n",
"import ai.djl.util.Utils\n",
"import org.apache.spark.ml.image.ImageSchema\n",
"import org.apache.spark.sql.functions.col\n",
"import org.apache.spark.sql.{Encoders, Row, NotebookSparkSession}\n",
"import ai.djl.spark.SparkTransformer\n",
"import ai.djl.spark.translator.SparkImageClassificationTranslator\n",
"import org.apache.spark.sql.SparkSession\n",
"\n",
"import org.apache.log4j.{Level, Logger}\n",
"Logger.getLogger(\"org\").setLevel(Level.OFF) // avoid too much message popping out\n",
"Logger.getLogger(\"ai\").setLevel(Level.OFF) // avoid too much message popping out"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create Translator\n",
"\n",
"A Translator in DJL is used to define the preprocessing and postprocessing logic. The following code is to \n",
"\n",
"- preprocess: convert a Spark DataFrame Row to DJL NDArray.\n",
"- postprocess: convert inference result to classifications"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
" // Translator: a class used to do preprocessing and post processing\n",
" class MyTranslator extends Translator[Row, Classifications] {\n",
"\n",
" private var classes: java.util.List[String] = new util.ArrayList[String]()\n",
" private val pipeline: Pipeline = new Pipeline()\n",
" .add(new Resize(224, 224))\n",
" .add(new ToTensor())\n",
"\n",
" override def prepare(manager: NDManager, model: Model): Unit = {\n",
" classes = Utils.readLines(model.getArtifact(\"synset.txt\").openStream())\n",
" }\n",
"\n",
" override def processInput(ctx: TranslatorContext, row: Row): NDList = {\n",
"\n",
" val height = ImageSchema.getHeight(row)\n",
" val width = ImageSchema.getWidth(row)\n",
" val channel = ImageSchema.getNChannels(row)\n",
" var image = ctx.getNDManager.create(ImageSchema.getData(row), new Shape(height, width, channel)).toType(DataType.UINT8, true)\n",
" // BGR to RGB\n",
" image = image.flip(2)\n",
" pipeline.transform(new NDList(image))\n",
" }\n",
"\n",
" // Deal with the output.,NDList contains output result, usually one or more NDArray(s).\n",
" override def processOutput(ctx: TranslatorContext, list: NDList): Classifications = {\n",
" var probabilitiesNd = list.singletonOrThrow\n",
" probabilitiesNd = probabilitiesNd.softmax(0)\n",
" new Classifications(classes, probabilitiesNd)\n",
" }\n",
"\n",
" override def getBatchifier: Batchifier = Batchifier.STACK\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the model\n",
"\n",
"Now, we just need to fetch the model from a URL. The url can be a hdfs (hdfs://), file (file://) or http (https://) format. We use Criteria as a container to store the model and translator information. Then, all we need to do is to load the model from it.\n",
"\n",
"Note: DJL Criteria and Model are not serializable, so we add `lazy` declaration.\n",
"\n",
"If you are using MXNet as the backend engine, plase uncomment the mxnet model url."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"val modelUrl = \"https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/pytorch_resnet18.zip?model_name=traced_resnet18\"\n",
"// val modelUrl = \"https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/mxnet_resnet18.zip?model_name=resnet18_v1\"\n",
"lazy val criteria = Criteria.builder\n",
" .setTypes(classOf[Row], classOf[Classifications])\n",
" .optModelUrls(modelUrl)\n",
" .optTranslator(new MyTranslator())\n",
" .optProgress(new ProgressBar)\n",
" .build()\n",
"lazy val model = ModelZoo.loadModel(criteria)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -173,7 +84,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's try to load the images from the local folder using Spark ML library:"
"Let's try to load the images from the local folder using Spark library:"
]
},
{
Expand All @@ -190,7 +101,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Then We can run inference on these images. All we need to do is to create a `Predictor` and run inference with DJL."
"Then We can run inference on these images. All we need to do is to create a `SparkTransformer` and run inference with DJL."
]
},
{
Expand All @@ -199,14 +110,15 @@
"metadata": {},
"outputs": [],
"source": [
"val result = df.select(col(\"image.*\")).mapPartitions(partition => {\n",
" val predictor = model.newPredictor()\n",
" partition.map(row => {\n",
" // image data stored as HWC format\n",
" predictor.predict(row).toString\n",
" })\n",
"})(Encoders.STRING)\n",
"println(result.collect().mkString(\"\\n\"))"
"System.setProperty(\"PYTORCH_PRECXX11\", \"true\")\n",
"val transformer = new SparkTransformer[Classifications]()\n",
" .setInputCol(\"image.*\")\n",
" .setOutputCol(\"value\")\n",
" .setModelUrl(\"https://alpha-djl-demos.s3.amazonaws.com/model/djl-blockrunner/pytorch_resnet18.zip?model_name=traced_resnet18\")\n",
" .setOutputClass(classOf[Classifications])\n",
" .setTranslator(new SparkImageClassificationTranslator())\n",
"val outputDf = transformer.transform(df)\n",
"println(outputDf.collect().mkString(\"\\n\"))"
]
}
],
Expand All @@ -222,7 +134,7 @@
"mimetype": "text/x-scala",
"name": "scala",
"nbconvert_exporter": "script",
"version": "2.12.12"
"version": "2.12.11"
}
},
"nbformat": 4,
Expand Down
16 changes: 8 additions & 8 deletions apache-spark/spark3.0/image-classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ We use PyTorch engine by default, you can switch to MXNet or TensorFlow Engine a

#### MXNet
```
libraryDependencies += "ai.djl.mxnet" % "mxnet-model-zoo" % "0.10.0"
libraryDependencies += "ai.djl.mxnet" % "mxnet-native-auto" % "1.7.0-backport"
libraryDependencies += "ai.djl.mxnet" % "mxnet-engine" % "0.20.0"
libraryDependencies += "ai.djl.mxnet" % "mxnet-native-mkl" % "1.9.1"
```

#### TensorFlow
```
libraryDependencies += "ai.djl.tensorflow" % "tensorflow-model-zoo" % "0.10.0"
libraryDependencies += "ai.djl.tensorflow" % "tensorflow-native-auto" % "2.3.0"
libraryDependencies += "ai.djl.tensorflow" % "tensorflow-engine" % "0.20.0"
libraryDependencies += "ai.djl.tensorflow" % "tensorflow-native-cpu" % "2.7.0"
```

### gradle
Expand All @@ -32,14 +32,14 @@ You should add these in `dependencies`

#### MXNet
```
runtimeOnly "ai.djl.mxnet:mxnet-model-zoo:0.10.0"
runtimeOnly "ai.djl.mxnet:mxnet-native-auto:1.7.0-backport"
runtimeOnly "ai.djl.mxnet:mxnet-engine:0.20.0"
runtimeOnly "ai.djl.mxnet:mxnet-native-mkl:1.9.1"
```

#### TensorFlow
```
runtimeOnly "ai.djl.tensorflow:tensorflow-model-zoo:0.10.0"
runtimeOnly "ai.djl.tensorflow:tensorflow-native-auto:2.3.0"
runtimeOnly "ai.djl.tensorflow:tensorflow-engine:0.20.0"
runtimeOnly "ai.djl.tensorflow:tensorflow-native-cpu:2.7.0"
```

Apart from that, you may also need to use `NaiveEngine` mode in MXNet for multi-thread inference. You can also add `OMP_NUM_THREAD` environment variable to have the best performance optimization.
Expand Down
25 changes: 18 additions & 7 deletions apache-spark/spark3.0/image-classification/build.gradle
Original file line number Diff line number Diff line change
@@ -1,32 +1,43 @@
plugins {
id 'scala'
id 'application'
id 'com.github.johnrengelman.shadow' version '7.1.2'
}

group "com.example"
version "1.0-SNAPSHOT"

repositories {
mavenCentral()
mavenLocal()
}

dependencies {
implementation platform("ai.djl:bom:0.12.0")
implementation "org.apache.spark:spark-sql_2.12:3.0.1"
implementation "org.apache.spark:spark-mllib_2.12:3.0.1"
implementation platform("ai.djl:bom:0.20.0")
implementation "ai.djl:api"


runtimeOnly "ai.djl.pytorch:pytorch-model-zoo"
runtimeOnly "ai.djl.pytorch:pytorch-native-auto"
implementation "ai.djl.spark:spark"
runtimeOnly "ai.djl.pytorch:pytorch-engine"
runtimeOnly "ai.djl.pytorch:pytorch-native-cpu-precxx11"
}

compileScala {
scalaCompileOptions.setAdditionalParameters(["-target:jvm-1.8"])
}

application {
sourceCompatibility = JavaVersion.VERSION_1_8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use JDK 11?

Copy link
Contributor Author

@xyang16 xyang16 Dec 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the java version on EMR is 8 now.

targetCompatibility = JavaVersion.VERSION_1_8

getMainClass().set(System.getProperty("main", "com.examples.ImageClassificationExample"))
}

shadowJar {
zip64 true
mergeServiceFiles()
exclude "META-INF/*.SF"
exclude 'META-INF/*.DSA'
exclude 'META-INF/*.RSA'
exclude "LICENSE*"
}

tasks.distTar.enabled = false
17 changes: 8 additions & 9 deletions apache-spark/spark3.0/image-classification/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@ scalacOptions += "-target:jvm-1.8"

resolvers += Resolver.jcenterRepo

libraryDependencies += "org.apache.spark" %% "spark-sql" % "3.0.1"
libraryDependencies += "org.apache.spark" %% "spark-mllib" % "3.0.1"
libraryDependencies += "ai.djl" % "api" % "0.12.0"
libraryDependencies += "ai.djl" % "api" % "0.20.0"
libraryDependencies += "ai.djl.spark" % "spark" % "0.20.0"

// libraryDependencies += "ai.djl.mxnet" % "mxnet-model-zoo" % "0.12.0"
// libraryDependencies += "ai.djl.mxnet" % "mxnet-native-auto" % "1.8.0"
// libraryDependencies += "ai.djl.mxnet" % "mxnet-engine" % "0.20.0"
// libraryDependencies += "ai.djl.mxnet" % "mxnet-native-mkl" % "1.9.1"

libraryDependencies += "ai.djl.pytorch" % "pytorch-model-zoo" % "0.12.0"
libraryDependencies += "ai.djl.pytorch" % "pytorch-native-auto" % "1.8.1"
libraryDependencies += "ai.djl.pytorch" % "pytorch-engine" % "0.20.0"
libraryDependencies += "ai.djl.pytorch" % "pytorch-native-cpu-precxx11" % "1.13.0"

// libraryDependencies += "ai.djl.tensorflow" % "tensorflow-model-zoo" % "0.12.0"
// libraryDependencies += "ai.djl.tensorflow" % "tensorflow-native-auto" % "2.4.1"
// libraryDependencies += "ai.djl.tensorflow" % "tensorflow-engine" % "0.20.0"
// libraryDependencies += "ai.djl.tensorflow" % "tensorflow-native-cpu" % "2.7.0"
1 change: 1 addition & 0 deletions apache-spark/spark3.0/image-classification/gradle
1 change: 1 addition & 0 deletions apache-spark/spark3.0/image-classification/gradlew.bat
Loading