diff --git a/.travis.yml b/.travis.yml index 42a12162..1df09667 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,6 +11,20 @@ before_cache: - find $HOME/.sbt -name "*.lock" -delete matrix: include: + # ---- Spark 2.0.x ---------------------------------------------------------------------------- + # Spark 2.0.0, Scala 2.11, and Avro 1.7.x + - jdk: openjdk7 + scala: 2.11.7 + env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.0.0" TEST_AVRO_VERSION="1.7.6" TEST_AVRO_MAPRED_VERSION="1.7.7" + # Spark 2.0.0, Scala 2.10, and Avro 1.7.x + - jdk: openjdk7 + scala: 2.10.4 + env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.0.0" TEST_AVRO_VERSION="1.7.6" TEST_AVRO_MAPRED_VERSION="1.7.7" + # Spark 2.0.0, Scala 2.10, and Avro 1.8.x + - jdk: openjdk7 + scala: 2.10.4 + env: TEST_HADOOP_VERSION="2.2.0" TEST_SPARK_VERSION="2.0.0" TEST_AVRO_VERSION="1.8.0" TEST_AVRO_MAPRED_VERSION="1.8.0" + # ---- Spark 2.1.x ---------------------------------------------------------------------------- # Spark 2.1.0, Scala 2.11, and Avro 1.7.x - jdk: openjdk7 scala: 2.11.8 diff --git a/build.sbt b/build.sbt index d8d933f3..1c9b212c 100644 --- a/build.sbt +++ b/build.sbt @@ -8,7 +8,7 @@ crossScalaVersions := Seq("2.10.6", "2.11.8") spName := "databricks/spark-avro" -sparkVersion := "2.1.0" +sparkVersion := "2.0.0" val testSparkVersion = settingKey[String]("The version of Spark to test against.") diff --git a/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala b/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala index bc71564e..cf515206 100644 --- a/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala +++ b/src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala @@ -32,6 +32,7 @@ import org.apache.avro.mapreduce.AvroKeyOutputFormat import org.apache.hadoop.io.NullWritable import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext, TaskAttemptID} +import org.apache.spark.SPARK_VERSION import org.apache.spark.sql.Row import org.apache.spark.sql.execution.datasources.OutputWriter import org.apache.spark.sql.types._ @@ -54,7 +55,14 @@ private[avro] class AvroOutputWriter( new AvroKeyOutputFormat[GenericRecord]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - new Path(path) + if (SPARK_VERSION.startsWith("2.0")) { + val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") + val taskAttemptId: TaskAttemptID = context.getTaskAttemptID + val split = taskAttemptId.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + } else { + new Path(path) + } } @throws(classOf[IOException]) diff --git a/src/main/scala/com/databricks/spark/avro/AvroOutputWriterFactory.scala b/src/main/scala/com/databricks/spark/avro/AvroOutputWriterFactory.scala index 3f3cbf07..84fce45d 100644 --- a/src/main/scala/com/databricks/spark/avro/AvroOutputWriterFactory.scala +++ b/src/main/scala/com/databricks/spark/avro/AvroOutputWriterFactory.scala @@ -26,11 +26,19 @@ private[avro] class AvroOutputWriterFactory( recordName: String, recordNamespace: String) extends OutputWriterFactory { - override def getFileExtension(context: TaskAttemptContext): String = { + def getFileExtension(context: TaskAttemptContext): String = { ".avro" } - override def newInstance( + def newInstance( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + newInstance(path, dataSchema, context) + } + + def newInstance( path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = {