Skip to content
This repository has been archived by the owner on Dec 20, 2018. It is now read-only.

Reader writer improvements #283

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/main/scala/com/databricks/spark/avro/DefaultSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.zip.Deflater
import scala.util.control.NonFatal

import com.databricks.spark.avro.DefaultSource.{AvroSchema, IgnoreFilesWithoutExtensionProperty, SerializableConfiguration}
import com.databricks.spark.avro.generic.SparkGenericDatumReader
import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
import com.esotericsoftware.kryo.io.{Input, Output}
import org.apache.avro.{Schema, SchemaBuilder}
Expand Down Expand Up @@ -178,10 +179,8 @@ private[avro] class DefaultSource extends FileFormat with DataSourceRegister {
val reader = {
val in = new FsInput(new Path(new URI(file.filePath)), conf)
try {
val datumReader = userProvidedSchema match {
case Some(userSchema) => new GenericDatumReader[GenericRecord](userSchema)
case _ => new GenericDatumReader[GenericRecord]()
}
val datumReader = new SparkGenericDatumReader()
userProvidedSchema.foreach(datumReader.setSchema)
DataFileReader.openReader(in, datumReader)
} catch {
case NonFatal(e) =>
Expand Down Expand Up @@ -210,9 +209,11 @@ private[avro] class DefaultSource extends FileFormat with DataSourceRegister {
private val encoderForDataColumns = RowEncoder(requiredSchema)

private[this] var completed = false
private var record: GenericRecord = _

override def hasNext: Boolean = {
if (completed) {
record = null
false
} else {
val r = reader.hasNext && !reader.pastSync(stop)
Expand All @@ -228,7 +229,9 @@ private[avro] class DefaultSource extends FileFormat with DataSourceRegister {
if (reader.pastSync(stop)) {
throw new NoSuchElementException("next on empty iterator")
}
val record = reader.next()

// record is reused by avro, we copy it's content with rowconverter
record = reader.next(record)
val safeDataRow = rowConverter(record).asInstanceOf[GenericRow]

// The safeDataRow is reused, we must do a copy
Expand Down
15 changes: 7 additions & 8 deletions src/main/scala/com/databricks/spark/avro/SchemaConverters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,12 @@ object SchemaConverters {
targetSqlType: DataType): AnyRef => AnyRef = {

def createConverter(avroSchema: Schema,
sqlType: DataType, path: List[String]): AnyRef => AnyRef = {
sparkSqlType: DataType, path: List[String]): AnyRef => AnyRef = {
val avroType = avroSchema.getType
(sqlType, avroType) match {
// Avro strings are in Utf8, so we have to call toString on them
case (StringType, STRING) | (StringType, ENUM) =>
(item: AnyRef) => item.toString
// Byte arrays are reused by avro, so we have to make a copy of them.
case (IntegerType, INT) | (BooleanType, BOOLEAN) | (DoubleType, DOUBLE) |
(sparkSqlType, avroType) match {
case (StringType, ENUM) => (item: AnyRef) => item.toString
case (StringType, STRING) | (IntegerType, INT) |
(BooleanType, BOOLEAN) | (DoubleType, DOUBLE) |
(FloatType, FLOAT) | (LongType, LONG) =>
identity
case (TimestampType, LONG) =>
Expand All @@ -160,7 +158,8 @@ object SchemaConverters {
case (BinaryType, FIXED) =>
(item: AnyRef) => item.asInstanceOf[GenericFixed].bytes().clone()
case (BinaryType, BYTES) =>
(item: AnyRef) =>
// Byte arrays are reused by avro, so we have to make a copy of them.
(item: AnyRef) =>
val byteBuffer = item.asInstanceOf[ByteBuffer]
val bytes = new Array[Byte](byteBuffer.remaining)
byteBuffer.get(bytes)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.databricks.spark.avro.generic

import org.apache.avro.Schema
import org.apache.avro.generic.{GenericDatumReader, GenericRecord}

/**
* A generic datumreader that reads strings as string instead of utf-8
*/
class SparkGenericDatumReader extends GenericDatumReader[GenericRecord]{

override def findStringClass(

schema: Schema): Class[_] = classOf[String]

}