diff --git a/docker/spark/Dockerfile b/docker/spark/Dockerfile index 0d9ffb321d9..d2cee4adac4 100644 --- a/docker/spark/Dockerfile +++ b/docker/spark/Dockerfile @@ -22,34 +22,33 @@ ARG FFMPEG_VERSION=5.1.2-1.5.8 ARG TENSORFLOW_CORE_VERSION=0.4.2 ARG PROTOBUF_VERSION=3.21.9 +RUN yum -y install python3-devel gcc COPY extensions/spark/setup/dist/ dist/ RUN pip3 install --no-cache-dir dist/djl_spark-*-py3-none-any.whl && \ rm -rf dist RUN pip3 install --no-cache-dir torch==1.13.1 --index-url https://download.pytorch.org/whl/cpu -RUN pip3 install --no-cache-dir pillow pandas numpy pyarrow 'transformers[torch]' - -ADD https://repo1.maven.org/maven2/ai/djl/api/${DJL_VERSION}/api-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-api-${DJL_VERSION}.jar -ADD https://repo1.maven.org/maven2/ai/djl/spark/spark/${DJL_VERSION}/spark-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-spark-${DJL_VERSION}.jar -ADD https://repo1.maven.org/maven2/ai/djl/huggingface/tokenizers/${DJL_VERSION}/tokenizers-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-tokenizers-${DJL_VERSION}.jar - -ADD https://repo1.maven.org/maven2/ai/djl/audio/audio/${DJL_VERSION}/audio-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-audio-${DJL_VERSION}.jar -ADD https://repo1.maven.org/maven2/org/bytedeco/javacv/${JAVACV_VERSION}/javacv-${JAVACV_VERSION}.jar /usr/lib/spark/jars/ -ADD https://repo1.maven.org/maven2/org/bytedeco/javacpp/${JAVACPP_VERSION}/javacpp-${JAVACPP_VERSION}.jar /usr/lib/spark/jars/ -ADD https://repo1.maven.org/maven2/org/bytedeco/ffmpeg/${FFMPEG_VERSION}/ffmpeg-${FFMPEG_VERSION}.jar /usr/lib/spark/jars/ -ADD https://repo1.maven.org/maven2/org/bytedeco/ffmpeg/${FFMPEG_VERSION}/ffmpeg-${FFMPEG_VERSION}-linux-x86_64.jar /usr/lib/spark/jars/ -ADD https://repo1.maven.org/maven2/org/bytedeco/ffmpeg-platform/${FFMPEG_VERSION}/ffmpeg-platform-${FFMPEG_VERSION}.jar /usr/lib/spark/jars/ - -ADD https://repo1.maven.org/maven2/ai/djl/pytorch/pytorch-engine/${DJL_VERSION}/pytorch-engine-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-pytorch-engine-${DJL_VERSION}.jar -ADD https://repo1.maven.org/maven2/ai/djl/pytorch/pytorch-model-zoo/${DJL_VERSION}/pytorch-model-zoo-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-model-zoo-${DJL_VERSION}.jar -ADD https://repo1.maven.org/maven2/net/java/dev/jna/jna/${JNA_VERSION}/jna-${JNA_VERSION}.jar /usr/lib/spark/jars/ - -ADD https://repo1.maven.org/maven2/ai/djl/tensorflow/tensorflow-engine/${DJL_VERSION}/tensorflow-engine-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-tensorflow-engine-${DJL_VERSION}.jar -ADD https://repo1.maven.org/maven2/ai/djl/tensorflow/tensorflow-model-zoo/${DJL_VERSION}/tensorflow-model-zoo-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-tensorflow-model-zoo-${DJL_VERSION}.jar -ADD https://repo1.maven.org/maven2/org/tensorflow/tensorflow-core-api/${TENSORFLOW_CORE_VERSION}/tensorflow-core-api-${TENSORFLOW_CORE_VERSION}.jar /usr/lib/spark/jars/ +RUN pip3 install --no-cache-dir pillow pandas numpy pyarrow librosa 'transformers[torch]' + +ADD --chmod=644 https://repo1.maven.org/maven2/ai/djl/api/${DJL_VERSION}/api-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-api-${DJL_VERSION}.jar +ADD --chmod=644 https://repo1.maven.org/maven2/ai/djl/spark/spark/${DJL_VERSION}/spark-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-spark-${DJL_VERSION}.jar +ADD --chmod=644 https://repo1.maven.org/maven2/ai/djl/huggingface/tokenizers/${DJL_VERSION}/tokenizers-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-tokenizers-${DJL_VERSION}.jar + +ADD --chmod=644 https://repo1.maven.org/maven2/ai/djl/audio/audio/${DJL_VERSION}/audio-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-audio-${DJL_VERSION}.jar +ADD --chmod=644 https://repo1.maven.org/maven2/org/bytedeco/javacv/${JAVACV_VERSION}/javacv-${JAVACV_VERSION}.jar /usr/lib/spark/jars/ +ADD --chmod=644 https://repo1.maven.org/maven2/org/bytedeco/javacpp/${JAVACPP_VERSION}/javacpp-${JAVACPP_VERSION}.jar /usr/lib/spark/jars/ +ADD --chmod=644 https://repo1.maven.org/maven2/org/bytedeco/ffmpeg/${FFMPEG_VERSION}/ffmpeg-${FFMPEG_VERSION}.jar /usr/lib/spark/jars/ +ADD --chmod=644 https://repo1.maven.org/maven2/org/bytedeco/ffmpeg/${FFMPEG_VERSION}/ffmpeg-${FFMPEG_VERSION}-linux-x86_64.jar /usr/lib/spark/jars/ +ADD --chmod=644 https://repo1.maven.org/maven2/org/bytedeco/ffmpeg-platform/${FFMPEG_VERSION}/ffmpeg-platform-${FFMPEG_VERSION}.jar /usr/lib/spark/jars/ + +ADD --chmod=644 https://repo1.maven.org/maven2/ai/djl/pytorch/pytorch-engine/${DJL_VERSION}/pytorch-engine-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-pytorch-engine-${DJL_VERSION}.jar +ADD --chmod=644 https://repo1.maven.org/maven2/ai/djl/pytorch/pytorch-model-zoo/${DJL_VERSION}/pytorch-model-zoo-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-model-zoo-${DJL_VERSION}.jar +ADD --chmod=644 https://repo1.maven.org/maven2/net/java/dev/jna/jna/${JNA_VERSION}/jna-${JNA_VERSION}.jar /usr/lib/spark/jars/ + +ADD --chmod=644 https://repo1.maven.org/maven2/ai/djl/tensorflow/tensorflow-engine/${DJL_VERSION}/tensorflow-engine-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-tensorflow-engine-${DJL_VERSION}.jar +ADD --chmod=644 https://repo1.maven.org/maven2/ai/djl/tensorflow/tensorflow-model-zoo/${DJL_VERSION}/tensorflow-model-zoo-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-tensorflow-model-zoo-${DJL_VERSION}.jar +ADD --chmod=644 https://repo1.maven.org/maven2/org/tensorflow/tensorflow-core-api/${TENSORFLOW_CORE_VERSION}/tensorflow-core-api-${TENSORFLOW_CORE_VERSION}.jar /usr/lib/spark/jars/ RUN rm /usr/lib/spark/jars/protobuf-java-*.jar -ADD https://repo1.maven.org/maven2/com/google/protobuf/protobuf-java/${PROTOBUF_VERSION}/protobuf-java-${PROTOBUF_VERSION}.jar /usr/lib/spark/jars/ - -RUN chmod -R +r /usr/lib/spark/jars/ +ADD --chmod=644 https://repo1.maven.org/maven2/com/google/protobuf/protobuf-java/${PROTOBUF_VERSION}/protobuf-java-${PROTOBUF_VERSION}.jar /usr/lib/spark/jars/ RUN echo 'export SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS -Dai.djl.pytorch.graph_optimizer=false"' >> /usr/lib/spark/conf/spark-env.sh diff --git a/extensions/spark/setup/djl_spark/task/audio/__init__.py b/extensions/spark/setup/djl_spark/task/audio/__init__.py new file mode 100644 index 00000000000..bed43fbd52f --- /dev/null +++ b/extensions/spark/setup/djl_spark/task/audio/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python +# +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. + +"""DJL Spark Tasks Text API.""" + +from . import whisper_speech_recognizer + +WhisperSpeechRecognizer = whisper_speech_recognizer.WhisperSpeechRecognizer + +# Remove unnecessary modules to avoid duplication in API. +del whisper_speech_recognizer \ No newline at end of file diff --git a/extensions/spark/setup/djl_spark/task/audio/whisper_speech_recognizer.py b/extensions/spark/setup/djl_spark/task/audio/whisper_speech_recognizer.py new file mode 100644 index 00000000000..1b2c0c4a788 --- /dev/null +++ b/extensions/spark/setup/djl_spark/task/audio/whisper_speech_recognizer.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# +# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file +# except in compliance with the License. A copy of the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for +# the specific language governing permissions and limitations under the License. + +from pyspark import SparkContext +from pyspark.sql.functions import pandas_udf +from pyspark.sql.types import StringType +import io +import librosa +import pandas as pd +from typing import Iterator +from transformers import pipeline + + +class WhisperSpeechRecognizer: + + def __init__(self, input_col, output_col, engine, model_url=None, model_name=None): + """ + Initializes the WhisperSpeechRecognizer. + + :param input_col: The input column + :param output_col: The output column + :param engine: The engine. Currently only PyTorch is supported. + :param model_url: The model URL + :param model_name: The model name + """ + self.input_col = input_col + self.output_col = output_col + self.engine = engine + self.model_url = model_url + self.model_name = model_name + + def recognize(self, dataset, generate_kwargs=None, **kwargs): + """ + Performs speech recognition on the provided dataset. + + :param dataset: input dataset + :return: output dataset + """ + sc = SparkContext._active_spark_context + if not self.model_url and not self.model_name: + raise ValueError("Either model_url or model_name must be provided.") + model_name_or_url = self.model_url if self.model_url else self.model_name + pipe = pipeline("automatic-speech-recognition", generate_kwargs=generate_kwargs, + model=model_name_or_url, chunk_length_s=30, **kwargs) + bc_pipe = sc.broadcast(pipe) + + @pandas_udf(StringType()) + def predict_udf(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: + for s in iterator: + batch = [] + for d in s: + # Model expects single channel, 16000 sample rate audio + data, sample_rate = librosa.load(io.BytesIO(d), mono=True, sr=16000) + batch.append(data) + output = bc_pipe.value(batch) + text = [o["text"] for o in output] + yield pd.Series(text) + + return dataset.withColumn(self.output_col, predict_udf(self.input_col)) \ No newline at end of file diff --git a/extensions/spark/src/main/scala/ai/djl/spark/task/audio/WhisperSpeechRecognizer.scala b/extensions/spark/src/main/scala/ai/djl/spark/task/audio/WhisperSpeechRecognizer.scala deleted file mode 100644 index cfaf3b3254e..00000000000 --- a/extensions/spark/src/main/scala/ai/djl/spark/task/audio/WhisperSpeechRecognizer.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance - * with the License. A copy of the License is located at - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES - * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions - * and limitations under the License. - */ -package ai.djl.spark.task.audio - -import ai.djl.audio.translator.WhisperTranslatorFactory -import ai.djl.modality.audio.Audio -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.Row -import org.apache.spark.sql.types.{StringType, StructField, StructType} -import org.bytedeco.javacv.{FFmpegFrameGrabber, Frame, FrameGrabber} - -import java.io.{ByteArrayInputStream, IOException} -import java.nio.ShortBuffer -import java.util -import scala.jdk.CollectionConverters.asScalaBufferConverter - -/** - * WhisperSpeechRecognizer is very similar to the SpeechRecognizer that performs speech recognition on audio, - * except that this API is specially tailored for OpenAI Whisper related models. - * - * @param uid An immutable unique ID for the object and its derivatives. - */ -class WhisperSpeechRecognizer(override val uid: String) extends SpeechRecognizer { - - def this() = this(Identifiable.randomUID("WhisperSpeechRecognizer")) - - private final val CHUNK_LENGTH: Int = 200 - - setDefault(translatorFactory, new WhisperTranslatorFactory()) - - /** - * Transforms the rows. - * - * @param iter the rows to transform - * @return the transformed rows - */ - override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = { - val predictor = model.newPredictor() - iter.map(row => { - val data = row.getAs[Array[Byte]](inputColIndex) - val audioList = splitAudio(data) - val result = predictor.batchPredict(audioList) - Row.fromSeq(row.toSeq :+ String.join("", result)) - }) - } - - /** @inheritdoc */ - override def transformSchema(schema: StructType): StructType = { - val outputSchema = StructType(schema.fields :+ StructField($(outputCol), StringType)) - outputSchema - } - - /** - * Split the audio data into chunks. - * - * @param data the audio data - * @return the list of audio chunks - */ - private def splitAudio(data: Array[Byte]): util.ArrayList[Audio] = { - val audioList = new util.ArrayList[Audio]() - val grabber = new FFmpegFrameGrabber(new ByteArrayInputStream(data)) - try { - if (isDefined(channels)) { - grabber.setAudioChannels($(channels)) - } - if (isDefined(sampleRate)) { - grabber.setSampleRate($(sampleRate)) - } - if (isDefined(sampleFormat)) { - grabber.setSampleFormat($(sampleFormat)) - } - grabber.start() - val list = new util.ArrayList[Float]() - var frame: Frame = null - var i = 0 - while ( { - frame = grabber.grabFrame(true, false, true, false, false) - frame != null - }) { - if (i > CHUNK_LENGTH) { - audioList.add(new Audio(list.asScala.toArray, grabber.getSampleRate, grabber.getAudioChannels)) - list.clear() - i = 0 // Reset the counter - } - val buf = frame.samples(0).asInstanceOf[ShortBuffer] - for (j <- 0 until buf.limit) { - list.add(buf.get / Short.MaxValue.toFloat) - } - i += 1 - } - if (!list.isEmpty) { - audioList.add(new Audio(list.asScala.toArray, grabber.getSampleRate, grabber.getAudioChannels)) - } - audioList - } catch { - case e: FrameGrabber.Exception => - throw new IOException("Unsupported Audio file", e) - } finally { - if (grabber != null) { - grabber.close() - } - } - } -} \ No newline at end of file