From df79ad1457f22346cc195962f2bed481f526a587 Mon Sep 17 00:00:00 2001 From: Owen O'Malley Date: Mon, 17 Jun 2019 16:14:03 +0200 Subject: [PATCH 1/4] ORC-484: Update PhysicalWriter to support encryption. Fixes #402 Signed-off-by: Owen O'Malley --- .../java/org/apache/orc/PhysicalWriter.java | 35 +- .../org/apache/orc/impl/PhysicalFsWriter.java | 486 +++++++++++++++--- .../java/org/apache/orc/impl/WriterImpl.java | 35 +- .../orc/impl/writer/ListTreeWriter.java | 13 +- .../apache/orc/impl/writer/MapTreeWriter.java | 16 +- .../orc/impl/writer/StructTreeWriter.java | 14 +- .../apache/orc/impl/writer/TreeWriter.java | 11 +- .../orc/impl/writer/TreeWriterBase.java | 17 +- .../orc/impl/writer/UnionTreeWriter.java | 15 +- .../apache/orc/impl/writer/WriterContext.java | 9 + .../orc/impl/writer/WriterEncryptionKey.java | 105 ++++ .../impl/writer/WriterEncryptionVariant.java | 128 +++++ .../org/apache/orc/TestStringDictionary.java | 5 + .../org/apache/orc/TestVectorOrcFile.java | 2 +- 14 files changed, 772 insertions(+), 119 deletions(-) create mode 100644 java/core/src/java/org/apache/orc/impl/writer/WriterEncryptionKey.java create mode 100644 java/core/src/java/org/apache/orc/impl/writer/WriterEncryptionVariant.java diff --git a/java/core/src/java/org/apache/orc/PhysicalWriter.java b/java/core/src/java/org/apache/orc/PhysicalWriter.java index e25e81c046..872d5dcde7 100644 --- a/java/core/src/java/org/apache/orc/PhysicalWriter.java +++ b/java/core/src/java/org/apache/orc/PhysicalWriter.java @@ -1,4 +1,4 @@ -/** +/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -22,6 +22,8 @@ import java.nio.ByteBuffer; import org.apache.orc.impl.StreamName; +import org.apache.orc.impl.writer.StreamOptions; +import org.apache.orc.impl.writer.WriterEncryptionVariant; /** * This interface separates the physical layout of ORC files from the higher @@ -39,7 +41,6 @@ interface OutputReceiver { * Output the given buffer to the final destination * * @param buffer the buffer to output - * @throws IOException */ void output(ByteBuffer buffer) throws IOException; @@ -48,16 +49,15 @@ interface OutputReceiver { */ void suppress(); } + /** * Writes the header of the file, which consists of the magic "ORC" bytes. - * @throws IOException */ void writeHeader() throws IOException; /** * Create an OutputReceiver for the given name. * @param name the name of the stream - * @throws IOException */ OutputReceiver createDataStream(StreamName name) throws IOException; @@ -65,7 +65,6 @@ interface OutputReceiver { * Write an index in the given stream name. * @param name the name of the stream * @param index the bloom filter to write - * @param codec the compression codec to use */ void writeIndex(StreamName name, OrcProto.RowIndex.Builder index) throws IOException; @@ -74,7 +73,6 @@ void writeIndex(StreamName name, * Write a bloom filter index in the given stream name. * @param name the name of the stream * @param bloom the bloom filter to write - * @param codec the compression codec to use */ void writeBloomFilter(StreamName name, OrcProto.BloomFilterIndex.Builder bloom) throws IOException; @@ -89,6 +87,16 @@ void writeBloomFilter(StreamName name, void finalizeStripe(OrcProto.StripeFooter.Builder footer, OrcProto.StripeInformation.Builder dirEntry) throws IOException; + /** + * Write a stripe or file statistics to the file. + * @param name the name of the stream + * @param statistics the statistics to write + * @throws IOException + */ + void writeStatistics(StreamName name, + OrcProto.ColumnStatistics.Builder statistics + ) throws IOException; + /** * Writes out the file metadata. * @param builder Metadata builder to finalize and write. @@ -122,19 +130,24 @@ void finalizeStripe(OrcProto.StripeFooter.Builder footer, * @param stripe Stripe data buffer. * @param dirEntry File metadata entry for the stripe, to be updated with * relevant data. - * @throws IOException */ void appendRawStripe(ByteBuffer stripe, OrcProto.StripeInformation.Builder dirEntry ) throws IOException; - /** Gets a compression codec used by this writer. */ - CompressionCodec getCompressionCodec(); - /** * Get the number of bytes for a file in a givem column. * @param column column from which to get file size + * @param variant the encryption variant to check * @return number of bytes for the given column */ - long getFileBytes(int column); + long getFileBytes(int column, WriterEncryptionVariant variant); + + /** + * Get the unencrypted stream options for this file. This class needs the + * stream options to write the indexes and footers. + * + * Additionally, the LLAP CacheWriter wants to disable the generic compression. + */ + StreamOptions getStreamOptions(); } diff --git a/java/core/src/java/org/apache/orc/impl/PhysicalFsWriter.java b/java/core/src/java/org/apache/orc/impl/PhysicalFsWriter.java index aee16b16db..4736a63052 100644 --- a/java/core/src/java/org/apache/orc/impl/PhysicalFsWriter.java +++ b/java/core/src/java/org/apache/orc/impl/PhysicalFsWriter.java @@ -22,19 +22,25 @@ import java.io.OutputStream; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.TreeMap; +import com.google.protobuf.ByteString; import com.google.protobuf.CodedOutputStream; import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.orc.CompressionCodec; +import org.apache.orc.EncryptionVariant; import org.apache.orc.OrcFile; import org.apache.orc.OrcProto; import org.apache.orc.PhysicalWriter; +import org.apache.orc.TypeDescription; +import org.apache.orc.impl.writer.WriterEncryptionKey; +import org.apache.orc.impl.writer.WriterEncryptionVariant; import org.apache.orc.impl.writer.StreamOptions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,10 +52,12 @@ public class PhysicalFsWriter implements PhysicalWriter { private static final int HDFS_BUFFER_SIZE = 256 * 1024; private FSDataOutputStream rawWriter; + private final DirectStream rawStream; + // the compressed metadata information outStream - private OutStream writer; + private OutStream compressStream; // a protobuf outStream around streamFactory - private CodedOutputStream protobufWriter; + private CodedOutputStream codedCompressStream; private final Path path; private final HadoopShims shims; @@ -59,10 +67,7 @@ public class PhysicalFsWriter implements PhysicalWriter { private final OrcFile.CompressionStrategy compressionStrategy; private final boolean addBlockPadding; private final boolean writeVariableLengthBlocks; - - // the streams that make up the current stripe - private final Map streams = - new TreeMap<>(); + private final VariantTracker unencrypted; private long headerLength; private long stripeStart; @@ -70,11 +75,24 @@ public class PhysicalFsWriter implements PhysicalWriter { // natural blocks private long blockOffset; private int metadataLength; + private int stripeStatisticsLength = 0; private int footerLength; + private int stripeNumber = 0; + + private final Map variants = new TreeMap<>(); + + public PhysicalFsWriter(FileSystem fs, + Path path, + OrcFile.WriterOptions opts + ) throws IOException { + this(fs, path, opts, new WriterEncryptionVariant[0]); + } public PhysicalFsWriter(FileSystem fs, Path path, - OrcFile.WriterOptions opts) throws IOException { + OrcFile.WriterOptions opts, + WriterEncryptionVariant[] encryption + ) throws IOException { this.path = path; long defaultStripeSize = opts.getStripeSize(); this.addBlockPadding = opts.getBlockPadding(); @@ -98,16 +116,124 @@ public PhysicalFsWriter(FileSystem fs, rawWriter = fs.create(path, opts.getOverwrite(), HDFS_BUFFER_SIZE, fs.getDefaultReplication(path), blockSize); blockOffset = 0; - writer = new OutStream("metadata", compress, - new DirectStream(rawWriter)); - protobufWriter = CodedOutputStream.newInstance(writer); + unencrypted = new VariantTracker(opts.getSchema(), compress); writeVariableLengthBlocks = opts.getWriteVariableLengthBlocks(); shims = opts.getHadoopShims(); + rawStream = new DirectStream(rawWriter); + compressStream = new OutStream("stripe footer", compress, rawStream); + codedCompressStream = CodedOutputStream.newInstance(compressStream); + for(WriterEncryptionVariant variant: encryption) { + WriterEncryptionKey key = variant.getKeyDescription(); + StreamOptions encryptOptions = + new StreamOptions(unencrypted.options) + .withEncryption(key.getAlgorithm(), variant.getFileFooterKey()); + variants.put(variant, new VariantTracker(variant.getRoot(), encryptOptions)); + } } - @Override - public CompressionCodec getCompressionCodec() { - return compress.getCodec(); + /** + * Record the information about each column encryption variant. + * The unencrypted data and each encrypted column root are variants. + */ + protected static class VariantTracker { + // the streams that make up the current stripe + protected final Map streams = new TreeMap<>(); + private final int rootColumn; + private final int lastColumn; + protected final StreamOptions options; + // a list for each column covered by this variant + // the elements in the list correspond to each stripe in the file + protected final List[] stripeStats; + protected final List stripeStatsStreams = new ArrayList<>(); + protected final OrcProto.ColumnStatistics[] fileStats; + + VariantTracker(TypeDescription schema, StreamOptions options) { + rootColumn = schema.getId(); + lastColumn = schema.getMaximumId(); + this.options = options; + stripeStats = new List[schema.getMaximumId() - schema.getId() + 1]; + for(int i=0; i < stripeStats.length; ++i) { + stripeStats[i] = new ArrayList<>(); + } + fileStats = new OrcProto.ColumnStatistics[stripeStats.length]; + } + + public BufferedStream createStream(StreamName name) { + BufferedStream result = new BufferedStream(); + streams.put(name, result); + return result; + } + + /** + * Place the streams in the appropriate area while updating the sizes + * with the number of bytes in the area. + * @param area the area to write + * @param sizes the sizes of the areas + * @return the list of stream descriptions to add + */ + public List placeStreams(StreamName.Area area, + SizeCounters sizes) { + List result = new ArrayList<>(streams.size()); + for(Map.Entry stream: streams.entrySet()) { + StreamName name = stream.getKey(); + BufferedStream bytes = stream.getValue(); + if (name.getArea() == area && !bytes.isSuppressed) { + OrcProto.Stream.Builder builder = OrcProto.Stream.newBuilder(); + long size = bytes.getOutputSize(); + if (area == StreamName.Area.INDEX) { + sizes.index += size; + } else { + sizes.data += size; + } + builder.setColumn(name.getColumn()) + .setKind(name.getKind()) + .setLength(size); + result.add(builder.build()); + } + } + return result; + } + + /** + * Write the streams in the appropriate area. + * @param area the area to write + * @param raw the raw stream to write to + */ + public void writeStreams(StreamName.Area area, + FSDataOutputStream raw) throws IOException { + for(Map.Entry stream: streams.entrySet()) { + if (stream.getKey().getArea() == area) { + stream.getValue().spillToDiskAndClear(raw); + } + } + } + + /** + * Computed the size of the given column on disk for this stripe. + * It excludes the index streams. + * @param column a column id + * @return the total number of bytes + */ + public long getFileBytes(int column) { + long result = 0; + if (column >= rootColumn && column <= lastColumn) { + for(Map.Entry entry: streams.entrySet()) { + StreamName name = entry.getKey(); + if (name.getColumn() == column && + name.getArea() != StreamName.Area.INDEX) { + result += entry.getValue().getOutputSize(); + } + } + } + return result; + } + } + + VariantTracker getVariant(EncryptionVariant column) { + if (column == null) { + return unencrypted; + } + return variants.get(column); } /** @@ -120,20 +246,13 @@ public CompressionCodec getCompressionCodec() { * @return number of bytes for the given column */ @Override - public long getFileBytes(final int column) { - long size = 0; - for (final Map.Entry pair: streams.entrySet()) { - final BufferedStream receiver = pair.getValue(); - if(!receiver.isSuppressed) { - - final StreamName name = pair.getKey(); - if(name.getColumn() == column && name.getArea() != StreamName.Area.INDEX ) { - size += receiver.getOutputSize(); - } - } + public long getFileBytes(int column, WriterEncryptionVariant variant) { + return getVariant(variant).getFileBytes(column); + } - } - return size; + @Override + public StreamOptions getStreamOptions() { + return unencrypted.options; } private static final byte[] ZEROS = new byte[64*1024]; @@ -198,36 +317,139 @@ public void suppress() { } private void writeStripeFooter(OrcProto.StripeFooter footer, - long dataSize, - long indexSize, + SizeCounters sizes, OrcProto.StripeInformation.Builder dirEntry) throws IOException { - footer.writeTo(protobufWriter); - protobufWriter.flush(); - writer.flush(); + footer.writeTo(codedCompressStream); + codedCompressStream.flush(); + compressStream.flush(); dirEntry.setOffset(stripeStart); - dirEntry.setFooterLength(rawWriter.getPos() - stripeStart - dataSize - indexSize); + dirEntry.setFooterLength(rawWriter.getPos() - stripeStart - sizes.total()); + } + + /** + * Write the saved encrypted stripe statistic in a variant out to the file. + * The streams that are written are added to the tracker.stripeStatsStreams. + * @param output the file we are writing to + * @param stripeNumber the number of stripes in the file + * @param tracker the variant to write out + */ + static void writeEncryptedStripeStatistics(DirectStream output, + int stripeNumber, + VariantTracker tracker + ) throws IOException { + StreamOptions options = new StreamOptions(tracker.options); + tracker.stripeStatsStreams.clear(); + for(int col = tracker.rootColumn; + col < tracker.rootColumn + tracker.stripeStats.length; ++col) { + options.modifyIv(CryptoUtils.modifyIvForStream(col, + OrcProto.Stream.Kind.STRIPE_STATISTICS, stripeNumber)); + OutStream stream = new OutStream("stripe stats for " + col, + options, output); + OrcProto.ColumnarStripeStatistics stats = + OrcProto.ColumnarStripeStatistics.newBuilder() + .addAllColStats(tracker.stripeStats[col - tracker.rootColumn]) + .build(); + long start = output.output.getPos(); + stats.writeTo(stream); + stream.flush(); + OrcProto.Stream description = OrcProto.Stream.newBuilder() + .setColumn(col) + .setKind(OrcProto.Stream.Kind.STRIPE_STATISTICS) + .setLength(output.output.getPos() - start) + .build(); + tracker.stripeStatsStreams.add(description); + } + } + + /** + * Merge the saved unencrypted stripe statistics into the Metadata section + * of the footer. + * @param builder the Metadata section of the file + * @param stripeCount the number of stripes in the file + * @param stats the stripe statistics + */ + static void setUnencryptedStripeStatistics(OrcProto.Metadata.Builder builder, + int stripeCount, + List[] stats) { + // Make the unencrypted stripe stats into lists of StripeStatistics. + builder.clearStripeStats(); + for(int s=0; s < stripeCount; ++s) { + OrcProto.StripeStatistics.Builder stripeStats = + OrcProto.StripeStatistics.newBuilder(); + for(List col: stats) { + stripeStats.addColStats(col.get(s)); + } + builder.addStripeStats(stripeStats.build()); + } + } + + static void setEncryptionStatistics(OrcProto.Encryption.Builder encryption, + int stripeNumber, + Collection variants + ) throws IOException { + int v = 0; + for(VariantTracker variant: variants) { + OrcProto.EncryptionVariant.Builder variantBuilder = + encryption.getVariantsBuilder(v++); + + // Add the stripe statistics streams to the variant description. + variantBuilder.clearStripeStatistics(); + variantBuilder.addAllStripeStatistics(variant.stripeStatsStreams); + + // Serialize and encrypt the file statistics. + OrcProto.FileStatistics.Builder file = OrcProto.FileStatistics.newBuilder(); + for(OrcProto.ColumnStatistics col: variant.fileStats) { + file.addColumn(col); + } + StreamOptions options = new StreamOptions(variant.options); + options.modifyIv(CryptoUtils.modifyIvForStream(variant.rootColumn, + OrcProto.Stream.Kind.FILE_STATISTICS, stripeNumber)); + BufferedStream buffer = new BufferedStream(); + OutStream stream = new OutStream("stats for " + variant, options, buffer); + file.build().writeTo(stream); + stream.flush(); + variantBuilder.setFileStatistics(buffer.getBytes()); + } } @Override public void writeFileMetadata(OrcProto.Metadata.Builder builder) throws IOException { - long startPosn = rawWriter.getPos(); - OrcProto.Metadata metadata = builder.build(); - metadata.writeTo(protobufWriter); - protobufWriter.flush(); - writer.flush(); - this.metadataLength = (int) (rawWriter.getPos() - startPosn); + long stripeStatisticsStart = rawWriter.getPos(); + for(VariantTracker variant: variants.values()) { + writeEncryptedStripeStatistics(rawStream, stripeNumber, variant); + } + setUnencryptedStripeStatistics(builder, stripeNumber, unencrypted.stripeStats); + long metadataStart = rawWriter.getPos(); + builder.build().writeTo(codedCompressStream); + codedCompressStream.flush(); + compressStream.flush(); + this.stripeStatisticsLength = (int) (metadataStart - stripeStatisticsStart); + this.metadataLength = (int) (rawWriter.getPos() - metadataStart); + } + + static void addUnencryptedStatistics(OrcProto.Footer.Builder builder, + OrcProto.ColumnStatistics[] stats) { + for(OrcProto.ColumnStatistics stat: stats) { + builder.addStatistics(stat); + } } @Override public void writeFileFooter(OrcProto.Footer.Builder builder) throws IOException { - long bodyLength = rawWriter.getPos() - metadataLength; + if (variants.size() > 0) { + OrcProto.Encryption.Builder encryption = builder.getEncryptionBuilder(); + setEncryptionStatistics(encryption, stripeNumber, variants.values()); + builder.setStripeStatisticsLength(stripeStatisticsLength); + } + addUnencryptedStatistics(builder, unencrypted.fileStats); + long bodyLength = rawWriter.getPos() - metadataLength - stripeStatisticsLength; builder.setContentLength(bodyLength); builder.setHeaderLength(headerLength); long startPosn = rawWriter.getPos(); OrcProto.Footer footer = builder.build(); - footer.writeTo(protobufWriter); - protobufWriter.flush(); - writer.flush(); + footer.writeTo(codedCompressStream); + codedCompressStream.flush(); + compressStream.flush(); this.footerLength = (int) (rawWriter.getPos() - startPosn); } @@ -300,7 +522,7 @@ public void appendRawStripe(ByteBuffer buffer, * data as buffers fill up and stores them in the output list. When the * stripe is being written, the whole stream is written to the file. */ - private static final class BufferedStream implements OutputReceiver { + static final class BufferedStream implements OutputReceiver { private boolean isSuppressed = false; private final List output = new ArrayList<>(); @@ -319,17 +541,56 @@ public void suppress() { /** * Write any saved buffers to the OutputStream if needed, and clears all the * buffers. + * @return true if the stream was written */ - void spillToDiskAndClear(FSDataOutputStream raw - ) throws IOException { + boolean spillToDiskAndClear(FSDataOutputStream raw) throws IOException { if (!isSuppressed) { for (ByteBuffer buffer: output) { raw.write(buffer.array(), buffer.arrayOffset() + buffer.position(), buffer.remaining()); } output.clear(); + return true; } isSuppressed = false; + return false; + } + + /** + * Get the buffer as a protobuf ByteString and clears the BufferedStream. + * @return the bytes + */ + ByteString getBytes() { + int len = output.size(); + if (len == 0) { + return ByteString.EMPTY; + } else { + ByteString result = ByteString.copyFrom(output.get(0)); + for (int i=1; i < output.size(); ++i) { + result = result.concat(ByteString.copyFrom(output.get(i))); + } + output.clear(); + return result; + } + } + + /** + * Get the stream as a ByteBuffer and clear it. + * @return a single ByteBuffer with the contents of the stream + */ + ByteBuffer getByteBuffer() { + ByteBuffer result; + if (output.size() == 1) { + result = output.get(0); + } else { + result = ByteBuffer.allocate((int) getOutputSize()); + for (ByteBuffer buffer : output) { + result.put(buffer); + } + output.clear(); + result.flip(); + } + return result; } /** @@ -347,38 +608,86 @@ public long getOutputSize() { } } + static class SizeCounters { + long index = 0; + long data = 0; + + long total() { + return index + data; + } + } + + void buildStreamList(OrcProto.StripeFooter.Builder footerBuilder, + SizeCounters sizes + ) throws IOException { + footerBuilder.addAllStreams( + unencrypted.placeStreams(StreamName.Area.INDEX, sizes)); + final long unencryptedIndexSize = sizes.index; + int v = 0; + for (VariantTracker variant: variants.values()) { + OrcProto.StripeEncryptionVariant.Builder builder = + footerBuilder.getEncryptionBuilder(v++); + builder.addAllStreams( + variant.placeStreams(StreamName.Area.INDEX, sizes)); + } + if (sizes.index != unencryptedIndexSize) { + // add a placeholder that covers the hole where the encrypted indexes are + footerBuilder.addStreams(OrcProto.Stream.newBuilder() + .setKind(OrcProto.Stream.Kind.ENCRYPTED_INDEX) + .setLength(sizes.index - unencryptedIndexSize)); + } + footerBuilder.addAllStreams( + unencrypted.placeStreams(StreamName.Area.DATA, sizes)); + final long unencryptedDataSize = sizes.data; + v = 0; + for (VariantTracker variant: variants.values()) { + OrcProto.StripeEncryptionVariant.Builder builder = + footerBuilder.getEncryptionBuilder(v++); + builder.addAllStreams( + variant.placeStreams(StreamName.Area.DATA, sizes)); + } + if (sizes.data != unencryptedDataSize) { + // add a placeholder that covers the hole where the encrypted indexes are + footerBuilder.addStreams(OrcProto.Stream.newBuilder() + .setKind(OrcProto.Stream.Kind.ENCRYPTED_DATA) + .setLength(sizes.data - unencryptedDataSize)); + } + } + @Override public void finalizeStripe(OrcProto.StripeFooter.Builder footerBuilder, OrcProto.StripeInformation.Builder dirEntry ) throws IOException { - long indexSize = 0; - long dataSize = 0; - for (Map.Entry pair: streams.entrySet()) { - BufferedStream receiver = pair.getValue(); - if (!receiver.isSuppressed) { - long streamSize = receiver.getOutputSize(); - StreamName name = pair.getKey(); - footerBuilder.addStreams(OrcProto.Stream.newBuilder().setColumn(name.getColumn()) - .setKind(name.getKind()).setLength(streamSize)); - if (StreamName.Area.INDEX == name.getArea()) { - indexSize += streamSize; - } else { - dataSize += streamSize; - } - } - } - dirEntry.setIndexLength(indexSize).setDataLength(dataSize); + SizeCounters sizes = new SizeCounters(); + buildStreamList(footerBuilder, sizes); OrcProto.StripeFooter footer = footerBuilder.build(); + // Do we need to pad the file so the stripe doesn't straddle a block boundary? - padStripe(indexSize + dataSize + footer.getSerializedSize()); + padStripe(sizes.total() + footer.getSerializedSize()); + + // write the unencrypted index streams + unencrypted.writeStreams(StreamName.Area.INDEX, rawWriter); + // write the encrypted index streams + for (VariantTracker variant: variants.values()) { + variant.writeStreams(StreamName.Area.INDEX, rawWriter); + } - // write out the data streams - for (Map.Entry pair : streams.entrySet()) { - pair.getValue().spillToDiskAndClear(rawWriter); + // write the unencrypted data streams + unencrypted.writeStreams(StreamName.Area.DATA, rawWriter); + // write out the unencrypted data streams + for (VariantTracker variant: variants.values()) { + variant.writeStreams(StreamName.Area.DATA, rawWriter); } + // Write out the footer. - writeStripeFooter(footer, dataSize, indexSize, dirEntry); + writeStripeFooter(footer, sizes, dirEntry); + + // fill in the data sizes + dirEntry.setDataLength(sizes.data); + dirEntry.setIndexLength(sizes.index); + + stripeNumber += 1; } @Override @@ -389,10 +698,11 @@ public void writeHeader() throws IOException { @Override public BufferedStream createDataStream(StreamName name) { - BufferedStream result = streams.get(name); + VariantTracker variant = getVariant(name.getEncryption()); + BufferedStream result = variant.streams.get(name); if (result == null) { result = new BufferedStream(); - streams.put(name, result); + variant.streams.put(name, result); } return result; } @@ -402,11 +712,26 @@ private StreamOptions getOptions(OrcProto.Stream.Kind kind) { kind); } + protected OutputStream createIndexStream(StreamName name) { + BufferedStream buffer = createDataStream(name); + VariantTracker tracker = getVariant(name.getEncryption()); + StreamOptions options = + SerializationUtils.getCustomizedCodec(tracker.options, + compressionStrategy, name.getKind()); + if (options.isEncrypted()) { + if (options == tracker.options) { + options = new StreamOptions(options); + } + options.modifyIv(CryptoUtils.modifyIvForStream(name, stripeNumber)); + } + return new OutStream(name.toString(), options, buffer); + } + @Override public void writeIndex(StreamName name, - OrcProto.RowIndex.Builder index) throws IOException { - OutputStream stream = new OutStream(path.toString(), - getOptions(name.getKind()), createDataStream(name)); + OrcProto.RowIndex.Builder index + ) throws IOException { + OutputStream stream = createIndexStream(name); index.build().writeTo(stream); stream.flush(); } @@ -415,12 +740,25 @@ public void writeIndex(StreamName name, public void writeBloomFilter(StreamName name, OrcProto.BloomFilterIndex.Builder bloom ) throws IOException { - OutputStream stream = new OutStream(path.toString(), - getOptions(name.getKind()), createDataStream(name)); + OutputStream stream = createIndexStream(name); bloom.build().writeTo(stream); stream.flush(); } + @Override + public void writeStatistics(StreamName name, + OrcProto.ColumnStatistics.Builder statistics + ) { + VariantTracker tracker = getVariant(name.getEncryption()); + if (name.getKind() == OrcProto.Stream.Kind.FILE_STATISTICS) { + tracker.fileStats[name.getColumn() - tracker.rootColumn] = + statistics.build(); + } else { + tracker.stripeStats[name.getColumn() - tracker.rootColumn] + .add(statistics.build()); + } + } + @Override public String toString() { return path.toString(); diff --git a/java/core/src/java/org/apache/orc/impl/WriterImpl.java b/java/core/src/java/org/apache/orc/impl/WriterImpl.java index 639f963398..4add5d3298 100644 --- a/java/core/src/java/org/apache/orc/impl/WriterImpl.java +++ b/java/core/src/java/org/apache/orc/impl/WriterImpl.java @@ -161,7 +161,7 @@ public Writer getWriter() { ? new PhysicalFsWriter(fs, path, opts) : opts.getPhysicalWriter(); physicalWriter.writeHeader(); - CompressionCodec codec = physicalWriter.getCompressionCodec(); + CompressionCodec codec = physicalWriter.getStreamOptions().getCodec(); if (codec != null) { compress.withCodec(codec, codec.getDefaultOptions()); } @@ -312,7 +312,7 @@ public boolean buildIndex() { * @return are the streams compressed */ public boolean isCompressed() { - return physicalWriter.getCompressionCodec() != null; + return physicalWriter.getStreamOptions().getCodec() != null; } /** @@ -379,6 +379,13 @@ public void writeBloomFilter(StreamName name, physicalWriter.writeBloomFilter(name, bloom); } + @Override + public void writeStatistics(StreamName name, + OrcProto.ColumnStatistics.Builder stats + ) throws IOException { + physicalWriter.writeStatistics(name, stats); + } + public boolean getUseUTCTimestamp() { return useUTCTimeZone; } @@ -453,9 +460,8 @@ private OrcProto.CompressionKind writeCompressionKind(CompressionKind kind) { } } - private void writeFileStatistics(OrcProto.Footer.Builder builder, - TreeWriter writer) throws IOException { - writer.writeFileStatistics(builder); + private void writeFileStatistics(TreeWriter writer) throws IOException { + writer.writeFileStatistics(); } private void writeMetadata() throws IOException { @@ -492,7 +498,7 @@ private long writeFooter() throws IOException { builder.addStripes(stripe); } // add the column statistics - writeFileStatistics(builder, treeWriter); + writeFileStatistics(treeWriter); // add all of the user metadata for(Map.Entry entry: userMetadata.entrySet()) { builder.addMetadata(OrcProto.UserMetadataItem.newBuilder() @@ -633,18 +639,17 @@ public void appendUserMetadata(List userMetadata) { } @Override - public ColumnStatistics[] getStatistics() - throws IOException { - // Generate the stats - OrcProto.Footer.Builder builder = OrcProto.Footer.newBuilder(); - - // add the column statistics - writeFileStatistics(builder, treeWriter); - return ReaderImpl.deserializeStats(schema, builder.getStatisticsList()); + public ColumnStatistics[] getStatistics() { + // get the column statistics + final ColumnStatistics[] result = + new ColumnStatistics[schema.getMaximumId() + 1]; + // Get the file statistics, preferring the encrypted one. + treeWriter.getCurrentStatistics(result); + return result; } public CompressionCodec getCompressionCodec() { - return physicalWriter.getCompressionCodec(); + return physicalWriter.getStreamOptions().getCodec(); } private static boolean hasTimestamp(TypeDescription schema) { diff --git a/java/core/src/java/org/apache/orc/impl/writer/ListTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/ListTreeWriter.java index c6068cdec0..a433e83cf7 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/ListTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/ListTreeWriter.java @@ -20,6 +20,7 @@ import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.ListColumnVector; +import org.apache.orc.ColumnStatistics; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; import org.apache.orc.impl.IntegerWriter; @@ -153,9 +154,9 @@ public long getRawDataSize() { } @Override - public void writeFileStatistics(OrcProto.Footer.Builder footer) { - super.writeFileStatistics(footer); - childWriter.writeFileStatistics(footer); + public void writeFileStatistics() throws IOException { + super.writeFileStatistics(); + childWriter.writeFileStatistics(); } @Override @@ -164,4 +165,10 @@ public void flushStreams() throws IOException { lengths.flush(); childWriter.flushStreams(); } + + @Override + public void getCurrentStatistics(ColumnStatistics[] output) { + super.getCurrentStatistics(output); + childWriter.getCurrentStatistics(output); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/MapTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/MapTreeWriter.java index 91e56578ae..5c09b365af 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/MapTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/MapTreeWriter.java @@ -19,6 +19,7 @@ import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.MapColumnVector; +import org.apache.orc.ColumnStatistics; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; import org.apache.orc.impl.IntegerWriter; @@ -164,10 +165,10 @@ public long getRawDataSize() { } @Override - public void writeFileStatistics(OrcProto.Footer.Builder footer) { - super.writeFileStatistics(footer); - keyWriter.writeFileStatistics(footer); - valueWriter.writeFileStatistics(footer); + public void writeFileStatistics() throws IOException { + super.writeFileStatistics(); + keyWriter.writeFileStatistics(); + valueWriter.writeFileStatistics(); } @Override @@ -177,4 +178,11 @@ public void flushStreams() throws IOException { keyWriter.flushStreams(); valueWriter.flushStreams(); } + + @Override + public void getCurrentStatistics(ColumnStatistics[] output) { + super.getCurrentStatistics(output); + keyWriter.getCurrentStatistics(output); + valueWriter.getCurrentStatistics(output); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/StructTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/StructTreeWriter.java index ee0b0c041a..d66f8468ec 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/StructTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/StructTreeWriter.java @@ -21,6 +21,7 @@ import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.StructColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.orc.ColumnStatistics; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; @@ -147,10 +148,10 @@ public long getRawDataSize() { } @Override - public void writeFileStatistics(OrcProto.Footer.Builder footer) { - super.writeFileStatistics(footer); + public void writeFileStatistics() throws IOException { + super.writeFileStatistics(); for (TreeWriter child : childrenWriters) { - child.writeFileStatistics(footer); + child.writeFileStatistics(); } } @@ -160,6 +161,13 @@ public void flushStreams() throws IOException { for (TreeWriter child : childrenWriters) { child.flushStreams(); } + } + @Override + public void getCurrentStatistics(ColumnStatistics[] output) { + super.getCurrentStatistics(output); + for (TreeWriter child: childrenWriters) { + child.getCurrentStatistics(output); + } } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/TreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/TreeWriter.java index bfa403eeff..662c57db48 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/TreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/TreeWriter.java @@ -20,6 +20,7 @@ import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.orc.ColumnStatistics; import org.apache.orc.OrcFile; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; @@ -98,9 +99,15 @@ void writeStripe(OrcProto.StripeFooter.Builder stripeFooter, /** * Add the file statistics to the file footer. - * @param footer the file footer builder */ - void writeFileStatistics(OrcProto.Footer.Builder footer); + void writeFileStatistics() throws IOException; + + /** + * Get the current file statistics for each column. If a column is encrypted, + * the encrypted variant statistics are used. + * @param output an array that is filled in with the results + */ + void getCurrentStatistics(ColumnStatistics[] output); class Factory { public static TreeWriter create(TypeDescription schema, diff --git a/java/core/src/java/org/apache/orc/impl/writer/TreeWriterBase.java b/java/core/src/java/org/apache/orc/impl/writer/TreeWriterBase.java index 7934b21755..4c36d09956 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/TreeWriterBase.java +++ b/java/core/src/java/org/apache/orc/impl/writer/TreeWriterBase.java @@ -23,6 +23,7 @@ import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.orc.ColumnStatistics; import org.apache.orc.OrcFile; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; @@ -256,12 +257,15 @@ public void writeStripe(OrcProto.StripeFooter.Builder builder, } /* Update byte count */ - final long byteCount = streamFactory.getPhysicalWriter().getFileBytes(id); + final long byteCount = streamFactory.getPhysicalWriter().getFileBytes(id, null); stripeColStatistics.updateByteCount(byteCount); // merge stripe-level column statistics to file statistics and write it to // stripe statistics fileStatistics.merge(stripeColStatistics); + streamFactory.writeStatistics( + new StreamName(id, OrcProto.Stream.Kind.STRIPE_STATISTICS, null), + stripeColStatistics.serialize()); stats.addColStats(stripeColStatistics.serialize()); stripeColStatistics.reset(); @@ -369,8 +373,10 @@ public long estimateMemory() { } @Override - public void writeFileStatistics(OrcProto.Footer.Builder footer) { - footer.addStatistics(fileStatistics.serialize()); + public void writeFileStatistics() throws IOException { + streamFactory.writeStatistics(new StreamName(id, + OrcProto.Stream.Kind.FILE_STATISTICS, null), + fileStatistics.serialize()); } static class RowIndexPositionRecorder implements PositionRecorder { @@ -385,4 +391,9 @@ public void addPosition(long position) { builder.addPositions(position); } } + + @Override + public void getCurrentStatistics(ColumnStatistics[] output) { + output[id] = fileStatistics; + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/UnionTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/UnionTreeWriter.java index 54a9a3a6dc..4456b62f55 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/UnionTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/UnionTreeWriter.java @@ -20,6 +20,7 @@ import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.UnionColumnVector; +import org.apache.orc.ColumnStatistics; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; import org.apache.orc.impl.PositionRecorder; @@ -165,10 +166,10 @@ public long getRawDataSize() { } @Override - public void writeFileStatistics(OrcProto.Footer.Builder footer) { - super.writeFileStatistics(footer); + public void writeFileStatistics() throws IOException { + super.writeFileStatistics(); for (TreeWriter child : childrenWriters) { - child.writeFileStatistics(footer); + child.writeFileStatistics(); } } @@ -180,4 +181,12 @@ public void flushStreams() throws IOException { child.flushStreams(); } } + + @Override + public void getCurrentStatistics(ColumnStatistics[] output) { + super.getCurrentStatistics(output); + for (TreeWriter child: childrenWriters) { + child.getCurrentStatistics(output); + } + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/WriterContext.java b/java/core/src/java/org/apache/orc/impl/writer/WriterContext.java index 9ef3ddaf2e..cfd2ec7ab0 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/WriterContext.java +++ b/java/core/src/java/org/apache/orc/impl/writer/WriterContext.java @@ -102,6 +102,15 @@ void writeBloomFilter(StreamName name, OrcProto.BloomFilterIndex.Builder bloom ) throws IOException; + /** + * Set the column statistics for the stripe or file. + * @param name the name of the statistics stream + * @param stats the statistics for this column in this stripe + */ + void writeStatistics(StreamName name, + OrcProto.ColumnStatistics.Builder stats + ) throws IOException; + boolean getUseUTCTimestamp(); double getDictionaryKeySizeThreshold(int column); diff --git a/java/core/src/java/org/apache/orc/impl/writer/WriterEncryptionKey.java b/java/core/src/java/org/apache/orc/impl/writer/WriterEncryptionKey.java new file mode 100644 index 0000000000..40606aa98c --- /dev/null +++ b/java/core/src/java/org/apache/orc/impl/writer/WriterEncryptionKey.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.orc.impl.writer; + +import org.apache.orc.EncryptionAlgorithm; +import org.apache.orc.EncryptionKey; +import org.apache.orc.EncryptionVariant; +import org.apache.orc.impl.HadoopShims; +import org.jetbrains.annotations.NotNull; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class WriterEncryptionKey implements EncryptionKey { + private final HadoopShims.KeyMetadata metadata; + private final List roots = new ArrayList<>(); + private int id; + + public WriterEncryptionKey(HadoopShims.KeyMetadata key) { + this.metadata = key; + } + + public void addRoot(WriterEncryptionVariant root) { + roots.add(root); + } + + public HadoopShims.KeyMetadata getMetadata() { + return metadata; + } + + public void setId(int id) { + this.id = id; + } + + @Override + public String getKeyName() { + return metadata.getKeyName(); + } + + @Override + public int getKeyVersion() { + return metadata.getVersion(); + } + + public EncryptionAlgorithm getAlgorithm() { + return metadata.getAlgorithm(); + } + + @Override + public WriterEncryptionVariant[] getEncryptionRoots() { + return roots.toArray(new WriterEncryptionVariant[roots.size()]); + } + + public int getId() { + return id; + } + + public void sortRoots() { + Collections.sort(roots); + } + + @Override + public int hashCode() { + return id; + } + + @Override + public boolean equals(Object other) { + if (other == null || getClass() != other.getClass()) { + return false; + } + return compareTo((EncryptionKey) other) == 0; + } + + @Override + public int compareTo(@NotNull EncryptionKey other) { + int result = getKeyName().compareTo(other.getKeyName()); + if (result == 0) { + result = Integer.compare(getKeyVersion(), other.getKeyVersion()); + } + return result; + } + + @Override + public String toString() { + return metadata.toString(); + } +} diff --git a/java/core/src/java/org/apache/orc/impl/writer/WriterEncryptionVariant.java b/java/core/src/java/org/apache/orc/impl/writer/WriterEncryptionVariant.java new file mode 100644 index 0000000000..ed026f7a90 --- /dev/null +++ b/java/core/src/java/org/apache/orc/impl/writer/WriterEncryptionVariant.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.orc.impl.writer; + +import org.apache.orc.EncryptionVariant; +import org.apache.orc.OrcProto; +import org.apache.orc.TypeDescription; +import org.apache.orc.impl.LocalKey; +import org.jetbrains.annotations.NotNull; + +import java.security.Key; +import java.util.ArrayList; +import java.util.List; + +public class WriterEncryptionVariant implements EncryptionVariant { + private int id; + private final WriterEncryptionKey key; + private final TypeDescription root; + private final LocalKey material; + private final OrcProto.FileStatistics.Builder fileStats = + OrcProto.FileStatistics.newBuilder(); + private final List encodings = new ArrayList<>(); + + public WriterEncryptionVariant(WriterEncryptionKey key, + TypeDescription root, + LocalKey columnKey) { + this.key = key; + this.root = root; + this.material = columnKey; + } + + @Override + public WriterEncryptionKey getKeyDescription() { + return key; + } + + public TypeDescription getRoot() { + return root; + } + + public void setId(int id) { + this.id = id; + } + + @Override + public int getVariantId() { + return id; + } + + @Override + public Key getFileFooterKey() { + return material.getDecryptedKey(); + } + + @Override + public Key getStripeKey(long stripe) { + return material.getDecryptedKey(); + } + + public LocalKey getMaterial() { + return material; + } + + public void clearFileStatistics() { + fileStats.clearColumn(); + } + + public void addFileStatistics(OrcProto.ColumnStatistics column) { + fileStats.addColumn(column); + } + + public OrcProto.FileStatistics getFileStatistics() { + return fileStats.build(); + } + + public void addEncoding(OrcProto.ColumnEncoding encoding) { + encodings.add(encoding); + } + + public List getEncodings() { + return encodings; + } + + public void clearEncodings() { + encodings.clear(); + } + + @Override + public int hashCode() { + return key.hashCode() << 16 ^ root.getId(); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } else if (other == null || other.getClass() != getClass()) { + return false; + } + return compareTo((WriterEncryptionVariant) other) == 0; + } + + @Override + public int compareTo(@NotNull EncryptionVariant other) { + int result = key.compareTo(other.getKeyDescription()); + if (result == 0) { + result = Integer.compare(root.getId(), other.getRoot().getId()); + } + return result; + } +} + diff --git a/java/core/src/test/org/apache/orc/TestStringDictionary.java b/java/core/src/test/org/apache/orc/TestStringDictionary.java index 27965fe375..2ab15064c0 100644 --- a/java/core/src/test/org/apache/orc/TestStringDictionary.java +++ b/java/core/src/test/org/apache/orc/TestStringDictionary.java @@ -239,6 +239,11 @@ public void writeBloomFilter(StreamName name, } + @Override + public void writeStatistics(StreamName name, OrcProto.ColumnStatistics.Builder stats) throws IOException { + + } + @Override public boolean getUseUTCTimestamp() { return true; diff --git a/java/core/src/test/org/apache/orc/TestVectorOrcFile.java b/java/core/src/test/org/apache/orc/TestVectorOrcFile.java index 658c1cea71..95f6458926 100644 --- a/java/core/src/test/org/apache/orc/TestVectorOrcFile.java +++ b/java/core/src/test/org/apache/orc/TestVectorOrcFile.java @@ -2099,7 +2099,7 @@ private CompressionCodec writeBatchesAndGetCodec(int count, ) throws IOException { fs.delete(testFilePath, false); PhysicalWriter physical = new PhysicalFsWriter(fs, testFilePath, opts); - CompressionCodec codec = physical.getCompressionCodec(); + CompressionCodec codec = physical.getStreamOptions().getCodec(); Writer writer = OrcFile.createWriter(testFilePath, opts.physicalWriter(physical)); writeRandomIntBytesBatches(writer, batch, count, size); From 73ba385f9534e1d919402ae0ad4ce229b33dc777 Mon Sep 17 00:00:00 2001 From: Owen O'Malley Date: Mon, 17 Jun 2019 19:31:58 +0200 Subject: [PATCH 2/4] ORC-486: Add column entryption to the TreeWriters. Fixes #403 Signed-off-by: Owen O'Malley --- .../org/apache/orc/impl/BitFieldWriter.java | 5 + .../org/apache/orc/impl/IntegerWriter.java | 3 + .../orc/impl/PositionedOutputStream.java | 8 + .../apache/orc/impl/RunLengthByteWriter.java | 5 + .../orc/impl/RunLengthIntegerWriter.java | 6 + .../orc/impl/RunLengthIntegerWriterV2.java | 6 + .../java/org/apache/orc/impl/WriterImpl.java | 336 ++++++++++++++---- .../orc/impl/writer/BinaryTreeWriter.java | 34 +- .../orc/impl/writer/BooleanTreeWriter.java | 28 +- .../orc/impl/writer/ByteTreeWriter.java | 27 +- .../orc/impl/writer/CharTreeWriter.java | 9 +- .../orc/impl/writer/DateTreeWriter.java | 26 +- .../orc/impl/writer/Decimal64TreeWriter.java | 27 +- .../orc/impl/writer/DecimalTreeWriter.java | 34 +- .../orc/impl/writer/DoubleTreeWriter.java | 27 +- .../orc/impl/writer/EncryptionTreeWriter.java | 151 ++++++++ .../orc/impl/writer/FloatTreeWriter.java | 27 +- .../orc/impl/writer/IntegerTreeWriter.java | 27 +- .../orc/impl/writer/ListTreeWriter.java | 33 +- .../apache/orc/impl/writer/MapTreeWriter.java | 38 +- .../orc/impl/writer/StringBaseTreeWriter.java | 45 ++- .../orc/impl/writer/StringTreeWriter.java | 9 +- .../orc/impl/writer/StructTreeWriter.java | 29 +- .../orc/impl/writer/TimestampTreeWriter.java | 36 +- .../apache/orc/impl/writer/TreeWriter.java | 148 ++++---- .../orc/impl/writer/TreeWriterBase.java | 84 +++-- .../orc/impl/writer/UnionTreeWriter.java | 36 +- .../orc/impl/writer/VarcharTreeWriter.java | 9 +- .../apache/orc/impl/writer/WriterContext.java | 198 ++++++----- .../org/apache/orc/TestStringDictionary.java | 24 +- 30 files changed, 1011 insertions(+), 464 deletions(-) create mode 100644 java/core/src/java/org/apache/orc/impl/writer/EncryptionTreeWriter.java diff --git a/java/core/src/java/org/apache/orc/impl/BitFieldWriter.java b/java/core/src/java/org/apache/orc/impl/BitFieldWriter.java index 1b1cd10474..044271cad5 100644 --- a/java/core/src/java/org/apache/orc/impl/BitFieldWriter.java +++ b/java/core/src/java/org/apache/orc/impl/BitFieldWriter.java @@ -18,6 +18,7 @@ package org.apache.orc.impl; import java.io.IOException; +import java.util.function.Consumer; public class BitFieldWriter { private RunLengthByteWriter output; @@ -70,4 +71,8 @@ public void getPosition(PositionRecorder recorder) throws IOException { public long estimateMemory() { return output.estimateMemory(); } + + public void changeIv(Consumer modifier) { + output.changeIv(modifier); + } } diff --git a/java/core/src/java/org/apache/orc/impl/IntegerWriter.java b/java/core/src/java/org/apache/orc/impl/IntegerWriter.java index 70b16d3e9f..19e843b09d 100644 --- a/java/core/src/java/org/apache/orc/impl/IntegerWriter.java +++ b/java/core/src/java/org/apache/orc/impl/IntegerWriter.java @@ -19,6 +19,7 @@ package org.apache.orc.impl; import java.io.IOException; +import java.util.function.Consumer; /** * Interface for writing integers. @@ -50,4 +51,6 @@ public interface IntegerWriter { * @return number of bytes */ long estimateMemory(); + + void changeIv(Consumer modifier); } diff --git a/java/core/src/java/org/apache/orc/impl/PositionedOutputStream.java b/java/core/src/java/org/apache/orc/impl/PositionedOutputStream.java index d412939dba..fd6561fa66 100644 --- a/java/core/src/java/org/apache/orc/impl/PositionedOutputStream.java +++ b/java/core/src/java/org/apache/orc/impl/PositionedOutputStream.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.io.OutputStream; +import java.util.function.Consumer; public abstract class PositionedOutputStream extends OutputStream { @@ -36,4 +37,11 @@ public abstract void getPosition(PositionRecorder recorder * @return the number of bytes used by buffers. */ public abstract long getBufferSize(); + + /** + * Change the current Initialization Vector (IV) for the encryption. + * Has no effect if the stream is not encrypted. + * @param modifier a function to modify the IV in place + */ + public abstract void changeIv(Consumer modifier); } diff --git a/java/core/src/java/org/apache/orc/impl/RunLengthByteWriter.java b/java/core/src/java/org/apache/orc/impl/RunLengthByteWriter.java index c2f1fa74da..bfa1d7a048 100644 --- a/java/core/src/java/org/apache/orc/impl/RunLengthByteWriter.java +++ b/java/core/src/java/org/apache/orc/impl/RunLengthByteWriter.java @@ -18,6 +18,7 @@ package org.apache.orc.impl; import java.io.IOException; +import java.util.function.Consumer; /** * A streamFactory that writes a sequence of bytes. A control byte is written before @@ -107,4 +108,8 @@ public void getPosition(PositionRecorder recorder) throws IOException { public long estimateMemory() { return output.getBufferSize() + MAX_LITERAL_SIZE; } + + public void changeIv(Consumer modifier) { + output.changeIv(modifier); + } } diff --git a/java/core/src/java/org/apache/orc/impl/RunLengthIntegerWriter.java b/java/core/src/java/org/apache/orc/impl/RunLengthIntegerWriter.java index 88b47e6a6d..710f493f8d 100644 --- a/java/core/src/java/org/apache/orc/impl/RunLengthIntegerWriter.java +++ b/java/core/src/java/org/apache/orc/impl/RunLengthIntegerWriter.java @@ -18,6 +18,7 @@ package org.apache.orc.impl; import java.io.IOException; +import java.util.function.Consumer; /** * A streamFactory that writes a sequence of integers. A control byte is written before @@ -144,4 +145,9 @@ public void getPosition(PositionRecorder recorder) throws IOException { public long estimateMemory() { return output.getBufferSize(); } + + @Override + public void changeIv(Consumer modifier) { + output.changeIv(modifier); + } } diff --git a/java/core/src/java/org/apache/orc/impl/RunLengthIntegerWriterV2.java b/java/core/src/java/org/apache/orc/impl/RunLengthIntegerWriterV2.java index e4c2a051e6..9107774ee8 100644 --- a/java/core/src/java/org/apache/orc/impl/RunLengthIntegerWriterV2.java +++ b/java/core/src/java/org/apache/orc/impl/RunLengthIntegerWriterV2.java @@ -18,6 +18,7 @@ package org.apache.orc.impl; import java.io.IOException; +import java.util.function.Consumer; /** *

A writer that performs light weight compression over sequence of integers. @@ -823,4 +824,9 @@ public void getPosition(PositionRecorder recorder) throws IOException { public long estimateMemory() { return output.getBufferSize(); } + + @Override + public void changeIv(Consumer modifier) { + output.changeIv(modifier); + } } diff --git a/java/core/src/java/org/apache/orc/impl/WriterImpl.java b/java/core/src/java/org/apache/orc/impl/WriterImpl.java index 4add5d3298..7f9cb63254 100644 --- a/java/core/src/java/org/apache/orc/impl/WriterImpl.java +++ b/java/core/src/java/org/apache/orc/impl/WriterImpl.java @@ -20,11 +20,15 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.security.SecureRandom; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.SortedMap; +import java.util.SortedSet; import java.util.TimeZone; import java.util.TreeMap; +import java.util.TreeSet; import io.airlift.compress.lz4.Lz4Compressor; import io.airlift.compress.lz4.Lz4Decompressor; @@ -33,6 +37,7 @@ import org.apache.orc.ColumnStatistics; import org.apache.orc.CompressionCodec; import org.apache.orc.CompressionKind; +import org.apache.orc.DataMask; import org.apache.orc.MemoryManager; import org.apache.orc.OrcConf; import org.apache.orc.OrcFile; @@ -41,7 +46,8 @@ import org.apache.orc.PhysicalWriter; import org.apache.orc.StripeInformation; import org.apache.orc.TypeDescription; -import org.apache.orc.Writer; +import org.apache.orc.impl.writer.WriterEncryptionKey; +import org.apache.orc.impl.writer.WriterEncryptionVariant; import org.apache.orc.impl.writer.StreamOptions; import org.apache.orc.impl.writer.TreeWriter; import org.apache.orc.impl.writer.WriterContext; @@ -76,16 +82,17 @@ public class WriterImpl implements WriterInternal, MemoryManager.Callback { private static final Logger LOG = LoggerFactory.getLogger(WriterImpl.class); + private static final HadoopShims SHIMS = HadoopShimsFactory.get(); private static final int MIN_ROW_INDEX_STRIDE = 1000; private final Path path; private long adjustedStripeSize; private final int rowIndexStride; - private final StreamOptions compress; private final TypeDescription schema; private final PhysicalWriter physicalWriter; private final OrcFile.WriterVersion writerVersion; + private final StreamOptions unencryptedOptions; private long rowCount = 0; private long rowsInStripe = 0; @@ -95,8 +102,6 @@ public class WriterImpl implements WriterInternal, MemoryManager.Callback { private int stripesAtLastFlush = -1; private final List stripes = new ArrayList<>(); - private final OrcProto.Metadata.Builder fileMetadata = - OrcProto.Metadata.newBuilder(); private final Map userMetadata = new TreeMap<>(); private final TreeWriter treeWriter; @@ -115,14 +120,55 @@ public class WriterImpl implements WriterInternal, MemoryManager.Callback { private final boolean useUTCTimeZone; private final double dictionaryKeySizeThreshold; private final boolean[] directEncodingColumns; + private final List unencryptedEncodings = + new ArrayList<>(); + + // the list of maskDescriptions, keys, and variants + private SortedSet maskDescriptions = new TreeSet<>(); + private SortedMap keys = new TreeMap<>(); + private final WriterEncryptionVariant[] encryption; + // the mapping of columns to maskDescriptions + private final MaskDescriptionImpl[] columnMaskDescriptions; + // the mapping of columns to EncryptionVariants + private final WriterEncryptionVariant[] columnEncryption; + private HadoopShims.KeyProvider keyProvider; + // do we need to include the current encryption keys in the next stripe + // information + private boolean needKeyFlush; public WriterImpl(FileSystem fs, Path path, OrcFile.WriterOptions opts) throws IOException { + this.schema = opts.getSchema(); + int numColumns = schema.getMaximumId() + 1; + if (!opts.isEnforceBufferSize()) { + opts.bufferSize(getEstimatedBufferSize(opts.getStripeSize(), numColumns, + opts.getBufferSize())); + } + + // Do we have column encryption? + List encryptionOptions = opts.getEncryption(); + columnEncryption = new WriterEncryptionVariant[numColumns]; + if (encryptionOptions.isEmpty()) { + columnMaskDescriptions = null; + encryption = new WriterEncryptionVariant[0]; + needKeyFlush = false; + } else { + columnMaskDescriptions = new MaskDescriptionImpl[numColumns]; + encryption = setupEncryption(opts.getKeyProvider(), encryptionOptions); + needKeyFlush = true; + } + + // Set up the physical writer + this.physicalWriter = opts.getPhysicalWriter() == null ? + new PhysicalFsWriter(fs, path, opts, encryption) : + opts.getPhysicalWriter(); + unencryptedOptions = physicalWriter.getStreamOptions(); + OutStream.assertBufferSizeValid(unencryptedOptions.getBufferSize()); + this.path = path; this.conf = opts.getConfiguration(); this.callback = opts.getCallback(); - this.schema = opts.getSchema(); this.writerVersion = opts.getWriterVersion(); bloomFilterVersion = opts.getBloomFilterVersion(); this.directEncodingColumns = OrcUtils.includeColumns( @@ -130,13 +176,7 @@ public WriterImpl(FileSystem fs, dictionaryKeySizeThreshold = OrcConf.DICTIONARY_KEY_SIZE_THRESHOLD.getDouble(conf); if (callback != null) { - callbackContext = new OrcFile.WriterContext(){ - - @Override - public Writer getWriter() { - return WriterImpl.this; - } - }; + callbackContext = () -> WriterImpl.this; } else { callbackContext = null; } @@ -149,22 +189,6 @@ public Writer getWriter() { this.rowIndexStride = opts.getRowIndexStride(); this.memoryManager = opts.getMemoryManager(); buildIndex = rowIndexStride > 0; - int numColumns = schema.getMaximumId() + 1; - if (opts.isEnforceBufferSize()) { - OutStream.assertBufferSizeValid(opts.getBufferSize()); - compress = new StreamOptions(opts.getBufferSize()); - } else { - compress = new StreamOptions(getEstimatedBufferSize(adjustedStripeSize, - numColumns, opts.getBufferSize())); - } - this.physicalWriter = opts.getPhysicalWriter() == null - ? new PhysicalFsWriter(fs, path, opts) - : opts.getPhysicalWriter(); - physicalWriter.writeHeader(); - CompressionCodec codec = physicalWriter.getStreamOptions().getCodec(); - if (codec != null) { - compress.withCodec(codec, codec.getDefaultOptions()); - } if (version == OrcFile.Version.FUTURE) { throw new IllegalArgumentException("Can not write in a unknown version."); } else if (version == OrcFile.Version.UNSTABLE_PRE_2_0) { @@ -180,17 +204,17 @@ public Writer getWriter() { OrcUtils.includeColumns(opts.getBloomFilterColumns(), schema); } this.bloomFilterFpp = opts.getBloomFilterFpp(); - treeWriter = TreeWriter.Factory.create(schema, new StreamFactory(), false); + physicalWriter.writeHeader(); + + treeWriter = TreeWriter.Factory.create(schema, null, new StreamFactory()); if (buildIndex && rowIndexStride < MIN_ROW_INDEX_STRIDE) { throw new IllegalArgumentException("Row stride must be at least " + MIN_ROW_INDEX_STRIDE); } - // ensure that we are able to handle callbacks before we register ourselves memoryManager.addWriter(path, opts.getStripeSize(), this); - LOG.info("ORC writer created for path: {} with stripeSize: {} blockSize: {}" + - " compression: {}", path, adjustedStripeSize, opts.getBlockSize(), - compress); + LOG.info("ORC writer created for path: {} with stripeSize: {} options: {}", + path, adjustedStripeSize, unencryptedOptions); } //@VisibleForTesting @@ -207,8 +231,8 @@ public static int getEstimatedBufferSize(long stripeSize, int numColumns, @Override public void increaseCompressionSize(int newSize) { - if (newSize > compress.getBufferSize()) { - compress.bufferSize(newSize); + if (newSize > unencryptedOptions.getBufferSize()) { + unencryptedOptions.bufferSize(newSize); } } @@ -277,19 +301,26 @@ public boolean checkMemory(double newScale) throws IOException { * that the TreeWriters have into the Writer. */ private class StreamFactory implements WriterContext { + /** * Create a stream to store part of a column. - * @param column the column id for the stream - * @param kind the kind of stream + * @param name the name for the stream * @return The output outStream that the section needs to be written to. */ - public OutStream createStream(int column, - OrcProto.Stream.Kind kind - ) throws IOException { - final StreamName name = new StreamName(column, kind); - return new OutStream(physicalWriter.toString(), - SerializationUtils.getCustomizedCodec(compress, compressionStrategy, kind), - physicalWriter.createDataStream(name)); + public OutStream createStream(StreamName name) throws IOException { + StreamOptions options = SerializationUtils.getCustomizedCodec( + unencryptedOptions, compressionStrategy, name.getKind()); + WriterEncryptionVariant encryption = + (WriterEncryptionVariant) name.getEncryption(); + if (encryption != null) { + if (options == unencryptedOptions) { + options = new StreamOptions(options); + } + options.withEncryption(encryption.getKeyDescription().getAlgorithm(), + encryption.getFileFooterKey()) + .modifyIv(CryptoUtils.modifyIvForStream(name, 1)); + } + return new OutStream(name, options, physicalWriter.createDataStream(name)); } /** @@ -312,7 +343,7 @@ public boolean buildIndex() { * @return are the streams compressed */ public boolean isCompressed() { - return physicalWriter.getStreamOptions().getCodec() != null; + return unencryptedOptions.getCodec() != null; } /** @@ -379,6 +410,30 @@ public void writeBloomFilter(StreamName name, physicalWriter.writeBloomFilter(name, bloom); } + @Override + public WriterEncryptionVariant getEncryption(int columnId) { + return columnId < columnEncryption.length ? + columnEncryption[columnId] : null; + } + + @Override + public DataMask getUnencryptedMask(int columnId) { + MaskDescriptionImpl descr = columnMaskDescriptions[columnId]; + return descr == null ? null : + DataMask.Factory.build(descr, schema.findSubtype(columnId), + (type) -> columnMaskDescriptions[type.getId()]); + } + + @Override + public void setEncoding(int column, WriterEncryptionVariant encryption, + OrcProto.ColumnEncoding encoding) { + if (encryption == null) { + unencryptedEncodings.add(encoding); + } else { + encryption.addEncoding(encoding); + } + } + @Override public void writeStatistics(StreamName name, OrcProto.ColumnStatistics.Builder stats @@ -406,6 +461,19 @@ private void createRowIndexEntry() throws IOException { rowsInIndex = 0; } + /** + * Write the encrypted keys into the StripeInformation along with the + * stripe id, so that the readers can decrypt the data. + * @param dirEntry the entry to modify + */ + private void addEncryptedKeys(OrcProto.StripeInformation.Builder dirEntry) { + for(WriterEncryptionVariant variant: encryption) { + dirEntry.addEncryptedLocalKeys(ByteString.copyFrom( + variant.getMaterial().getEncryptedKey())); + } + dirEntry.setEncryptStripeId(1 + stripes.size()); + } + private void flushStripe() throws IOException { if (buildIndex && rowsInIndex != 0) { createRowIndexEntry(); @@ -426,18 +494,27 @@ private void flushStripe() throws IOException { builder.setWriterTimezone(TimeZone.getDefault().getID()); } } - OrcProto.StripeStatistics.Builder stats = - OrcProto.StripeStatistics.newBuilder(); - treeWriter.flushStreams(); - treeWriter.writeStripe(builder, stats, requiredIndexEntries); - + treeWriter.writeStripe(requiredIndexEntries); + // update the encodings + builder.addAllColumns(unencryptedEncodings); + unencryptedEncodings.clear(); + for (WriterEncryptionVariant writerEncryptionVariant : encryption) { + OrcProto.StripeEncryptionVariant.Builder encrypt = + OrcProto.StripeEncryptionVariant.newBuilder(); + encrypt.addAllEncoding(writerEncryptionVariant.getEncodings()); + writerEncryptionVariant.clearEncodings(); + builder.addEncryption(encrypt); + } OrcProto.StripeInformation.Builder dirEntry = OrcProto.StripeInformation.newBuilder() .setNumberOfRows(rowsInStripe); + if (encryption.length > 0 && needKeyFlush) { + addEncryptedKeys(dirEntry); + needKeyFlush = false; + } physicalWriter.finalizeStripe(builder, dirEntry); - fileMetadata.addStripeStats(stats.build()); stripes.add(dirEntry.build()); rowCount += rowsInStripe; rowsInStripe = 0; @@ -460,31 +537,73 @@ private OrcProto.CompressionKind writeCompressionKind(CompressionKind kind) { } } - private void writeFileStatistics(TreeWriter writer) throws IOException { - writer.writeFileStatistics(); - } - private void writeMetadata() throws IOException { - physicalWriter.writeFileMetadata(fileMetadata); + // The physical writer now has the stripe statistics, so we pass a + // new builder in here. + physicalWriter.writeFileMetadata(OrcProto.Metadata.newBuilder()); } private long writePostScript() throws IOException { - CompressionCodec codec = compress.getCodec(); OrcProto.PostScript.Builder builder = OrcProto.PostScript.newBuilder() - .setCompression(writeCompressionKind(codec == null - ? CompressionKind.NONE - : codec.getKind())) .setMagic(OrcFile.MAGIC) .addVersion(version.getMajor()) .addVersion(version.getMinor()) .setWriterVersion(writerVersion.getId()); - if (compress.getCodec() != null) { - builder.setCompressionBlockSize(compress.getBufferSize()); + CompressionCodec codec = unencryptedOptions.getCodec(); + if (codec == null) { + builder.setCompression(OrcProto.CompressionKind.NONE); + } else { + builder.setCompression(writeCompressionKind(codec.getKind())) + .setCompressionBlockSize(unencryptedOptions.getBufferSize()); } return physicalWriter.writePostScript(builder); } + private OrcProto.EncryptionKey.Builder writeEncryptionKey(WriterEncryptionKey key) { + OrcProto.EncryptionKey.Builder result = OrcProto.EncryptionKey.newBuilder(); + HadoopShims.KeyMetadata meta = key.getMetadata(); + result.setKeyName(meta.getKeyName()); + result.setKeyVersion(meta.getVersion()); + result.setAlgorithm(OrcProto.EncryptionAlgorithm.valueOf( + meta.getAlgorithm().getSerialization())); + return result; + } + + private OrcProto.EncryptionVariant.Builder + writeEncryptionVariant(WriterEncryptionVariant variant) { + OrcProto.EncryptionVariant.Builder result = + OrcProto.EncryptionVariant.newBuilder(); + result.setRoot(variant.getRoot().getId()); + result.setKey(variant.getKeyDescription().getId()); + result.setEncryptedKey(ByteString.copyFrom(variant.getMaterial().getEncryptedKey())); + return result; + } + + private OrcProto.Encryption.Builder writeEncryptionFooter() { + OrcProto.Encryption.Builder encrypt = OrcProto.Encryption.newBuilder(); + for(MaskDescriptionImpl mask: maskDescriptions) { + OrcProto.DataMask.Builder maskBuilder = OrcProto.DataMask.newBuilder(); + maskBuilder.setName(mask.getName()); + for(String param: mask.getParameters()) { + maskBuilder.addMaskParameters(param); + } + for(TypeDescription column: mask.getColumns()) { + maskBuilder.addColumns(column.getId()); + } + encrypt.addMask(maskBuilder); + } + for(WriterEncryptionKey key: keys.values()) { + encrypt.addKey(writeEncryptionKey(key)); + } + for(WriterEncryptionVariant variant: encryption) { + encrypt.addVariants(writeEncryptionVariant(variant)); + } + encrypt.setKeyProvider(OrcProto.KeyProviderKind.valueOf( + keyProvider.getKind().getValue())); + return encrypt; + } + private long writeFooter() throws IOException { writeMetadata(); OrcProto.Footer.Builder builder = OrcProto.Footer.newBuilder(); @@ -498,12 +617,15 @@ private long writeFooter() throws IOException { builder.addStripes(stripe); } // add the column statistics - writeFileStatistics(treeWriter); + treeWriter.writeFileStatistics(); // add all of the user metadata for(Map.Entry entry: userMetadata.entrySet()) { builder.addMetadata(OrcProto.UserMetadataItem.newBuilder() .setName(entry.getKey()).setValue(entry.getValue())); } + if (encryption.length > 0) { + builder.setEncryption(writeEncryptionFooter()); + } builder.setWriter(OrcFile.WriterImplementation.ORC_JAVA.getId()); physicalWriter.writeFileFooter(builder); return writePostScript(); @@ -521,6 +643,11 @@ public void addUserMetadata(String name, ByteBuffer value) { @Override public void addRowBatch(VectorizedRowBatch batch) throws IOException { + // If this is the first set of rows in this stripe, tell the tree writers + // to prepare the stripe. + if (batch.size != 0 && rowsInStripe == 0) { + treeWriter.prepareStripe(stripes.size() + 1); + } if (buildIndex) { // Batch the writes up to the rowIndexStride so that we can get the // right size indexes. @@ -607,6 +734,10 @@ public void appendStripe(byte[] stripe, int offset, int length, checkArgument(stripeStatistics != null, "Stripe statistics must not be null"); + // If we have buffered rows, flush them + if (rowsInStripe > 0) { + flushStripe(); + } rowsInStripe = stripeInfo.getNumberOfRows(); // update stripe information OrcProto.StripeInformation.Builder dirEntry = OrcProto.StripeInformation @@ -620,13 +751,13 @@ public void appendStripe(byte[] stripe, int offset, int length, // since we have already written the stripe, just update stripe statistics treeWriter.updateFileStatistics(stripeStatistics); - fileMetadata.addStripeStats(stripeStatistics); stripes.add(dirEntry.build()); // reset it after writing the stripe rowCount += rowsInStripe; rowsInStripe = 0; + needKeyFlush = encryption.length > 0; } @Override @@ -649,7 +780,7 @@ public ColumnStatistics[] getStatistics() { } public CompressionCodec getCompressionCodec() { - return physicalWriter.getStreamOptions().getCodec(); + return unencryptedOptions.getCodec(); } private static boolean hasTimestamp(TypeDescription schema) { @@ -666,4 +797,79 @@ private static boolean hasTimestamp(TypeDescription schema) { } return false; } + + WriterEncryptionKey getKey(String keyName, + HadoopShims.KeyProvider provider) throws IOException { + WriterEncryptionKey result = keys.get(keyName); + if (result == null) { + result = new WriterEncryptionKey(provider.getCurrentKeyVersion(keyName)); + keys.put(keyName, result); + } + return result; + } + + MaskDescriptionImpl getMask(OrcFile.EncryptionOption opt) { + MaskDescriptionImpl result = new MaskDescriptionImpl(opt.getMask(), + opt.getMaskParameters()); + // if it is already there, get the earlier object + if (!maskDescriptions.add(result)) { + result = maskDescriptions.tailSet(result).first(); + } + return result; + } + + /** + * Iterate through the encryption options given by the user and set up + * our data structures. + * @param provider the KeyProvider to use to generate keys + * @param options the options from the user + */ + WriterEncryptionVariant[] setupEncryption(HadoopShims.KeyProvider provider, + List options + ) throws IOException { + keyProvider = provider != null ? provider : + SHIMS.getKeyProvider(conf, new SecureRandom()); + if (keyProvider == null) { + throw new IllegalArgumentException("Encryption requires a KeyProvider."); + } + // fill out the primary encryption keys + int variantCount = 0; + for(OrcFile.EncryptionOption option: options) { + MaskDescriptionImpl mask = getMask(option); + for(TypeDescription col: schema.findSubtypes(option.getColumnNames())) { + mask.addColumn(col); + } + if (option.getKeyName() != null) { + WriterEncryptionKey key = getKey(option.getKeyName(), keyProvider); + HadoopShims.KeyMetadata metadata = key.getMetadata(); + for(TypeDescription rootType: schema.findSubtypes(option.getColumnNames())) { + WriterEncryptionVariant variant = new WriterEncryptionVariant(key, + rootType, keyProvider.createLocalKey(metadata)); + key.addRoot(variant); + variantCount += 1; + } + } + } + // Now that we have de-duped the keys and maskDescriptions, make the arrays + int nextId = 0; + for (MaskDescriptionImpl mask: maskDescriptions) { + mask.setId(nextId++); + for(TypeDescription column: mask.getColumns()) { + this.columnMaskDescriptions[column.getId()] = mask; + } + } + nextId = 0; + int nextVariantId = 0; + WriterEncryptionVariant[] result = new WriterEncryptionVariant[variantCount]; + for(WriterEncryptionKey key: keys.values()) { + key.setId(nextId++); + key.sortRoots(); + for(WriterEncryptionVariant variant: key.getEncryptionRoots()) { + result[nextVariantId] = variant; + columnEncryption[variant.getRoot().getId()] = variant; + variant.setId(nextVariantId++); + } + } + return result; + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/BinaryTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/BinaryTreeWriter.java index 14669c9ebc..0567d43a5d 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/BinaryTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/BinaryTreeWriter.java @@ -23,27 +23,30 @@ import org.apache.orc.BinaryColumnStatistics; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.IntegerWriter; import org.apache.orc.impl.PositionRecorder; import org.apache.orc.impl.PositionedOutputStream; +import org.apache.orc.impl.StreamName; import java.io.IOException; +import java.util.function.Consumer; public class BinaryTreeWriter extends TreeWriterBase { private final PositionedOutputStream stream; private final IntegerWriter length; private boolean isDirectV2 = true; - public BinaryTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); - this.stream = writer.createStream(id, - OrcProto.Stream.Kind.DATA); + public BinaryTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); + this.stream = writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA, encryption)); this.isDirectV2 = isNewWriteFormat(writer); - this.length = createIntegerWriter(writer.createStream(id, - OrcProto.Stream.Kind.LENGTH), false, isDirectV2, writer); + this.length = createIntegerWriter(writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.LENGTH, encryption)), + false, isDirectV2, writer); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -104,10 +107,8 @@ public void writeBatch(ColumnVector vector, int offset, @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -141,4 +142,11 @@ public void flushStreams() throws IOException { } + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + Consumer updater = CryptoUtils.modifyIvForStripe(stripeId); + stream.changeIv(updater); + length.changeIv(updater); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/BooleanTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/BooleanTreeWriter.java index 744aaefa0c..5329cf90aa 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/BooleanTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/BooleanTreeWriter.java @@ -24,21 +24,23 @@ import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; import org.apache.orc.impl.BitFieldWriter; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.PositionRecorder; import org.apache.orc.impl.PositionedOutputStream; +import org.apache.orc.impl.StreamName; import java.io.IOException; +import java.util.function.Consumer; public class BooleanTreeWriter extends TreeWriterBase { private final BitFieldWriter writer; - public BooleanTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); - PositionedOutputStream out = writer.createStream(id, - OrcProto.Stream.Kind.DATA); + public BooleanTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); + PositionedOutputStream out = writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA, encryption)); this.writer = new BitFieldWriter(out, 1); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); @@ -70,10 +72,8 @@ public void writeBatch(ColumnVector vector, int offset, } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -101,4 +101,10 @@ public void flushStreams() throws IOException { super.flushStreams(); writer.flush(); } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + writer.changeIv(CryptoUtils.modifyIvForStripe(stripeId)); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/ByteTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/ByteTreeWriter.java index a8dc0599bc..a3e1d456df 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/ByteTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/ByteTreeWriter.java @@ -23,21 +23,22 @@ import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.PositionRecorder; import org.apache.orc.impl.RunLengthByteWriter; +import org.apache.orc.impl.StreamName; import java.io.IOException; public class ByteTreeWriter extends TreeWriterBase { private final RunLengthByteWriter writer; - public ByteTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); - this.writer = new RunLengthByteWriter(writer.createStream(id, - OrcProto.Stream.Kind.DATA)); + public ByteTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); + this.writer = new RunLengthByteWriter(writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA, encryption))); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -80,10 +81,8 @@ public void writeBatch(ColumnVector vector, int offset, } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -111,4 +110,10 @@ public void flushStreams() throws IOException { super.flushStreams(); writer.flush(); } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + writer.changeIv(CryptoUtils.modifyIvForStripe(stripeId)); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/CharTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/CharTreeWriter.java index 14e3c26f22..83a72e9298 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/CharTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/CharTreeWriter.java @@ -34,11 +34,10 @@ public class CharTreeWriter extends StringBaseTreeWriter { private final int maxLength; private final byte[] padding; - CharTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); + CharTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); maxLength = schema.getMaxLength(); // utf-8 is currently 4 bytes long, but it could be upto 6 padding = new byte[6*maxLength]; diff --git a/java/core/src/java/org/apache/orc/impl/writer/DateTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/DateTreeWriter.java index 209dd0e36b..bc81d456a0 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/DateTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/DateTreeWriter.java @@ -23,9 +23,11 @@ import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.IntegerWriter; import org.apache.orc.impl.OutStream; import org.apache.orc.impl.PositionRecorder; +import org.apache.orc.impl.StreamName; import java.io.IOException; @@ -33,13 +35,12 @@ public class DateTreeWriter extends TreeWriterBase { private final IntegerWriter writer; private final boolean isDirectV2; - public DateTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); - OutStream out = writer.createStream(id, - OrcProto.Stream.Kind.DATA); + public DateTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); + OutStream out = writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA, encryption)); this.isDirectV2 = isNewWriteFormat(writer); this.writer = createIntegerWriter(out, true, isDirectV2, writer); if (rowIndexPosition != null) { @@ -84,10 +85,8 @@ public void writeBatch(ColumnVector vector, int offset, } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -127,4 +126,9 @@ public void flushStreams() throws IOException { writer.flush(); } + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + writer.changeIv(CryptoUtils.modifyIvForStripe(stripeId)); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/Decimal64TreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/Decimal64TreeWriter.java index 020d8ff4f8..4b3cfdd63f 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/Decimal64TreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/Decimal64TreeWriter.java @@ -25,11 +25,14 @@ import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.OutStream; import org.apache.orc.impl.PositionRecorder; import org.apache.orc.impl.RunLengthIntegerWriterV2; +import org.apache.orc.impl.StreamName; import java.io.IOException; +import java.util.function.Consumer; /** * Writer for short decimals in ORCv2. @@ -38,12 +41,12 @@ public class Decimal64TreeWriter extends TreeWriterBase { private final RunLengthIntegerWriterV2 valueWriter; private final int scale; - public Decimal64TreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); - OutStream stream = writer.createStream(id, OrcProto.Stream.Kind.DATA); + public Decimal64TreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); + OutStream stream = writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA)); // Use RLEv2 until we have the new RLEv3. valueWriter = new RunLengthIntegerWriterV2(stream, true, true); scale = schema.getScale(); @@ -121,10 +124,8 @@ public void writeBatch(ColumnVector vector, int offset, } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -151,4 +152,10 @@ public void flushStreams() throws IOException { super.flushStreams(); valueWriter.flush(); } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + valueWriter.changeIv(CryptoUtils.modifyIvForStripe(stripeId)); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/DecimalTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/DecimalTreeWriter.java index 822042eef5..be2b2bf3bc 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/DecimalTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/DecimalTreeWriter.java @@ -26,12 +26,15 @@ import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.IntegerWriter; import org.apache.orc.impl.PositionRecorder; import org.apache.orc.impl.PositionedOutputStream; import org.apache.orc.impl.SerializationUtils; +import org.apache.orc.impl.StreamName; import java.io.IOException; +import java.util.function.Consumer; public class DecimalTreeWriter extends TreeWriterBase { private final PositionedOutputStream valueStream; @@ -44,17 +47,18 @@ public class DecimalTreeWriter extends TreeWriterBase { private final IntegerWriter scaleStream; private final boolean isDirectV2; - public DecimalTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); + public DecimalTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); this.isDirectV2 = isNewWriteFormat(writer); - valueStream = writer.createStream(id, OrcProto.Stream.Kind.DATA); + valueStream = writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA, encryption)); scratchLongs = new long[HiveDecimal.SCRATCH_LONGS_LEN]; scratchBuffer = new byte[HiveDecimal.SCRATCH_BUFFER_LEN_TO_BYTES]; - this.scaleStream = createIntegerWriter(writer.createStream(id, - OrcProto.Stream.Kind.SECONDARY), true, isDirectV2, writer); + this.scaleStream = createIntegerWriter(writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.SECONDARY, encryption)), + true, isDirectV2, writer); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -161,10 +165,8 @@ public void writeBatch(ColumnVector vector, int offset, } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -195,4 +197,12 @@ public void flushStreams() throws IOException { valueStream.flush(); scaleStream.flush(); } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + Consumer updater = CryptoUtils.modifyIvForStripe(stripeId); + valueStream.changeIv(updater); + scaleStream.changeIv(updater); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/DoubleTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/DoubleTreeWriter.java index 84218ca6fb..17f0f73317 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/DoubleTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/DoubleTreeWriter.java @@ -23,9 +23,11 @@ import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.PositionRecorder; import org.apache.orc.impl.PositionedOutputStream; import org.apache.orc.impl.SerializationUtils; +import org.apache.orc.impl.StreamName; import java.io.IOException; @@ -33,13 +35,12 @@ public class DoubleTreeWriter extends TreeWriterBase { private final PositionedOutputStream stream; private final SerializationUtils utils; - public DoubleTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); - this.stream = writer.createStream(id, - OrcProto.Stream.Kind.DATA); + public DoubleTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); + this.stream = writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA, encryption)); this.utils = new SerializationUtils(); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); @@ -83,10 +84,8 @@ public void writeBatch(ColumnVector vector, int offset, } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); stream.flush(); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); @@ -115,4 +114,10 @@ public void flushStreams() throws IOException { super.flushStreams(); stream.flush(); } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + stream.changeIv(CryptoUtils.modifyIvForStripe(stripeId)); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/EncryptionTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/EncryptionTreeWriter.java new file mode 100644 index 0000000000..981f75a9c5 --- /dev/null +++ b/java/core/src/java/org/apache/orc/impl/writer/EncryptionTreeWriter.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.orc.impl.writer; + +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.orc.ColumnStatistics; +import org.apache.orc.DataMask; +import org.apache.orc.OrcProto; +import org.apache.orc.TypeDescription; + +import java.io.IOException; + +/** + * TreeWriter that handles column encryption. + * We create a TreeWriter for each of the alternatives with an WriterContext + * that creates encrypted streams. + */ +public class EncryptionTreeWriter implements TreeWriter { + // the different writers + private final TreeWriter[] childrenWriters; + private final DataMask[] masks; + // a column vector that we use to apply the masks + private final VectorizedRowBatch scratch; + + EncryptionTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext context) throws IOException { + scratch = schema.createRowBatch(); + childrenWriters = new TreeWriterBase[2]; + masks = new DataMask[childrenWriters.length]; + + // no mask, encrypted data + masks[0] = null; + childrenWriters[0] = Factory.createSubtree(schema, encryption, context); + + // masked unencrypted + masks[1] = context.getUnencryptedMask(schema.getId()); + childrenWriters[1] = Factory.createSubtree(schema, null, context); + } + + @Override + public void writeRootBatch(VectorizedRowBatch batch, int offset, + int length) throws IOException { + scratch.ensureSize(length); + for(int alt=0; alt < childrenWriters.length; ++alt) { + // if there is a mask, apply it to each column + if (masks[alt] != null) { + for(int col=0; col < scratch.cols.length; ++col) { + masks[alt].maskData(batch.cols[col], scratch.cols[col], offset, + length); + } + childrenWriters[alt].writeRootBatch(scratch, offset, length); + } else { + childrenWriters[alt].writeRootBatch(batch, offset, length); + } + } + } + + @Override + public void writeBatch(ColumnVector vector, int offset, + int length) throws IOException { + for(int alt=0; alt < childrenWriters.length; ++alt) { + // if there is a mask, apply it to each column + if (masks[alt] != null) { + masks[alt].maskData(vector, scratch.cols[0], offset, length); + childrenWriters[alt].writeBatch(scratch.cols[0], offset, length); + } else { + childrenWriters[alt].writeBatch(vector, offset, length); + } + } + } + + @Override + public void createRowIndexEntry() throws IOException { + for(TreeWriter child: childrenWriters) { + child.createRowIndexEntry(); + } + } + + @Override + public void flushStreams() throws IOException { + for(TreeWriter child: childrenWriters) { + child.flushStreams(); + } + } + + @Override + public void writeStripe(int requiredIndexEntries) throws IOException { + for(TreeWriter child: childrenWriters) { + child.writeStripe(requiredIndexEntries); + } + } + + @Override + public void updateFileStatistics(OrcProto.StripeStatistics stats) { + for(TreeWriter child: childrenWriters) { + child.updateFileStatistics(stats); + } + } + + @Override + public long estimateMemory() { + long result = 0; + for (TreeWriter writer : childrenWriters) { + result += writer.estimateMemory(); + } + return result; + } + + @Override + public long getRawDataSize() { + // return the size of the encrypted data + return childrenWriters[0].getRawDataSize(); + } + + @Override + public void prepareStripe(int stripeId) { + for (TreeWriter writer : childrenWriters) { + writer.prepareStripe(stripeId); + } + } + + @Override + public void writeFileStatistics() throws IOException { + for (TreeWriter child : childrenWriters) { + child.writeFileStatistics(); + } + } + + @Override + public void getCurrentStatistics(ColumnStatistics[] output) { + childrenWriters[0].getCurrentStatistics(output); + } +} diff --git a/java/core/src/java/org/apache/orc/impl/writer/FloatTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/FloatTreeWriter.java index e4198a21ac..bc3a15b023 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/FloatTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/FloatTreeWriter.java @@ -23,9 +23,11 @@ import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.PositionRecorder; import org.apache.orc.impl.PositionedOutputStream; import org.apache.orc.impl.SerializationUtils; +import org.apache.orc.impl.StreamName; import java.io.IOException; @@ -33,13 +35,12 @@ public class FloatTreeWriter extends TreeWriterBase { private final PositionedOutputStream stream; private final SerializationUtils utils; - public FloatTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); - this.stream = writer.createStream(id, - OrcProto.Stream.Kind.DATA); + public FloatTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); + this.stream = writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA, encryption)); this.utils = new SerializationUtils(); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); @@ -84,10 +85,8 @@ public void writeBatch(ColumnVector vector, int offset, @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -115,4 +114,10 @@ public void flushStreams() throws IOException { super.flushStreams(); stream.flush(); } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + stream.changeIv(CryptoUtils.modifyIvForStripe(stripeId)); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/IntegerTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/IntegerTreeWriter.java index dc0eaad1b9..7f8f21a0bd 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/IntegerTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/IntegerTreeWriter.java @@ -23,9 +23,11 @@ import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.IntegerWriter; import org.apache.orc.impl.OutStream; import org.apache.orc.impl.PositionRecorder; +import org.apache.orc.impl.StreamName; import java.io.IOException; @@ -34,13 +36,12 @@ public class IntegerTreeWriter extends TreeWriterBase { private boolean isDirectV2 = true; private final boolean isLong; - public IntegerTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); - OutStream out = writer.createStream(id, - OrcProto.Stream.Kind.DATA); + public IntegerTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); + OutStream out = writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA, encryption)); this.isDirectV2 = isNewWriteFormat(writer); this.writer = createIntegerWriter(out, true, isDirectV2, writer); if (rowIndexPosition != null) { @@ -97,10 +98,8 @@ public void writeBatch(ColumnVector vector, int offset, } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -129,4 +128,10 @@ public void flushStreams() throws IOException { super.flushStreams(); writer.flush(); } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + writer.changeIv(CryptoUtils.modifyIvForStripe(stripeId)); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/ListTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/ListTreeWriter.java index a433e83cf7..3cd3ed11b1 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/ListTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/ListTreeWriter.java @@ -23,8 +23,10 @@ import org.apache.orc.ColumnStatistics; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.IntegerWriter; import org.apache.orc.impl.PositionRecorder; +import org.apache.orc.impl.StreamName; import java.io.IOException; @@ -33,15 +35,15 @@ public class ListTreeWriter extends TreeWriterBase { private final boolean isDirectV2; private final TreeWriter childWriter; - ListTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); + ListTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); this.isDirectV2 = isNewWriteFormat(writer); - childWriter = Factory.create(schema.getChildren().get(0), writer, true); - lengths = createIntegerWriter(writer.createStream(columnId, - OrcProto.Stream.Kind.LENGTH), false, isDirectV2, writer); + childWriter = Factory.create(schema.getChildren().get(0), encryption, writer); + lengths = createIntegerWriter(writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.LENGTH, encryption)), + false, isDirectV2, writer); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -120,11 +122,9 @@ public void writeBatch(ColumnVector vector, int offset, } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); - childWriter.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); + childWriter.writeStripe(requiredIndexEntries); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -171,4 +171,11 @@ public void getCurrentStatistics(ColumnStatistics[] output) { super.getCurrentStatistics(output); childWriter.getCurrentStatistics(output); } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + lengths.changeIv(CryptoUtils.modifyIvForStripe(stripeId)); + childWriter.prepareStripe(stripeId); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/MapTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/MapTreeWriter.java index 5c09b365af..02191adc7b 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/MapTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/MapTreeWriter.java @@ -22,8 +22,10 @@ import org.apache.orc.ColumnStatistics; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.IntegerWriter; import org.apache.orc.impl.PositionRecorder; +import org.apache.orc.impl.StreamName; import java.io.IOException; import java.util.List; @@ -34,17 +36,17 @@ public class MapTreeWriter extends TreeWriterBase { private final TreeWriter keyWriter; private final TreeWriter valueWriter; - MapTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); + MapTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); this.isDirectV2 = isNewWriteFormat(writer); List children = schema.getChildren(); - keyWriter = Factory.create(children.get(0), writer, true); - valueWriter = Factory.create(children.get(1), writer, true); - lengths = createIntegerWriter(writer.createStream(columnId, - OrcProto.Stream.Kind.LENGTH), false, isDirectV2, writer); + keyWriter = Factory.create(children.get(0), encryption, writer); + valueWriter = Factory.create(children.get(1), encryption, writer); + lengths = createIntegerWriter(writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.LENGTH, encryption)), + false, isDirectV2, writer); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -129,12 +131,10 @@ public void writeBatch(ColumnVector vector, int offset, } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); - keyWriter.writeStripe(builder, stats, requiredIndexEntries); - valueWriter.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); + keyWriter.writeStripe(requiredIndexEntries); + valueWriter.writeStripe(requiredIndexEntries); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -185,4 +185,12 @@ public void getCurrentStatistics(ColumnStatistics[] output) { keyWriter.getCurrentStatistics(output); valueWriter.getCurrentStatistics(output); } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + lengths.changeIv(CryptoUtils.modifyIvForStripe(stripeId)); + keyWriter.prepareStripe(stripeId); + valueWriter.prepareStripe(stripeId); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/StringBaseTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/StringBaseTreeWriter.java index e7d32593d3..c3f56a6d03 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/StringBaseTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/StringBaseTreeWriter.java @@ -25,16 +25,19 @@ import org.apache.orc.OrcProto; import org.apache.orc.StringColumnStatistics; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.DynamicIntArray; import org.apache.orc.impl.IntegerWriter; import org.apache.orc.impl.OutStream; import org.apache.orc.impl.PositionRecorder; import org.apache.orc.impl.PositionedOutputStream; +import org.apache.orc.impl.StreamName; import org.apache.orc.impl.StringRedBlackTree; import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.function.Consumer; public abstract class StringBaseTreeWriter extends TreeWriterBase { private static final int INITIAL_DICTIONARY_SIZE = 4096; @@ -57,17 +60,18 @@ public abstract class StringBaseTreeWriter extends TreeWriterBase { private boolean doneDictionaryCheck; private final boolean strideDictionaryCheck; - StringBaseTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); + StringBaseTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); this.isDirectV2 = isNewWriteFormat(writer); - directStreamOutput = writer.createStream(id, OrcProto.Stream.Kind.DATA); - stringOutput = writer.createStream(id, - OrcProto.Stream.Kind.DICTIONARY_DATA); - lengthOutput = createIntegerWriter(writer.createStream(id, - OrcProto.Stream.Kind.LENGTH), false, isDirectV2, writer); + directStreamOutput = writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA, encryption)); + stringOutput = writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DICTIONARY_DATA, encryption)); + lengthOutput = createIntegerWriter(writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.LENGTH, encryption)), + false, isDirectV2, writer); rowOutput = createIntegerWriter(directStreamOutput, false, isDirectV2, writer); if (rowIndexPosition != null) { @@ -76,7 +80,7 @@ public abstract class StringBaseTreeWriter extends TreeWriterBase { rowIndexValueCount.add(0L); buildIndex = writer.buildIndex(); Configuration conf = writer.getConfiguration(); - dictionaryKeySizeThreshold = writer.getDictionaryKeySizeThreshold(columnId); + dictionaryKeySizeThreshold = writer.getDictionaryKeySizeThreshold(id); strideDictionaryCheck = OrcConf.ROW_INDEX_STRIDE_DICTIONARY_CHECK.getBoolean(conf); if (dictionaryKeySizeThreshold <= 0.0) { @@ -99,9 +103,10 @@ private void checkDictionaryEncoding() { } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { + public void writeStripe(int requiredIndexEntries) throws IOException { + // if rows in stripe is less than dictionaryCheckAfterRows, dictionary + // checking would not have happened. So do it again here. + checkDictionaryEncoding(); checkDictionaryEncoding(); if (!useDictionaryEncoding) { @@ -110,8 +115,7 @@ public void writeStripe(OrcProto.StripeFooter.Builder builder, // we need to build the rowindex before calling super, since it // writes it out. - super.writeStripe(builder, stats, requiredIndexEntries); - + super.writeStripe(requiredIndexEntries); // reset all of the fields to be ready for the next stripe. dictionary.clear(); savedRowIndex.clear(); @@ -297,7 +301,16 @@ public void flushStreams() throws IOException { directStreamOutput.flush(); lengthOutput.flush(); } + } + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + Consumer updater = CryptoUtils.modifyIvForStripe(stripeId); + stringOutput.changeIv(updater); + lengthOutput.changeIv(updater); + rowOutput.changeIv(updater); + directStreamOutput.changeIv(updater); } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/StringTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/StringTreeWriter.java index ab6f38f9c6..ed1de950fd 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/StringTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/StringTreeWriter.java @@ -26,11 +26,10 @@ import java.nio.charset.StandardCharsets; public class StringTreeWriter extends StringBaseTreeWriter { - StringTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); + StringTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); } @Override diff --git a/java/core/src/java/org/apache/orc/impl/writer/StructTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/StructTreeWriter.java index d66f8468ec..a78b387533 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/StructTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/StructTreeWriter.java @@ -31,15 +31,14 @@ public class StructTreeWriter extends TreeWriterBase { final TreeWriter[] childrenWriters; - public StructTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); + public StructTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); List children = schema.getChildren(); - childrenWriters = new TreeWriterBase[children.size()]; + childrenWriters = new TreeWriter[children.size()]; for (int i = 0; i < childrenWriters.length; ++i) { - childrenWriters[i] = Factory.create(children.get(i), writer, true); + childrenWriters[i] = Factory.create(children.get(i), encryption, writer); } if (rowIndexPosition != null) { recordPosition(rowIndexPosition); @@ -109,12 +108,10 @@ public void createRowIndexEntry() throws IOException { } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); for (TreeWriter child : childrenWriters) { - child.writeStripe(builder, stats, requiredIndexEntries); + child.writeStripe(requiredIndexEntries); } if (rowIndexPosition != null) { recordPosition(rowIndexPosition); @@ -170,4 +167,12 @@ public void getCurrentStatistics(ColumnStatistics[] output) { child.getCurrentStatistics(output); } } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + for (TreeWriter child: childrenWriters) { + child.prepareStripe(stripeId); + } + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/TimestampTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/TimestampTreeWriter.java index 0f30d07757..3ba2dbeb9c 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/TimestampTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/TimestampTreeWriter.java @@ -23,15 +23,18 @@ import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.IntegerWriter; import org.apache.orc.impl.PositionRecorder; import org.apache.orc.impl.SerializationUtils; +import org.apache.orc.impl.StreamName; import java.io.IOException; import java.text.DateFormat; import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.TimeZone; +import java.util.function.Consumer; public class TimestampTreeWriter extends TreeWriterBase { public static final int MILLIS_PER_SECOND = 1000; @@ -45,16 +48,17 @@ public class TimestampTreeWriter extends TreeWriterBase { private final long baseEpochSecsLocalTz; private final long baseEpochSecsUTC; - public TimestampTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); + public TimestampTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); this.isDirectV2 = isNewWriteFormat(writer); - this.seconds = createIntegerWriter(writer.createStream(id, - OrcProto.Stream.Kind.DATA), true, isDirectV2, writer); - this.nanos = createIntegerWriter(writer.createStream(id, - OrcProto.Stream.Kind.SECONDARY), false, isDirectV2, writer); + this.seconds = createIntegerWriter(writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA, encryption)), + true, isDirectV2, writer); + this.nanos = createIntegerWriter(writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.SECONDARY, encryption)), + false, isDirectV2, writer); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -152,10 +156,8 @@ public void writeBatch(ColumnVector vector, int offset, } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -202,4 +204,12 @@ public void flushStreams() throws IOException { seconds.flush(); nanos.flush(); } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + Consumer updater = CryptoUtils.modifyIvForStripe(stripeId); + seconds.changeIv(updater); + nanos.changeIv(updater); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/TreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/TreeWriter.java index 662c57db48..680cf8cebf 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/TreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/TreeWriter.java @@ -46,6 +46,12 @@ public interface TreeWriter { */ long getRawDataSize(); + /** + * Set up for the next stripe. + * @param stripeId the next stripe id + */ + void prepareStripe(int stripeId); + /** * Write a VectorizedRowBath to the file. This is called by the WriterImplV2 * at the top level. @@ -79,17 +85,11 @@ void writeBatch(ColumnVector vector, int offset, /** * Write the stripe out to the file. - * @param stripeFooter the stripe footer that contains the information about the - * layout of the stripe. The TreeWriterBase is required to update - * the footer with its information. - * @param stats the stripe statistics information * @param requiredIndexEntries the number of index entries that are * required. this is to check to make sure the * row index is well formed. */ - void writeStripe(OrcProto.StripeFooter.Builder stripeFooter, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException; + void writeStripe(int requiredIndexEntries) throws IOException; /** * During a stripe append, we need to update the file statistics. @@ -98,7 +98,7 @@ void writeStripe(OrcProto.StripeFooter.Builder stripeFooter, void updateFileStatistics(OrcProto.StripeStatistics stripeStatistics); /** - * Add the file statistics to the file footer. + * Write the FileStatistics for each column in each encryption variant. */ void writeFileStatistics() throws IOException; @@ -110,71 +110,81 @@ void writeStripe(OrcProto.StripeFooter.Builder stripeFooter, void getCurrentStatistics(ColumnStatistics[] output); class Factory { + /** + * Create a new tree writer for the given types and insert encryption if + * required. + * @param schema the type to build a writer for + * @param encryption the encryption status + * @param streamFactory the writer context + * @return a new tree writer + */ public static TreeWriter create(TypeDescription schema, - WriterContext streamFactory, - boolean nullable) throws IOException { + WriterEncryptionVariant encryption, + WriterContext streamFactory) throws IOException { + if (encryption == null) { + // If we are the root of an encryption variant, create a special writer. + encryption = streamFactory.getEncryption(schema.getId()); + if (encryption != null) { + return new EncryptionTreeWriter(schema, encryption, streamFactory); + } + } + return createSubtree(schema, encryption, streamFactory); + } + + /** + * Create a subtree without inserting encryption nodes + * @param schema the schema to create + * @param encryption the encryption variant + * @param streamFactory the writer context + * @return a new tree writer + */ + static TreeWriter createSubtree(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext streamFactory) throws IOException { OrcFile.Version version = streamFactory.getVersion(); switch (schema.getCategory()) { - case BOOLEAN: - return new BooleanTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case BYTE: - return new ByteTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case SHORT: - case INT: - case LONG: - return new IntegerTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case FLOAT: - return new FloatTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case DOUBLE: - return new DoubleTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case STRING: - return new StringTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case CHAR: - return new CharTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case VARCHAR: - return new VarcharTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case BINARY: - return new BinaryTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case TIMESTAMP: - return new TimestampTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case DATE: - return new DateTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case DECIMAL: - if (version == OrcFile.Version.UNSTABLE_PRE_2_0 && - schema.getPrecision() <= TypeDescription.MAX_DECIMAL64_PRECISION) { - return new Decimal64TreeWriter(schema.getId(), - schema, streamFactory, nullable); - } - return new DecimalTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case STRUCT: - return new StructTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case MAP: - return new MapTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case LIST: - return new ListTreeWriter(schema.getId(), - schema, streamFactory, nullable); - case UNION: - return new UnionTreeWriter(schema.getId(), - schema, streamFactory, nullable); - default: - throw new IllegalArgumentException("Bad category: " + - schema.getCategory()); + case BOOLEAN: + return new BooleanTreeWriter(schema, encryption, streamFactory); + case BYTE: + return new ByteTreeWriter(schema, encryption, streamFactory); + case SHORT: + case INT: + case LONG: + return new IntegerTreeWriter(schema, encryption, streamFactory); + case FLOAT: + return new FloatTreeWriter(schema, encryption, streamFactory); + case DOUBLE: + return new DoubleTreeWriter(schema, encryption, streamFactory); + case STRING: + return new StringTreeWriter(schema, encryption, streamFactory); + case CHAR: + return new CharTreeWriter(schema, encryption, streamFactory); + case VARCHAR: + return new VarcharTreeWriter(schema, encryption, streamFactory); + case BINARY: + return new BinaryTreeWriter(schema, encryption, streamFactory); + case TIMESTAMP: + return new TimestampTreeWriter(schema, encryption, streamFactory); + case DATE: + return new DateTreeWriter(schema, encryption, streamFactory); + case DECIMAL: + if (version == OrcFile.Version.UNSTABLE_PRE_2_0 && + schema.getPrecision() <= TypeDescription.MAX_DECIMAL64_PRECISION) { + return new Decimal64TreeWriter(schema, encryption, streamFactory); + } + return new DecimalTreeWriter(schema, encryption, streamFactory); + case STRUCT: + return new StructTreeWriter(schema, encryption, streamFactory); + case MAP: + return new MapTreeWriter(schema, encryption, streamFactory); + case LIST: + return new ListTreeWriter(schema, encryption, streamFactory); + case UNION: + return new UnionTreeWriter(schema, encryption, streamFactory); + default: + throw new IllegalArgumentException("Bad category: " + + schema.getCategory()); } } - } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/TreeWriterBase.java b/java/core/src/java/org/apache/orc/impl/writer/TreeWriterBase.java index 4c36d09956..17a2a5fbb4 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/TreeWriterBase.java +++ b/java/core/src/java/org/apache/orc/impl/writer/TreeWriterBase.java @@ -29,6 +29,7 @@ import org.apache.orc.TypeDescription; import org.apache.orc.impl.BitFieldWriter; import org.apache.orc.impl.ColumnStatisticsImpl; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.IntegerWriter; import org.apache.orc.impl.OutStream; import org.apache.orc.impl.PositionRecorder; @@ -49,6 +50,8 @@ public abstract class TreeWriterBase implements TreeWriter { protected final int id; protected final BitFieldWriter isPresent; + protected final TypeDescription schema; + protected final WriterEncryptionVariant encryption; private final boolean isCompressed; protected final ColumnStatisticsImpl indexStatistics; protected final ColumnStatisticsImpl stripeColStatistics; @@ -64,37 +67,31 @@ public abstract class TreeWriterBase implements TreeWriter { protected final OrcProto.BloomFilter.Builder bloomFilterEntry; private boolean foundNulls; private OutStream isPresentOutStream; - private final WriterContext streamFactory; - private final TypeDescription schema; + protected final WriterContext context; /** * Create a tree writer. - * @param columnId the column id of the column to write * @param schema the row schema - * @param streamFactory limited access to the Writer's data. - * @param nullable can the value be null? + * @param encryption the encryption variant or null if it is unencrypted + * @param context limited access to the Writer's data. */ - TreeWriterBase(int columnId, - TypeDescription schema, - WriterContext streamFactory, - boolean nullable) throws IOException { + TreeWriterBase(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext context) throws IOException { this.schema = schema; - this.streamFactory = streamFactory; - this.isCompressed = streamFactory.isCompressed(); - this.id = columnId; - if (nullable) { - isPresentOutStream = streamFactory.createStream(id, - OrcProto.Stream.Kind.PRESENT); - isPresent = new BitFieldWriter(isPresentOutStream, 1); - } else { - isPresent = null; - } + this.encryption = encryption; + this.context = context; + this.isCompressed = context.isCompressed(); + this.id = schema.getId(); + isPresentOutStream = context.createStream(new StreamName(id, + OrcProto.Stream.Kind.PRESENT, encryption)); + isPresent = new BitFieldWriter(isPresentOutStream, 1); this.foundNulls = false; - createBloomFilter = streamFactory.getBloomFilterColumns()[columnId]; + createBloomFilter = context.getBloomFilterColumns()[id]; indexStatistics = ColumnStatisticsImpl.create(schema); stripeColStatistics = ColumnStatisticsImpl.create(schema); fileStatistics = ColumnStatisticsImpl.create(schema); - if (streamFactory.buildIndex()) { + if (context.buildIndex()) { rowIndex = OrcProto.RowIndex.newBuilder(); rowIndexEntry = OrcProto.RowIndexEntry.newBuilder(); rowIndexPosition = new RowIndexPositionRecorder(rowIndexEntry); @@ -105,16 +102,16 @@ public abstract class TreeWriterBase implements TreeWriter { } if (createBloomFilter) { bloomFilterEntry = OrcProto.BloomFilter.newBuilder(); - if (streamFactory.getBloomFilterVersion() == OrcFile.BloomFilterVersion.ORIGINAL) { - bloomFilter = new BloomFilter(streamFactory.getRowIndexStride(), - streamFactory.getBloomFilterFPP()); + if (context.getBloomFilterVersion() == OrcFile.BloomFilterVersion.ORIGINAL) { + bloomFilter = new BloomFilter(context.getRowIndexStride(), + context.getBloomFilterFPP()); bloomFilterIndex = OrcProto.BloomFilterIndex.newBuilder(); } else { bloomFilter = null; bloomFilterIndex = null; } - bloomFilterUtf8 = new BloomFilterUtf8(streamFactory.getRowIndexStride(), - streamFactory.getBloomFilterFPP()); + bloomFilterUtf8 = new BloomFilterUtf8(context.getRowIndexStride(), + context.getBloomFilterFPP()); bloomFilterIndexUtf8 = OrcProto.BloomFilterIndex.newBuilder(); } else { bloomFilterEntry = null; @@ -233,17 +230,21 @@ private void removeIsPresentPositions() { } @Override - public void flushStreams() throws IOException { + public void prepareStripe(int stripeId) { + if (isPresent != null) { + isPresent.changeIv(CryptoUtils.modifyIvForStripe(stripeId)); + } + } + @Override + public void flushStreams() throws IOException { if (isPresent != null) { isPresent.flush(); } - } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, int requiredIndexEntries) throws IOException { + public void writeStripe(int requiredIndexEntries) throws IOException { // if no nulls are found in a stream, then suppress the stream if (isPresent != null && !foundNulls) { @@ -253,50 +254,47 @@ public void writeStripe(OrcProto.StripeFooter.Builder builder, if (rowIndex != null) { removeIsPresentPositions(); } - } /* Update byte count */ - final long byteCount = streamFactory.getPhysicalWriter().getFileBytes(id, null); + final long byteCount = context.getPhysicalWriter().getFileBytes(id, encryption); stripeColStatistics.updateByteCount(byteCount); // merge stripe-level column statistics to file statistics and write it to // stripe statistics fileStatistics.merge(stripeColStatistics); - streamFactory.writeStatistics( - new StreamName(id, OrcProto.Stream.Kind.STRIPE_STATISTICS, null), + context.writeStatistics( + new StreamName(id, OrcProto.Stream.Kind.STRIPE_STATISTICS, encryption), stripeColStatistics.serialize()); - stats.addColStats(stripeColStatistics.serialize()); stripeColStatistics.reset(); // reset the flag for next stripe foundNulls = false; - builder.addColumns(getEncoding()); + context.setEncoding(id, encryption, getEncoding().build()); if (rowIndex != null) { if (rowIndex.getEntryCount() != requiredIndexEntries) { throw new IllegalArgumentException("Column has wrong number of " + "index entries found: " + rowIndex.getEntryCount() + " expected: " + requiredIndexEntries); } - streamFactory.writeIndex(new StreamName(id, OrcProto.Stream.Kind.ROW_INDEX), rowIndex); + context.writeIndex(new StreamName(id, OrcProto.Stream.Kind.ROW_INDEX), rowIndex); rowIndex.clear(); rowIndexEntry.clear(); } // write the bloom filter to out stream if (bloomFilterIndex != null) { - streamFactory.writeBloomFilter(new StreamName(id, + context.writeBloomFilter(new StreamName(id, OrcProto.Stream.Kind.BLOOM_FILTER), bloomFilterIndex); bloomFilterIndex.clear(); } // write the bloom filter to out stream if (bloomFilterIndexUtf8 != null) { - streamFactory.writeBloomFilter(new StreamName(id, + context.writeBloomFilter(new StreamName(id, OrcProto.Stream.Kind.BLOOM_FILTER_UTF8), bloomFilterIndexUtf8); bloomFilterIndexUtf8.clear(); } - } /** @@ -374,9 +372,9 @@ public long estimateMemory() { @Override public void writeFileStatistics() throws IOException { - streamFactory.writeStatistics(new StreamName(id, - OrcProto.Stream.Kind.FILE_STATISTICS, null), - fileStatistics.serialize()); + context.writeStatistics(new StreamName(id, + OrcProto.Stream.Kind.FILE_STATISTICS, encryption), + fileStatistics.serialize()); } static class RowIndexPositionRecorder implements PositionRecorder { diff --git a/java/core/src/java/org/apache/orc/impl/writer/UnionTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/UnionTreeWriter.java index 4456b62f55..df4dfef123 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/UnionTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/UnionTreeWriter.java @@ -23,8 +23,10 @@ import org.apache.orc.ColumnStatistics; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; +import org.apache.orc.impl.CryptoUtils; import org.apache.orc.impl.PositionRecorder; import org.apache.orc.impl.RunLengthByteWriter; +import org.apache.orc.impl.StreamName; import java.io.IOException; import java.util.List; @@ -33,19 +35,18 @@ public class UnionTreeWriter extends TreeWriterBase { private final RunLengthByteWriter tags; private final TreeWriter[] childrenWriters; - UnionTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); + UnionTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); List children = schema.getChildren(); childrenWriters = new TreeWriterBase[children.size()]; for (int i = 0; i < childrenWriters.length; ++i) { - childrenWriters[i] = Factory.create(children.get(i), writer, true); + childrenWriters[i] = Factory.create(children.get(i), encryption, writer); } tags = - new RunLengthByteWriter(writer.createStream(columnId, - OrcProto.Stream.Kind.DATA)); + new RunLengthByteWriter(writer.createStream( + new StreamName(id, OrcProto.Stream.Kind.DATA, encryption))); if (rowIndexPosition != null) { recordPosition(rowIndexPosition); } @@ -121,12 +122,10 @@ public void createRowIndexEntry() throws IOException { } @Override - public void writeStripe(OrcProto.StripeFooter.Builder builder, - OrcProto.StripeStatistics.Builder stats, - int requiredIndexEntries) throws IOException { - super.writeStripe(builder, stats, requiredIndexEntries); + public void writeStripe(int requiredIndexEntries) throws IOException { + super.writeStripe(requiredIndexEntries); for (TreeWriter child : childrenWriters) { - child.writeStripe(builder, stats, requiredIndexEntries); + child.writeStripe(requiredIndexEntries); } if (rowIndexPosition != null) { recordPosition(rowIndexPosition); @@ -185,8 +184,17 @@ public void flushStreams() throws IOException { @Override public void getCurrentStatistics(ColumnStatistics[] output) { super.getCurrentStatistics(output); - for (TreeWriter child: childrenWriters) { + for(TreeWriter child: childrenWriters) { child.getCurrentStatistics(output); } } + + @Override + public void prepareStripe(int stripeId) { + super.prepareStripe(stripeId); + tags.changeIv(CryptoUtils.modifyIvForStripe(stripeId)); + for (TreeWriter child: childrenWriters) { + child.prepareStripe(stripeId); + } + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/VarcharTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/VarcharTreeWriter.java index b08ef437cf..29a6ab75d4 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/VarcharTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/VarcharTreeWriter.java @@ -32,11 +32,10 @@ public class VarcharTreeWriter extends StringBaseTreeWriter { private final int maxLength; - VarcharTreeWriter(int columnId, - TypeDescription schema, - WriterContext writer, - boolean nullable) throws IOException { - super(columnId, schema, writer, nullable); + VarcharTreeWriter(TypeDescription schema, + WriterEncryptionVariant encryption, + WriterContext writer) throws IOException { + super(schema, encryption, writer); maxLength = schema.getMaxLength(); } diff --git a/java/core/src/java/org/apache/orc/impl/writer/WriterContext.java b/java/core/src/java/org/apache/orc/impl/writer/WriterContext.java index cfd2ec7ab0..73542ad976 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/WriterContext.java +++ b/java/core/src/java/org/apache/orc/impl/writer/WriterContext.java @@ -19,9 +19,12 @@ package org.apache.orc.impl.writer; import org.apache.hadoop.conf.Configuration; +import org.apache.orc.CompressionCodec; +import org.apache.orc.DataMask; import org.apache.orc.OrcFile; import org.apache.orc.OrcProto; import org.apache.orc.PhysicalWriter; +import org.apache.orc.TypeDescription; import org.apache.orc.impl.OutStream; import org.apache.orc.impl.StreamName; @@ -30,88 +33,115 @@ public interface WriterContext { /** - * Create a stream to store part of a column. - * @param column the column id for the stream - * @param kind the kind of stream - * @return The output outStream that the section needs to be written to. - */ - OutStream createStream(int column, - OrcProto.Stream.Kind kind - ) throws IOException; - - /** - * Get the stride rate of the row index. - */ - int getRowIndexStride(); - - /** - * Should be building the row index. - * @return true if we are building the index - */ - boolean buildIndex(); - - /** - * Is the ORC file compressed? - * @return are the streams compressed - */ - boolean isCompressed(); - - /** - * Get the encoding strategy to use. - * @return encoding strategy - */ - OrcFile.EncodingStrategy getEncodingStrategy(); - - /** - * Get the bloom filter columns - * @return bloom filter columns - */ - boolean[] getBloomFilterColumns(); - - /** - * Get bloom filter false positive percentage. - * @return fpp - */ - double getBloomFilterFPP(); - - /** - * Get the writer's configuration. - * @return configuration - */ - Configuration getConfiguration(); - - /** - * Get the version of the file to write. - */ - OrcFile.Version getVersion(); - - /** - * Get the PhysicalWriter. - * - * @return the file's physical writer. - */ - PhysicalWriter getPhysicalWriter(); - - - OrcFile.BloomFilterVersion getBloomFilterVersion(); - - void writeIndex(StreamName name, - OrcProto.RowIndex.Builder index) throws IOException; - - void writeBloomFilter(StreamName name, - OrcProto.BloomFilterIndex.Builder bloom - ) throws IOException; - - /** - * Set the column statistics for the stripe or file. - * @param name the name of the statistics stream - * @param stats the statistics for this column in this stripe - */ - void writeStatistics(StreamName name, - OrcProto.ColumnStatistics.Builder stats - ) throws IOException; - - boolean getUseUTCTimestamp(); - - double getDictionaryKeySizeThreshold(int column); + * Create a stream to store part of a column. + * @param name the name of the stream + * @return The output outStream that the section needs to be written to. + */ + OutStream createStream(StreamName name) throws IOException; + + /** + * Get the stride rate of the row index. + */ + int getRowIndexStride(); + + /** + * Should be building the row index. + * @return true if we are building the index + */ + boolean buildIndex(); + + /** + * Is the ORC file compressed? + * @return are the streams compressed + */ + boolean isCompressed(); + + /** + * Get the encoding strategy to use. + * @return encoding strategy + */ + OrcFile.EncodingStrategy getEncodingStrategy(); + + /** + * Get the bloom filter columns + * @return bloom filter columns + */ + boolean[] getBloomFilterColumns(); + + /** + * Get bloom filter false positive percentage. + * @return fpp + */ + double getBloomFilterFPP(); + + /** + * Get the writer's configuration. + * @return configuration + */ + Configuration getConfiguration(); + + /** + * Get the version of the file to write. + */ + OrcFile.Version getVersion(); + + OrcFile.BloomFilterVersion getBloomFilterVersion(); + + void writeIndex(StreamName name, + OrcProto.RowIndex.Builder index) throws IOException; + + void writeBloomFilter(StreamName name, + OrcProto.BloomFilterIndex.Builder bloom + ) throws IOException; + + /** + * Get the mask for the unencrypted variant. + * @param columnId the column id + * @return the mask to apply to the unencrypted data or null if there is none + */ + DataMask getUnencryptedMask(int columnId); + + /** + * Get the encryption for the given column. + * @param columnId the root column id + * @return the column encryption or null if it isn't encrypted + */ + WriterEncryptionVariant getEncryption(int columnId); + + /** + * Get the PhysicalWriter. + * @return the file's physical writer. + */ + PhysicalWriter getPhysicalWriter(); + + /** + * Set the encoding for the current stripe. + * @param column the column identifier + * @param variant the encryption variant + * @param encoding the encoding for this stripe + */ + void setEncoding(int column, WriterEncryptionVariant variant, + OrcProto.ColumnEncoding encoding); + + /** + * Set the column statistics for the stripe or file. + * @param name the name of the statistics stream + * @param stats the statistics for this column in this stripe + */ + void writeStatistics(StreamName name, + OrcProto.ColumnStatistics.Builder stats + ) throws IOException; + + /** + * Should the writer use UTC as the timezone? + */ + boolean getUseUTCTimestamp(); + + /** + * Get the dictionary key size threshold. + * @param columnId the column id + * @return the minimum ratio for using a dictionary + */ + double getDictionaryKeySizeThreshold(int columnId); + } diff --git a/java/core/src/test/org/apache/orc/TestStringDictionary.java b/java/core/src/test/org/apache/orc/TestStringDictionary.java index 2ab15064c0..b0d39a0931 100644 --- a/java/core/src/test/org/apache/orc/TestStringDictionary.java +++ b/java/core/src/test/org/apache/orc/TestStringDictionary.java @@ -40,6 +40,7 @@ import org.apache.orc.impl.writer.StringTreeWriter; import org.apache.orc.impl.writer.TreeWriter; import org.apache.orc.impl.writer.WriterContext; +import org.apache.orc.impl.writer.WriterEncryptionVariant; import org.junit.Assert; import org.junit.Before; import org.junit.Rule; @@ -172,9 +173,9 @@ static class WriterContextImpl implements WriterContext { } @Override - public OutStream createStream(int column, OrcProto.Stream.Kind kind) throws IOException { + public OutStream createStream(StreamName name) throws IOException { TestInStream.OutputCollector collect = new TestInStream.OutputCollector(); - streams.put(new StreamName(column, kind), collect); + streams.put(name, collect); return new OutStream("test", new StreamOptions(1000), collect); } @@ -223,6 +224,16 @@ public PhysicalWriter getPhysicalWriter() { return null; } + @Override + public void setEncoding(int column, WriterEncryptionVariant variant, OrcProto.ColumnEncoding encoding) { + + } + + @Override + public void writeStatistics(StreamName name, OrcProto.ColumnStatistics.Builder stats) throws IOException { + + } + @Override public OrcFile.BloomFilterVersion getBloomFilterVersion() { return OrcFile.BloomFilterVersion.UTF8; @@ -240,8 +251,13 @@ public void writeBloomFilter(StreamName name, } @Override - public void writeStatistics(StreamName name, OrcProto.ColumnStatistics.Builder stats) throws IOException { + public DataMask getUnencryptedMask(int columnId) { + return null; + } + @Override + public WriterEncryptionVariant getEncryption(int columnId) { + return null; } @Override @@ -262,7 +278,7 @@ public void testNonDistinctDisabled() throws Exception { conf.set(OrcConf.DICTIONARY_KEY_SIZE_THRESHOLD.getAttribute(), "0.0"); WriterContextImpl writerContext = new WriterContextImpl(schema, conf); StringTreeWriter writer = (StringTreeWriter) - TreeWriter.Factory.create(schema, writerContext, true); + TreeWriter.Factory.create(schema, null, writerContext); VectorizedRowBatch batch = schema.createRowBatch(); BytesColumnVector col = (BytesColumnVector) batch.cols[0]; From a68ce245dbccbd246ce7070af7d9a36ff4277f1b Mon Sep 17 00:00:00 2001 From: Owen O'Malley Date: Mon, 24 Jun 2019 13:05:55 -0700 Subject: [PATCH 3/4] ORC-516: Update InStream for column compression. --- .../java/org/apache/orc/impl/BufferChunk.java | 29 +- .../org/apache/orc/impl/BufferChunkList.java | 14 + .../java/org/apache/orc/impl/InStream.java | 297 +++++++++++------- .../java/org/apache/orc/impl/ReaderImpl.java | 4 +- .../org/apache/orc/impl/RecordReaderImpl.java | 4 +- .../apache/orc/impl/RecordReaderUtils.java | 6 +- .../apache/orc/impl/TestBitFieldReader.java | 6 +- .../test/org/apache/orc/impl/TestBitPack.java | 2 +- .../org/apache/orc/impl/TestInStream.java | 26 +- .../impl/TestIntegerCompressionReader.java | 4 +- .../org/apache/orc/impl/TestOutStream.java | 15 +- .../apache/orc/impl/TestPhysicalFsWriter.java | 19 +- .../orc/impl/TestRunLengthByteReader.java | 6 +- .../orc/impl/TestRunLengthIntegerReader.java | 8 +- .../apache/orc/impl/TestSchemaEvolution.java | 15 +- 15 files changed, 282 insertions(+), 173 deletions(-) diff --git a/java/core/src/java/org/apache/orc/impl/BufferChunk.java b/java/core/src/java/org/apache/orc/impl/BufferChunk.java index afde82f50c..951a6d8bce 100644 --- a/java/core/src/java/org/apache/orc/impl/BufferChunk.java +++ b/java/core/src/java/org/apache/orc/impl/BufferChunk.java @@ -1,6 +1,4 @@ -package org.apache.orc.impl; - -/** +/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -18,6 +16,8 @@ * limitations under the License. */ +package org.apache.orc.impl; + import org.apache.hadoop.hive.common.io.DiskRange; import org.apache.hadoop.hive.common.io.DiskRangeList; import org.slf4j.Logger; @@ -34,15 +34,20 @@ public class BufferChunk extends DiskRangeList { private static final Logger LOG = LoggerFactory.getLogger(BufferChunk.class); - final ByteBuffer chunk; + private ByteBuffer chunk; + + public BufferChunk(long offset, int length) { + super(offset, offset + length); + chunk = null; + } public BufferChunk(ByteBuffer chunk, long offset) { super(offset, offset + chunk.remaining()); this.chunk = chunk; } - public ByteBuffer getChunk() { - return chunk; + public void setChunk(ByteBuffer chunk) { + this.chunk = chunk; } @Override @@ -52,10 +57,14 @@ public boolean hasData() { @Override public final String toString() { - boolean makesSense = chunk.remaining() == (end - offset); - return "data range [" + offset + ", " + end + "), size: " + chunk.remaining() - + (makesSense ? "" : "(!)") + " type: " + - (chunk.isDirect() ? "direct" : "array-backed"); + if (chunk == null) { + return "data range[" + offset + ", " + end +")"; + } else { + boolean makesSense = chunk.remaining() == (end - offset); + return "data range [" + offset + ", " + end + "), size: " + chunk.remaining() + + (makesSense ? "" : "(!)") + " type: " + + (chunk.isDirect() ? "direct" : "array-backed"); + } } @Override diff --git a/java/core/src/java/org/apache/orc/impl/BufferChunkList.java b/java/core/src/java/org/apache/orc/impl/BufferChunkList.java index d8a89dbce6..1a0aea9661 100644 --- a/java/core/src/java/org/apache/orc/impl/BufferChunkList.java +++ b/java/core/src/java/org/apache/orc/impl/BufferChunkList.java @@ -32,6 +32,7 @@ public void add(BufferChunk value) { } else { tail.next = value; value.prev = tail; + value.next = null; tail = value; } } @@ -40,6 +41,19 @@ public BufferChunk get() { return head; } + /** + * Get the nth element of the list + * @param chunk the element number to get from 0 + * @return the given element number + */ + public BufferChunk get(int chunk) { + BufferChunk ptr = head; + for(int i=0; i < chunk; ++i) { + ptr = ptr == null ? null : (BufferChunk) ptr.next; + } + return ptr; + } + public void clear() { head = null; tail = null; diff --git a/java/core/src/java/org/apache/orc/impl/InStream.java b/java/core/src/java/org/apache/orc/impl/InStream.java index 06f439effe..dd4429e3e0 100644 --- a/java/core/src/java/org/apache/orc/impl/InStream.java +++ b/java/core/src/java/org/apache/orc/impl/InStream.java @@ -23,6 +23,7 @@ import java.security.InvalidAlgorithmParameterException; import java.security.InvalidKeyException; import java.security.Key; +import java.util.function.Consumer; import org.apache.hadoop.hive.common.io.DiskRangeList; import org.apache.orc.CompressionCodec; @@ -41,25 +42,25 @@ public abstract class InStream extends InputStream { private static final Logger LOG = LoggerFactory.getLogger(InStream.class); public static final int PROTOBUF_MESSAGE_MAX_LIMIT = 1024 << 20; // 1GB - protected final String name; - protected long length; + protected final Object name; + protected final long offset; + protected final long length; - public InStream(String name, long length) { + public InStream(Object name, long offset, long length) { this.name = name; + this.offset = offset; this.length = length; } - public String getStreamName() { - return name; - } - - public long getStreamLength() { - return length; + public String toString() { + return name.toString(); } @Override public abstract void close(); + public abstract void changeIv(Consumer modifier); + static int getRangeNumber(DiskRangeList list, DiskRangeList current) { int result = 0; DiskRangeList range = list; @@ -75,8 +76,8 @@ static int getRangeNumber(DiskRangeList list, DiskRangeList current) { */ public static class UncompressedStream extends InStream { private DiskRangeList bytes; - private long length; - protected long currentOffset; + // position in the stream (0..length) + protected long position; protected ByteBuffer decrypted; protected DiskRangeList currentRange; @@ -86,33 +87,33 @@ public static class UncompressedStream extends InStream { * @param name name of the stream * @param length the number of bytes for the stream */ - public UncompressedStream(String name, long length) { - super(name, length); + public UncompressedStream(Object name, long offset, long length) { + super(name, offset, length); } - public UncompressedStream(String name, + public UncompressedStream(Object name, DiskRangeList input, + long offset, long length) { - super(name, length); - reset(input, length); + super(name, offset, length); + reset(input); } - protected void reset(DiskRangeList input, long length) { + protected void reset(DiskRangeList input) { this.bytes = input; - this.length = length; - currentOffset = input == null ? 0 : input.getOffset(); + position = input == null ? 0 : input.getOffset() - offset; setCurrent(input, true); } @Override public int read() { if (decrypted == null || decrypted.remaining() == 0) { - if (currentOffset == length) { + if (position == length) { return -1; } setCurrent(currentRange.next, false); } - currentOffset += 1; + position += 1; return 0xff & decrypted.get(); } @@ -122,21 +123,21 @@ protected void setCurrent(DiskRangeList newRange, boolean isJump) { decrypted = newRange.getData().slice(); // Move the position in the ByteBuffer to match the currentOffset, // which is relative to the stream. - decrypted.position((int) (currentOffset - newRange.getOffset())); + decrypted.position((int) (position + offset - newRange.getOffset())); } } @Override public int read(byte[] data, int offset, int length) { if (decrypted == null || decrypted.remaining() == 0) { - if (currentOffset == this.length) { + if (position == this.length) { return -1; } setCurrent(currentRange.next, false); } int actualLength = Math.min(length, decrypted.remaining()); decrypted.get(data, offset, actualLength); - currentOffset += actualLength; + position += actualLength; return actualLength; } @@ -145,18 +146,23 @@ public int available() { if (decrypted != null && decrypted.remaining() > 0) { return decrypted.remaining(); } - return (int) (length - currentOffset); + return (int) (length - position); } @Override public void close() { currentRange = null; - currentOffset = length; + position = length; // explicit de-ref of bytes[] decrypted = null; bytes = null; } + @Override + public void changeIv(Consumer modifier) { + // nothing to do + } + @Override public void seek(PositionProvider index) throws IOException { seek(index.getNext()); @@ -166,18 +172,20 @@ public void seek(long desired) throws IOException { if (desired == 0 && bytes == null) { return; } + // compute the position of the desired point in file + long posn = desired + offset; // If we are seeking inside of the current range, just reposition. - if (currentRange != null && desired >= currentRange.getOffset() && - desired < currentRange.getEnd()) { - decrypted.position((int) (desired - currentRange.getOffset())); - currentOffset = desired; + if (currentRange != null && posn >= currentRange.getOffset() && + posn < currentRange.getEnd()) { + decrypted.position((int) (posn - currentRange.getOffset())); + position = desired; } else { for (DiskRangeList curRange = bytes; curRange != null; curRange = curRange.next) { - if (curRange.getOffset() <= desired && - (curRange.next == null ? desired <= curRange.getEnd() : - desired < curRange.getEnd())) { - currentOffset = desired; + if (curRange.getOffset() <= posn && + (curRange.next == null ? posn <= curRange.getEnd() : + posn < curRange.getEnd())) { + position = desired; setCurrent(curRange, true); return; } @@ -189,7 +197,7 @@ public void seek(long desired) throws IOException { @Override public String toString() { - return "uncompressed stream " + name + " position: " + currentOffset + + return "uncompressed stream " + name + " position: " + position + " length: " + length + " range: " + getRangeNumber(bytes, currentRange) + " offset: " + (decrypted == null ? 0 : decrypted.position()) + " limit: " + (decrypted == null ? 0 : decrypted.limit()); @@ -209,21 +217,39 @@ private static ByteBuffer allocateBuffer(int size, boolean isDirect) { * Manage the state of the decryption, including the ability to seek. */ static class EncryptionState { - private final String name; + private final Object name; private final EncryptionAlgorithm algorithm; private final Key key; private final byte[] iv; private final Cipher cipher; + private final long offset; private ByteBuffer decrypted; - EncryptionState(String name, StreamOptions options) { + EncryptionState(Object name, long offset, StreamOptions options) { this.name = name; - algorithm = options.algorithm; - key = options.key; - iv = options.iv; + this.offset = offset; + algorithm = options.getAlgorithm(); + key = options.getKey(); + iv = options.getIv(); cipher = algorithm.createCipher(); } + void changeIv(Consumer modifier) { + modifier.accept(iv); + updateIv(); + OutStream.logKeyAndIv(name, key, iv); + } + + private void updateIv() { + try { + cipher.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv)); + } catch (InvalidKeyException e) { + throw new IllegalArgumentException("Invalid key on " + name, e); + } catch (InvalidAlgorithmParameterException e) { + throw new IllegalArgumentException("Invalid iv on " + name, e); + } + } + /** * We are seeking to a new range, so update the cipher to change the IV * to match. This code assumes that we only support encryption in CTR mode. @@ -233,35 +259,27 @@ void changeIv(long offset) { int blockSize = cipher.getBlockSize(); long encryptionBlocks = offset / blockSize; long extra = offset % blockSize; - byte[] advancedIv; - if (encryptionBlocks == 0) { - advancedIv = iv; - } else { + CryptoUtils.clearCounter(iv); + if (encryptionBlocks != 0) { // Add the encryption blocks into the initial iv, to compensate for // skipping over decrypting those bytes. - advancedIv = new byte[iv.length]; - System.arraycopy(iv, 0, advancedIv, 0, iv.length); int posn = iv.length - 1; while (encryptionBlocks > 0) { - long sum = (advancedIv[posn] & 0xff) + encryptionBlocks; - advancedIv[posn--] = (byte) sum; + long sum = (iv[posn] & 0xff) + encryptionBlocks; + iv[posn--] = (byte) sum; encryptionBlocks = sum / 0x100; } } - try { - cipher.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(advancedIv)); - // If the range starts at an offset that doesn't match the encryption - // block, we need to advance some bytes within an encryption block. - if (extra > 0) { + updateIv(); + // If the range starts at an offset that doesn't match the encryption + // block, we need to advance some bytes within an encryption block. + if (extra > 0) { + try { byte[] wasted = new byte[(int) extra]; cipher.update(wasted, 0, wasted.length, wasted, 0); + } catch (ShortBufferException e) { + throw new IllegalArgumentException("Short buffer in " + name, e); } - } catch (InvalidKeyException e) { - throw new IllegalArgumentException("Invalid key on " + name, e); - } catch (InvalidAlgorithmParameterException e) { - throw new IllegalArgumentException("Invalid iv on " + name, e); - } catch (ShortBufferException e) { - throw new IllegalArgumentException("Short buffer in " + name, e); } } @@ -305,11 +323,11 @@ void close() { public static class EncryptedStream extends UncompressedStream { private final EncryptionState encrypt; - public EncryptedStream(String name, DiskRangeList input, long length, + public EncryptedStream(Object name, DiskRangeList input, long offset, long length, StreamOptions options) { - super(name, length); - encrypt = new EncryptionState(name, options); - reset(input, length); + super(name, offset, length); + encrypt = new EncryptionState(name, offset, options); + reset(input); } @Override @@ -317,10 +335,10 @@ protected void setCurrent(DiskRangeList newRange, boolean isJump) { currentRange = newRange; if (newRange != null) { if (isJump) { - encrypt.changeIv(newRange.getOffset()); + encrypt.changeIv(newRange.getOffset() - offset); } decrypted = encrypt.decrypt(newRange); - decrypted.position((int) (currentOffset - newRange.getOffset())); + decrypted.position((int) (position + offset - newRange.getOffset())); } } @@ -330,6 +348,11 @@ public void close() { encrypt.close(); } + @Override + public void changeIv(Consumer modifier) { + encrypt.changeIv(modifier); + } + @Override public String toString() { return "encrypted " + super.toString(); @@ -342,7 +365,7 @@ private static class CompressedStream extends InStream { private ByteBuffer uncompressed; private final CompressionCodec codec; protected ByteBuffer compressed; - protected long currentOffset; + protected long position; protected DiskRangeList currentRange; private boolean isUncompressedOriginal; @@ -354,10 +377,11 @@ private static class CompressedStream extends InStream { * @param length the total number of bytes in the stream * @param options the options used to read the stream */ - public CompressedStream(String name, + public CompressedStream(Object name, + long offset, long length, StreamOptions options) { - super(name, length); + super(name, offset, length); this.codec = options.codec; this.bufferSize = options.bufferSize; } @@ -369,25 +393,24 @@ public CompressedStream(String name, * @param length the total length of the stream * @param options the options to read the data with */ - public CompressedStream(String name, + public CompressedStream(Object name, DiskRangeList input, + long offset, long length, StreamOptions options) { - super(name, length); + super(name, offset, length); this.codec = options.codec; this.bufferSize = options.bufferSize; - reset(input, length); + reset(input); } /** * Reset the input to a new set of data. * @param input the input data - * @param length the number of bytes in the stream */ - void reset(DiskRangeList input, long length) { + void reset(DiskRangeList input) { bytes = input; - this.length = length; - currentOffset = input == null ? 0 : input.getOffset(); + position = input == null ? 0 : input.getOffset() - offset; setCurrent(input, true); } @@ -400,7 +423,7 @@ protected void setCurrent(DiskRangeList newRange, currentRange = newRange; if (newRange != null) { compressed = newRange.getData().slice(); - compressed.position((int) (currentOffset - newRange.getOffset())); + compressed.position((int) (position + offset - newRange.getOffset())); } } @@ -417,11 +440,11 @@ private void readHeader() throws IOException { if (chunkLength > bufferSize) { throw new IllegalArgumentException("Buffer size too small. size = " + - bufferSize + " needed = " + chunkLength); + bufferSize + " needed = " + chunkLength + " in " + name); } // read 3 bytes, which should be equal to OutStream.HEADER_SIZE always assert OutStream.HEADER_SIZE == 3 : "The Orc HEADER_SIZE must be the same in OutStream and InStream"; - currentOffset += OutStream.HEADER_SIZE; + position += OutStream.HEADER_SIZE; ByteBuffer slice = this.slice(chunkLength); @@ -464,7 +487,7 @@ public int read(byte[] data, int offset, int length) throws IOException { private boolean ensureUncompressed() throws IOException { while (uncompressed == null || uncompressed.remaining() == 0) { - if (currentOffset == this.length) { + if (position == this.length) { return false; } readHeader(); @@ -485,10 +508,15 @@ public void close() { uncompressed = null; compressed = null; currentRange = null; - currentOffset = length; + position = length; bytes = null; } + @Override + public void changeIv(Consumer modifier) { + // nothing to do + } + @Override public void seek(PositionProvider index) throws IOException { seek(index.getNext()); @@ -507,13 +535,13 @@ public void seek(PositionProvider index) throws IOException { private ByteBuffer slice(int chunkLength) throws IOException { int len = chunkLength; final DiskRangeList oldRange = currentRange; - final long oldOffset = currentOffset; + final long oldPosition = position; ByteBuffer slice; if (compressed.remaining() >= len) { slice = compressed.slice(); // simple case slice.limit(len); - currentOffset += len; + position += len; compressed.position(compressed.position() + len); return slice; } else if (currentRange.next == null) { @@ -531,7 +559,7 @@ private ByteBuffer slice(int chunkLength) throws IOException { // we need to consolidate 2 or more buffers into 1 // first copy out compressed buffers ByteBuffer copy = allocateBuffer(chunkLength, compressed.isDirect()); - currentOffset += compressed.remaining(); + position += compressed.remaining(); len -= compressed.remaining(); copy.put(compressed); @@ -544,18 +572,18 @@ private ByteBuffer slice(int chunkLength) throws IOException { slice = compressed.slice(); slice.limit(len); copy.put(slice); - currentOffset += len; + position += len; compressed.position(compressed.position() + len); copy.flip(); return copy; } - currentOffset += compressed.remaining(); + position += compressed.remaining(); len -= compressed.remaining(); copy.put(compressed); } // restore offsets for exception clarity - currentOffset = oldOffset; + position = oldPosition; setCurrent(oldRange, true); throw new IOException("EOF in " + this + " while trying to read " + chunkLength + " bytes"); @@ -565,11 +593,12 @@ void seek(long desired) throws IOException { if (desired == 0 && bytes == null) { return; } + long posn = desired + offset; for (DiskRangeList range = bytes; range != null; range = range.next) { - if (range.getOffset() <= desired && - (range.next == null ? desired <= range.getEnd() : - desired < range.getEnd())) { - currentOffset = desired; + if (range.getOffset() <= posn && + (range.next == null ? posn <= range.getEnd() : + posn < range.getEnd())) { + position = desired; setCurrent(range, true); return; } @@ -597,7 +626,7 @@ private String rangeString() { @Override public String toString() { - return "compressed stream " + name + " position: " + currentOffset + + return "compressed stream " + name + " position: " + position + " length: " + length + " range: " + getRangeNumber(bytes, currentRange) + " offset: " + (compressed == null ? 0 : compressed.position()) + " limit: " + (compressed == null ? 0 : compressed.limit()) + @@ -611,13 +640,14 @@ public String toString() { private static class EncryptedCompressedStream extends CompressedStream { private final EncryptionState encrypt; - public EncryptedCompressedStream(String name, + public EncryptedCompressedStream(Object name, DiskRangeList input, + long offset, long length, StreamOptions options) { - super(name, length, options); - encrypt = new EncryptionState(name, options); - reset(input, length); + super(name, offset, length, options); + encrypt = new EncryptionState(name, offset, options); + reset(input); } @Override @@ -625,10 +655,10 @@ protected void setCurrent(DiskRangeList newRange, boolean isJump) { currentRange = newRange; if (newRange != null) { if (isJump) { - encrypt.changeIv(newRange.getOffset()); + encrypt.changeIv(newRange.getOffset() - offset); } compressed = encrypt.decrypt(newRange); - compressed.position((int) (currentOffset - newRange.getOffset())); + compressed.position((int) (position + offset - newRange.getOffset())); } } @@ -638,6 +668,11 @@ public void close() { encrypt.close(); } + @Override + public void changeIv(Consumer modifier) { + encrypt.changeIv(modifier); + } + @Override public String toString() { return "encrypted " + super.toString(); @@ -653,6 +688,17 @@ public static class StreamOptions implements Cloneable { private Key key; private byte[] iv; + public StreamOptions(StreamOptions other) { + codec = other.codec; + bufferSize = other.bufferSize; + algorithm = other.algorithm; + key = other.key; + iv = other.iv == null ? null : other.iv.clone(); + } + + public StreamOptions() { + } + public StreamOptions withCodec(CompressionCodec value) { this.codec = value; return this; @@ -672,10 +718,30 @@ public StreamOptions withEncryption(EncryptionAlgorithm algorithm, return this; } + public boolean isCompressed() { + return codec != null; + } + public CompressionCodec getCodec() { return codec; } + public int getBufferSize() { + return bufferSize; + } + + public EncryptionAlgorithm getAlgorithm() { + return algorithm; + } + + public Key getKey() { + return key; + } + + public byte[] getIv() { + return iv; + } + @Override public StreamOptions clone() { try { @@ -689,6 +755,20 @@ public StreamOptions clone() { throw new UnsupportedOperationException("uncloneable", e); } } + + @Override + public String toString() { + StringBuilder buffer = new StringBuilder(); + buffer.append("compress: "); + buffer.append(codec == null ? "none" : codec.getKind()); + buffer.append(", buffer size: "); + buffer.append(bufferSize); + if (key != null) { + buffer.append(", encryption: "); + buffer.append(algorithm); + } + return buffer.toString(); + } } public static StreamOptions options() { @@ -699,24 +779,30 @@ public static StreamOptions options() { * Create an input stream from a list of disk ranges with data. * @param name the name of the stream * @param input the list of ranges of bytes for the stream; from disk or cache + * @param offset the first byte offset of the stream * @param length the length in bytes of the stream * @param options the options to read with * @return an input stream */ - public static InStream create(String name, + public static InStream create(Object name, DiskRangeList input, + long offset, long length, StreamOptions options) { + LOG.debug("Reading {} with {} from {} for {}", name, options, offset, + length); if (options == null || options.codec == null) { if (options == null || options.key == null) { - return new UncompressedStream(name, input, length); + return new UncompressedStream(name, input, offset, length); } else { - return new EncryptedStream(name, input, length, options); + OutStream.logKeyAndIv(name, options.getKey(), options.getIv()); + return new EncryptedStream(name, input, offset, length, options); } } else if (options.key == null) { - return new CompressedStream(name, input, length, options); + return new CompressedStream(name, input, offset, length, options); } else { - return new EncryptedCompressedStream(name, input, length, options); + OutStream.logKeyAndIv(name, options.getKey(), options.getIv()); + return new EncryptedCompressedStream(name, input, offset, length, options); } } @@ -727,10 +813,11 @@ public static InStream create(String name, * @param length the length in bytes of the stream * @return an input stream */ - public static InStream create(String name, + public static InStream create(Object name, DiskRangeList input, - long length) throws IOException { - return create(name, input, length, null); + long offset, + long length) { + return create(name, input, offset, length, null); } /** diff --git a/java/core/src/java/org/apache/orc/impl/ReaderImpl.java b/java/core/src/java/org/apache/orc/impl/ReaderImpl.java index 8edfaee6e7..ddd53b7312 100644 --- a/java/core/src/java/org/apache/orc/impl/ReaderImpl.java +++ b/java/core/src/java/org/apache/orc/impl/ReaderImpl.java @@ -412,7 +412,7 @@ private static OrcProto.Footer extractFooter(ByteBuffer bb, int footerAbsPos, bb.position(footerAbsPos); bb.limit(footerAbsPos + footerSize); return OrcProto.Footer.parseFrom(InStream.createCodedInputStream( - InStream.create("footer", new BufferChunk(bb, 0), footerSize, options))); + InStream.create("footer", new BufferChunk(bb, 0), 0, footerSize, options))); } public static OrcProto.Metadata extractMetadata(ByteBuffer bb, int metadataAbsPos, @@ -420,7 +420,7 @@ public static OrcProto.Metadata extractMetadata(ByteBuffer bb, int metadataAbsPo bb.position(metadataAbsPos); bb.limit(metadataAbsPos + metadataSize); return OrcProto.Metadata.parseFrom(InStream.createCodedInputStream( - InStream.create("metadata", new BufferChunk(bb, 0), metadataSize, options))); + InStream.create("metadata", new BufferChunk(bb, 0), 0, metadataSize, options))); } private static OrcProto.PostScript extractPostScript(ByteBuffer bb, Path path, diff --git a/java/core/src/java/org/apache/orc/impl/RecordReaderImpl.java b/java/core/src/java/org/apache/orc/impl/RecordReaderImpl.java index 786f9631f5..28445529d0 100644 --- a/java/core/src/java/org/apache/orc/impl/RecordReaderImpl.java +++ b/java/core/src/java/org/apache/orc/impl/RecordReaderImpl.java @@ -1061,7 +1061,7 @@ private void clearStreams() { if (!(range instanceof BufferChunk)) { continue; } - dataReader.releaseBuffer(((BufferChunk) range).getChunk()); + dataReader.releaseBuffer(range.getData()); } } } @@ -1213,7 +1213,7 @@ void createStreams(List streamDescriptions, ranges, streamOffset, streamDesc.getLength()); StreamName name = new StreamName(column, streamDesc.getKind()); streams.put(name, InStream.create(name.toString(), buffers, - streamDesc.getLength(), options)); + 0, streamDesc.getLength(), options)); streamOffset += streamDesc.getLength(); } } diff --git a/java/core/src/java/org/apache/orc/impl/RecordReaderUtils.java b/java/core/src/java/org/apache/orc/impl/RecordReaderUtils.java index d4c57e57d7..aeb33d4948 100644 --- a/java/core/src/java/org/apache/orc/impl/RecordReaderUtils.java +++ b/java/core/src/java/org/apache/orc/impl/RecordReaderUtils.java @@ -234,7 +234,7 @@ public OrcIndex readRowIndex(StripeInformation stripe, indexes[column] = OrcProto.RowIndex.parseFrom( InStream.createCodedInputStream(InStream.create("index", new BufferChunk(bb, 0), - stream.getLength(), options))); + 0, stream.getLength(), options))); } break; case BLOOM_FILTER: @@ -246,7 +246,7 @@ public OrcIndex readRowIndex(StripeInformation stripe, bloomFilterIndices[column] = OrcProto.BloomFilterIndex.parseFrom (InStream.createCodedInputStream(InStream.create( "bloom_filter", new BufferChunk(bb, 0), - stream.getLength(), options))); + 0, stream.getLength(), options))); } break; default: @@ -271,7 +271,7 @@ public OrcProto.StripeFooter readStripeFooter(StripeInformation stripe) throws I file.readFully(offset, tailBuf.array(), tailBuf.arrayOffset(), tailLength); return OrcProto.StripeFooter.parseFrom( InStream.createCodedInputStream(InStream.create("footer", - new BufferChunk(tailBuf, 0), tailLength, options))); + new BufferChunk(tailBuf, 0), 0, tailLength, options))); } @Override diff --git a/java/core/src/test/org/apache/orc/impl/TestBitFieldReader.java b/java/core/src/test/org/apache/orc/impl/TestBitFieldReader.java index 4590f9b518..6f54c77f8c 100644 --- a/java/core/src/test/org/apache/orc/impl/TestBitFieldReader.java +++ b/java/core/src/test/org/apache/orc/impl/TestBitFieldReader.java @@ -53,7 +53,7 @@ public void runSeekTest(CompressionCodec codec) throws Exception { collect.buffer.setByteBuffer(inBuf, 0, collect.buffer.size()); inBuf.flip(); BitFieldReader in = new BitFieldReader(InStream.create("test", - new BufferChunk(inBuf, 0), inBuf.remaining(), + new BufferChunk(inBuf, 0), 0, inBuf.remaining(), InStream.options().withCodec(codec).withBufferSize(500))); for(int i=0; i < COUNT; ++i) { int x = in.next(); @@ -102,7 +102,7 @@ public void testSkips() throws Exception { collect.buffer.setByteBuffer(inBuf, 0, collect.buffer.size()); inBuf.flip(); BitFieldReader in = new BitFieldReader(InStream.create("test", - new BufferChunk(inBuf, 0), inBuf.remaining())); + new BufferChunk(inBuf, 0), 0, inBuf.remaining())); for(int i=0; i < COUNT; i += 5) { int x = in.next(); if (i < COUNT/2) { @@ -139,7 +139,7 @@ public void testSeekSkip() throws Exception { collect.buffer.setByteBuffer(inBuf, 0, collect.buffer.size()); inBuf.flip(); BitFieldReader in = new BitFieldReader(InStream.create("test", - new BufferChunk(inBuf, 0), inBuf.remaining())); + new BufferChunk(inBuf, 0), 0, inBuf.remaining())); in.seek(posn); in.skip(10); for(int r = 210; r < COUNT; ++r) { diff --git a/java/core/src/test/org/apache/orc/impl/TestBitPack.java b/java/core/src/test/org/apache/orc/impl/TestBitPack.java index d298ecc98c..a26b1305dc 100644 --- a/java/core/src/test/org/apache/orc/impl/TestBitPack.java +++ b/java/core/src/test/org/apache/orc/impl/TestBitPack.java @@ -110,7 +110,7 @@ private void runTest(int numBits) throws IOException { inBuf.flip(); long[] buff = new long[SIZE]; utils.readInts(buff, 0, SIZE, fixedWidth, - InStream.create("test", new BufferChunk(inBuf,0), + InStream.create("test", new BufferChunk(inBuf,0), 0, inBuf.remaining())); for (int i = 0; i < SIZE; i++) { buff[i] = utils.zigzagDecode(buff[i]); diff --git a/java/core/src/test/org/apache/orc/impl/TestInStream.java b/java/core/src/test/org/apache/orc/impl/TestInStream.java index 1881be358f..3b6e28c793 100644 --- a/java/core/src/test/org/apache/orc/impl/TestInStream.java +++ b/java/core/src/test/org/apache/orc/impl/TestInStream.java @@ -108,7 +108,7 @@ public void testUncompressed() throws Exception { collect.buffer.setByteBuffer(inBuf, 0, collect.buffer.size()); inBuf.flip(); InStream in = InStream.create("test", new BufferChunk(inBuf, 0), - inBuf.remaining()); + 0, inBuf.remaining()); assertEquals("uncompressed stream test position: 0 length: 1024" + " range: 0 offset: 0 limit: 1024", in.toString()); @@ -161,7 +161,7 @@ public void testEncrypted() throws Exception { offset += size; } - InStream in = InStream.create("test", list.get(), collect.buffer.size(), + InStream in = InStream.create("test", list.get(), 0, collect.buffer.size(), InStream.options().withEncryption(algorithm, decryptKey, writerOptions.getIv())); assertEquals("encrypted uncompressed stream test position: 0 length: 8192" + @@ -219,7 +219,7 @@ public void testCompressedEncrypted() throws Exception { offset += size; } - InStream in = InStream.create("test", list.get(), collect.buffer.size(), + InStream in = InStream.create("test", list.get(), 0, collect.buffer.size(), InStream.options() .withCodec(new ZlibCodec()).withBufferSize(500) .withEncryption(algorithm, decryptKey, writerOptions.getIv())); @@ -258,7 +258,7 @@ public void testCompressed() throws Exception { ByteBuffer inBuf = ByteBuffer.allocate(collect.buffer.size()); collect.buffer.setByteBuffer(inBuf, 0, collect.buffer.size()); inBuf.flip(); - InStream in = InStream.create("test", new BufferChunk(inBuf, 0), + InStream in = InStream.create("test", new BufferChunk(inBuf, 0), 0, inBuf.remaining(), InStream.options().withCodec(codec).withBufferSize(300)); assertEquals("compressed stream test position: 0 length: 961 range: 0" + @@ -294,7 +294,7 @@ public void testCorruptStream() throws Exception { ByteBuffer inBuf = ByteBuffer.allocate(collect.buffer.size()); collect.buffer.setByteBuffer(inBuf, 0, collect.buffer.size()); inBuf.flip(); - InStream in = InStream.create("test", new BufferChunk(inBuf, 0), + InStream in = InStream.create("test", new BufferChunk(inBuf, 0), 0, inBuf.remaining(), InStream.options().withCodec(codec).withBufferSize(100)); byte[] contents = new byte[1024]; @@ -310,7 +310,7 @@ public void testCorruptStream() throws Exception { inBuf.put((byte) 32); inBuf.put((byte) 0); inBuf.flip(); - in = InStream.create("test2", new BufferChunk(inBuf, 0), + in = InStream.create("test2", new BufferChunk(inBuf, 0), 0, inBuf.remaining(), InStream.options().withCodec(codec).withBufferSize(300)); try { @@ -354,7 +354,7 @@ public void testDisjointBuffers() throws Exception { } InStream.StreamOptions inOptions = InStream.options() .withCodec(codec).withBufferSize(400); - InStream in = InStream.create("test", buffers.get(), 1674, inOptions); + InStream in = InStream.create("test", buffers.get(), 0, 1674, inOptions); assertEquals("compressed stream test position: 0 length: 1674 range: 0" + " offset: 0 limit: 483 range 0 = 0 to 483;" + " range 1 = 483 to 1625; range 2 = 1625 to 1674", @@ -373,7 +373,7 @@ public void testDisjointBuffers() throws Exception { buffers.clear(); buffers.add(new BufferChunk(inBuf[1], 483)); buffers.add(new BufferChunk(inBuf[2], 1625)); - in = InStream.create("test", buffers.get(), 1674, inOptions); + in = InStream.create("test", buffers.get(), 0, 1674, inOptions); inStream = new DataInputStream(in); positions[303].reset(); in.seek(positions[303]); @@ -384,7 +384,7 @@ public void testDisjointBuffers() throws Exception { buffers.clear(); buffers.add(new BufferChunk(inBuf[0], 0)); buffers.add(new BufferChunk(inBuf[2], 1625)); - in = InStream.create("test", buffers.get(), 1674, inOptions); + in = InStream.create("test", buffers.get(), 0, 1674, inOptions); inStream = new DataInputStream(in); positions[1001].reset(); for(int i=0; i < 300; ++i) { @@ -424,7 +424,7 @@ public void testUncompressedDisjointBuffers() throws Exception { buffers.add(new BufferChunk(inBuf[0], 0)); buffers.add(new BufferChunk(inBuf[1], 1024)); buffers.add(new BufferChunk(inBuf[2], 3072)); - InStream in = InStream.create("test", buffers.get(), 4096); + InStream in = InStream.create("test", buffers.get(), 0, 4096); assertEquals("uncompressed stream test position: 0 length: 4096" + " range: 0 offset: 0 limit: 1024", in.toString()); @@ -442,7 +442,7 @@ public void testUncompressedDisjointBuffers() throws Exception { buffers.clear(); buffers.add(new BufferChunk(inBuf[1], 1024)); buffers.add(new BufferChunk(inBuf[2], 3072)); - in = InStream.create("test", buffers.get(), 4096); + in = InStream.create("test", buffers.get(), 0, 4096); inStream = new DataInputStream(in); positions[256].reset(); in.seek(positions[256]); @@ -453,7 +453,7 @@ public void testUncompressedDisjointBuffers() throws Exception { buffers.clear(); buffers.add(new BufferChunk(inBuf[0], 0)); buffers.add(new BufferChunk(inBuf[2], 3072)); - in = InStream.create("test", buffers.get(), 4096); + in = InStream.create("test", buffers.get(), 0, 4096); inStream = new DataInputStream(in); positions[768].reset(); for(int i=0; i < 256; ++i) { @@ -468,7 +468,7 @@ public void testUncompressedDisjointBuffers() throws Exception { @Test public void testEmptyDiskRange() throws IOException { DiskRangeList range = new BufferChunk(ByteBuffer.allocate(0), 0); - InStream stream = new InStream.UncompressedStream("test", range, 0); + InStream stream = new InStream.UncompressedStream("test", range, 0, 0); assertEquals(0, stream.available()); stream.seek(new PositionProvider() { @Override diff --git a/java/core/src/test/org/apache/orc/impl/TestIntegerCompressionReader.java b/java/core/src/test/org/apache/orc/impl/TestIntegerCompressionReader.java index 178e768629..3f66ace745 100644 --- a/java/core/src/test/org/apache/orc/impl/TestIntegerCompressionReader.java +++ b/java/core/src/test/org/apache/orc/impl/TestIntegerCompressionReader.java @@ -61,7 +61,7 @@ public void runSeekTest(CompressionCodec codec) throws Exception { inBuf.flip(); RunLengthIntegerReaderV2 in = new RunLengthIntegerReaderV2(InStream.create("test", - new BufferChunk(inBuf, 0), inBuf.remaining(), + new BufferChunk(inBuf, 0), 0, inBuf.remaining(), InStream.options().withCodec(codec).withBufferSize(1000)), true, false); for(int i=0; i < 2048; ++i) { int x = (int) in.next(); @@ -114,7 +114,7 @@ public void testSkips() throws Exception { inBuf.flip(); RunLengthIntegerReaderV2 in = new RunLengthIntegerReaderV2(InStream.create("test", - new BufferChunk(inBuf, 0), + new BufferChunk(inBuf, 0), 0, inBuf.remaining()), true, false); for(int i=0; i < 2048; i += 10) { int x = (int) in.next(); diff --git a/java/core/src/test/org/apache/orc/impl/TestOutStream.java b/java/core/src/test/org/apache/orc/impl/TestOutStream.java index b905c7ec34..95dedf6451 100644 --- a/java/core/src/test/org/apache/orc/impl/TestOutStream.java +++ b/java/core/src/test/org/apache/orc/impl/TestOutStream.java @@ -20,6 +20,7 @@ import org.apache.orc.CompressionCodec; import org.apache.orc.EncryptionAlgorithm; +import org.apache.orc.InMemoryKeystore; import org.apache.orc.OrcProto; import org.apache.orc.PhysicalWriter; import org.apache.orc.impl.writer.StreamOptions; @@ -35,22 +36,12 @@ import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.security.Key; -import java.security.NoSuchAlgorithmException; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; public class TestOutStream { - public static final boolean TEST_AES_256; - static { - try { - TEST_AES_256 = Cipher.getMaxAllowedKeyLength("AES") != 128; - } catch (NoSuchAlgorithmException e) { - throw new IllegalArgumentException("Unknown algorithm", e); - } - } - @Test public void testFlush() throws Exception { PhysicalWriter.OutputReceiver receiver = @@ -180,7 +171,7 @@ public void testEncryption() throws Exception { @Test public void testCompression256Encryption() throws Exception { // disable test if AES_256 is not available - Assume.assumeTrue(TEST_AES_256); + Assume.assumeTrue(InMemoryKeystore.SUPPORTS_AES_256); TestInStream.OutputCollector receiver = new TestInStream.OutputCollector(); EncryptionAlgorithm aes256 = EncryptionAlgorithm.AES_CTR_256; byte[] keyBytes = new byte[aes256.keyLength()]; @@ -211,7 +202,7 @@ public void testCompression256Encryption() throws Exception { // use InStream to decompress it BufferChunkList ranges = new BufferChunkList(); ranges.add(new BufferChunk(ByteBuffer.wrap(compressed), 0)); - InStream decompressedStream = InStream.create("test", ranges.get(), + InStream decompressedStream = InStream.create("test", ranges.get(), 0, compressed.length, InStream.options().withCodec(new ZlibCodec()).withBufferSize(1024)); diff --git a/java/core/src/test/org/apache/orc/impl/TestPhysicalFsWriter.java b/java/core/src/test/org/apache/orc/impl/TestPhysicalFsWriter.java index 6d2d298801..333bc980fb 100644 --- a/java/core/src/test/org/apache/orc/impl/TestPhysicalFsWriter.java +++ b/java/core/src/test/org/apache/orc/impl/TestPhysicalFsWriter.java @@ -57,7 +57,7 @@ static class MemoryOutputStream extends OutputStream { } @Override - public void write(int b) throws IOException { + public void write(int b) { contents.add(new byte[]{(byte) b}); } @@ -97,12 +97,12 @@ public FSDataOutputStream create(Path f, FsPermission permission, @Override public FSDataOutputStream append(Path f, int bufferSize, - Progressable progress) throws IOException { + Progressable progress) { throw new UnsupportedOperationException("append not supported"); } @Override - public boolean rename(Path src, Path dst) throws IOException { + public boolean rename(Path src, Path dst) { boolean result = fileContents.containsKey(src) && !fileContents.containsKey(dst); if (result) { @@ -113,14 +113,14 @@ public boolean rename(Path src, Path dst) throws IOException { } @Override - public boolean delete(Path f, boolean recursive) throws IOException { + public boolean delete(Path f, boolean recursive) { boolean result = fileContents.containsKey(f); fileContents.remove(f); return result; } @Override - public FileStatus[] listStatus(Path f) throws IOException { + public FileStatus[] listStatus(Path f) { return new FileStatus[]{getFileStatus(f)}; } @@ -135,12 +135,12 @@ public Path getWorkingDirectory() { } @Override - public boolean mkdirs(Path f, FsPermission permission) throws IOException { + public boolean mkdirs(Path f, FsPermission permission) { return false; } @Override - public FileStatus getFileStatus(Path f) throws IOException { + public FileStatus getFileStatus(Path f) { List contents = fileContents.get(f); if (contents != null) { long sum = 0; @@ -262,7 +262,8 @@ public DirectDecompressor getDirectDecompressor(DirectCompressionType codec) { } @Override - public ZeroCopyReaderShim getZeroCopyReader(FSDataInputStream in, ByteBufferPoolShim pool) throws IOException { + public ZeroCopyReaderShim getZeroCopyReader(FSDataInputStream in, + ByteBufferPoolShim pool) { return null; } @@ -276,7 +277,7 @@ public boolean endVariableLengthBlock(OutputStream output) throws IOException { } @Override - public KeyProvider getKeyProvider(Configuration conf, Random random) throws IOException { + public KeyProvider getKeyProvider(Configuration conf, Random random) { return null; } } diff --git a/java/core/src/test/org/apache/orc/impl/TestRunLengthByteReader.java b/java/core/src/test/org/apache/orc/impl/TestRunLengthByteReader.java index 6887f26ae4..cc81dbc0be 100644 --- a/java/core/src/test/org/apache/orc/impl/TestRunLengthByteReader.java +++ b/java/core/src/test/org/apache/orc/impl/TestRunLengthByteReader.java @@ -48,7 +48,7 @@ public void testUncompressedSeek() throws Exception { collect.buffer.setByteBuffer(inBuf, 0, collect.buffer.size()); inBuf.flip(); RunLengthByteReader in = new RunLengthByteReader(InStream.create("test", - new BufferChunk(inBuf, 0), inBuf.remaining())); + new BufferChunk(inBuf, 0), 0, inBuf.remaining())); for(int i=0; i < 2048; ++i) { int x = in.next() & 0xff; if (i < 1024) { @@ -92,7 +92,7 @@ public void testCompressedSeek() throws Exception { collect.buffer.setByteBuffer(inBuf, 0, collect.buffer.size()); inBuf.flip(); RunLengthByteReader in = new RunLengthByteReader(InStream.create("test", - new BufferChunk(inBuf, 0), inBuf.remaining(), + new BufferChunk(inBuf, 0), 0, inBuf.remaining(), InStream.options().withCodec(codec).withBufferSize(500))); for(int i=0; i < 2048; ++i) { int x = in.next() & 0xff; @@ -130,7 +130,7 @@ public void testSkips() throws Exception { collect.buffer.setByteBuffer(inBuf, 0, collect.buffer.size()); inBuf.flip(); RunLengthByteReader in = new RunLengthByteReader(InStream.create("test", - new BufferChunk(inBuf, 0), inBuf.remaining())); + new BufferChunk(inBuf, 0), 0, inBuf.remaining())); for(int i=0; i < 2048; i += 10) { int x = in.next() & 0xff; if (i < 1024) { diff --git a/java/core/src/test/org/apache/orc/impl/TestRunLengthIntegerReader.java b/java/core/src/test/org/apache/orc/impl/TestRunLengthIntegerReader.java index 6b29152453..ac41e6e828 100644 --- a/java/core/src/test/org/apache/orc/impl/TestRunLengthIntegerReader.java +++ b/java/core/src/test/org/apache/orc/impl/TestRunLengthIntegerReader.java @@ -17,8 +17,6 @@ */ package org.apache.orc.impl; -import static junit.framework.Assert.assertEquals; - import java.nio.ByteBuffer; import java.util.Random; @@ -26,6 +24,8 @@ import org.apache.orc.impl.writer.StreamOptions; import org.junit.Test; +import static org.junit.Assert.assertEquals; + public class TestRunLengthIntegerReader { public void runSeekTest(CompressionCodec codec) throws Exception { @@ -60,7 +60,7 @@ public void runSeekTest(CompressionCodec codec) throws Exception { collect.buffer.setByteBuffer(inBuf, 0, collect.buffer.size()); inBuf.flip(); RunLengthIntegerReader in = new RunLengthIntegerReader(InStream.create - ("test", new BufferChunk(inBuf, 0), inBuf.remaining(), + ("test", new BufferChunk(inBuf, 0), 0, inBuf.remaining(), InStream.options().withCodec(codec).withBufferSize(1000)), true); for(int i=0; i < 2048; ++i) { int x = (int) in.next(); @@ -112,7 +112,7 @@ public void testSkips() throws Exception { collect.buffer.setByteBuffer(inBuf, 0, collect.buffer.size()); inBuf.flip(); RunLengthIntegerReader in = new RunLengthIntegerReader(InStream.create - ("test", new BufferChunk(inBuf, 0), inBuf.remaining()), true); + ("test", new BufferChunk(inBuf, 0), 0, inBuf.remaining()), true); for(int i=0; i < 2048; i += 10) { int x = (int) in.next(); if (i < 1024) { diff --git a/java/core/src/test/org/apache/orc/impl/TestSchemaEvolution.java b/java/core/src/test/org/apache/orc/impl/TestSchemaEvolution.java index 1102d121af..87478d405b 100644 --- a/java/core/src/test/org/apache/orc/impl/TestSchemaEvolution.java +++ b/java/core/src/test/org/apache/orc/impl/TestSchemaEvolution.java @@ -23,10 +23,8 @@ import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; -import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; -import java.util.List; import java.util.Map; import org.apache.hadoop.conf.Configuration; @@ -47,7 +45,6 @@ import org.apache.orc.TypeDescription; import org.apache.orc.Writer; import org.junit.Before; -import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TestName; @@ -1581,7 +1578,17 @@ static void createStream(Map streams, buffer[i] = (byte) values[i]; } ranges.add(new BufferChunk(ByteBuffer.wrap(buffer), 0)); - streams.put(name, InStream.create(name.toString(), ranges.get(), values.length)); + streams.put(name, InStream.create(name.toString(), ranges.get(), 0, + values.length)); + } + + static ByteBuffer createBuffer(int... values) { + ByteBuffer result = ByteBuffer.allocate(values.length); + for(int v: values) { + result.put((byte) v); + } + result.flip(); + return result; } @Test From 1a92b864aee5bca59a556c53203f180cc72373ec Mon Sep 17 00:00:00 2001 From: Owen O'Malley Date: Mon, 24 Jun 2019 13:53:19 -0700 Subject: [PATCH 4/4] ORC-523: Update ReaderImpl to work with column encryption. --- .../src/java/org/apache/orc/OrcUtils.java | 14 +- .../org/apache/orc/StripeInformation.java | 21 ++ .../src/java/org/apache/orc/impl/OrcTail.java | 29 +-- .../java/org/apache/orc/impl/ReaderImpl.java | 152 +++++++++++--- .../orc/impl/mask/SHA256MaskFactory.java | 14 +- .../orc/impl/reader/ReaderEncryption.java | 145 +++++++++++++ .../orc/impl/reader/ReaderEncryptionKey.java | 132 ++++++++++++ .../impl/reader/ReaderEncryptionVariant.java | 190 ++++++++++++++++++ .../apache/orc/impl/TestRecordReaderImpl.java | 4 +- 9 files changed, 646 insertions(+), 55 deletions(-) create mode 100644 java/core/src/java/org/apache/orc/impl/reader/ReaderEncryption.java create mode 100644 java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionKey.java create mode 100644 java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionVariant.java diff --git a/java/core/src/java/org/apache/orc/OrcUtils.java b/java/core/src/java/org/apache/orc/OrcUtils.java index 220fa1338e..0ba46dc032 100644 --- a/java/core/src/java/org/apache/orc/OrcUtils.java +++ b/java/core/src/java/org/apache/orc/OrcUtils.java @@ -616,9 +616,17 @@ TypeDescription convertTypeFromProtobuf(List types, public static List convertProtoStripesToStripes( List stripes) { - List result = new ArrayList(stripes.size()); - for (OrcProto.StripeInformation info : stripes) { - result.add(new ReaderImpl.StripeInformationImpl(info)); + List result = new ArrayList<>(stripes.size()); + long previousStripeId = -1; + byte[][] previousKeys = null; + long stripeId = 0; + for (OrcProto.StripeInformation stripeProto: stripes) { + ReaderImpl.StripeInformationImpl stripe = + new ReaderImpl.StripeInformationImpl(stripeProto, stripeId++, + previousStripeId, previousKeys); + result.add(stripe); + previousStripeId = stripe.getEncryptionStripeId(); + previousKeys = stripe.getEncryptedLocalKeys(); } return result; } diff --git a/java/core/src/java/org/apache/orc/StripeInformation.java b/java/core/src/java/org/apache/orc/StripeInformation.java index 38f7eba31e..6490d6bbe5 100644 --- a/java/core/src/java/org/apache/orc/StripeInformation.java +++ b/java/core/src/java/org/apache/orc/StripeInformation.java @@ -56,4 +56,25 @@ public interface StripeInformation { * @return a count of the number of rows */ long getNumberOfRows(); + + /** + * Get the index of this stripe in the current file. + * @return 0 to number_of_stripes - 1 + */ + long getStripeId(); + + /** + * Get the original stripe id that was used when the stripe was originally + * written. This is only different that getStripeId in merged files. + * @return the original stripe id + */ + long getEncryptionStripeId(); + + /** + * Get the encrypted keys starting from this stripe until overridden by + * a new set in a following stripe. The top level array is one for each + * encryption variant. Each element is an encrypted key. + * @return the array of encrypted keys + */ + byte[][] getEncryptedLocalKeys(); } diff --git a/java/core/src/java/org/apache/orc/impl/OrcTail.java b/java/core/src/java/org/apache/orc/impl/OrcTail.java index 9e8a5f21e0..2765473486 100644 --- a/java/core/src/java/org/apache/orc/impl/OrcTail.java +++ b/java/core/src/java/org/apache/orc/impl/OrcTail.java @@ -1,4 +1,4 @@ -/** +/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -27,6 +27,7 @@ import org.apache.orc.CompressionKind; import org.apache.orc.OrcFile; import org.apache.orc.OrcProto; +import org.apache.orc.OrcUtils; import org.apache.orc.StripeInformation; import org.apache.orc.StripeStatistics; @@ -77,11 +78,7 @@ public OrcFile.WriterVersion getWriterVersion() { } public List getStripes() { - List result = new ArrayList<>(fileTail.getFooter().getStripesCount()); - for (OrcProto.StripeInformation stripeProto : fileTail.getFooter().getStripesList()) { - result.add(new ReaderImpl.StripeInformationImpl(stripeProto)); - } - return result; + return OrcUtils.convertProtoStripesToStripes(getFooter().getStripesList()); } public CompressionKind getCompressionKind() { @@ -92,9 +89,9 @@ public int getCompressionBufferSize() { return (int) fileTail.getPostscript().getCompressionBlockSize(); } - public List getStripeStatistics() throws IOException { + public List getStripeStatistics(InStream.StreamOptions options) throws IOException { List result = new ArrayList<>(); - List ssProto = getStripeStatisticsProto(); + List ssProto = getStripeStatisticsProto(options); if (ssProto != null) { for (OrcProto.StripeStatistics ss : ssProto) { result.add(new StripeStatistics(ss.getColStatsList())); @@ -103,17 +100,12 @@ public List getStripeStatistics() throws IOException { return result; } - public List getStripeStatisticsProto() throws IOException { + public List getStripeStatisticsProto(InStream.StreamOptions options) throws IOException { if (serializedTail == null) return null; if (metadata == null) { - CompressionCodec codec = OrcCodecPool.getCodec(getCompressionKind()); - try { - metadata = extractMetadata(serializedTail, 0, - (int) fileTail.getPostscript().getMetadataLength(), - InStream.options().withCodec(codec).withBufferSize(getCompressionBufferSize())); - } finally { - OrcCodecPool.returnCodec(getCompressionKind(), codec); - } + metadata = extractMetadata(serializedTail, 0, + (int) fileTail.getPostscript().getMetadataLength(), + options); // clear does not clear the contents but sets position to 0 and limit = capacity serializedTail.clear(); } @@ -137,7 +129,6 @@ public OrcProto.FileTail getMinimalFileTail() { OrcProto.Footer.Builder footerBuilder = OrcProto.Footer.newBuilder(fileTail.getFooter()); footerBuilder.clearStatistics(); fileTailBuilder.setFooter(footerBuilder.build()); - OrcProto.FileTail result = fileTailBuilder.build(); - return result; + return fileTailBuilder.build(); } } diff --git a/java/core/src/java/org/apache/orc/impl/ReaderImpl.java b/java/core/src/java/org/apache/orc/impl/ReaderImpl.java index ddd53b7312..d1311b9a81 100644 --- a/java/core/src/java/org/apache/orc/impl/ReaderImpl.java +++ b/java/core/src/java/org/apache/orc/impl/ReaderImpl.java @@ -20,13 +20,15 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.security.Key; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Set; -import org.apache.hadoop.fs.FileStatus; +import org.apache.orc.EncryptionAlgorithm; +import org.apache.orc.EncryptionKey; import org.apache.orc.CompressionKind; import org.apache.orc.DataMaskDescription; import org.apache.orc.EncryptionKey; @@ -43,8 +45,8 @@ import org.apache.orc.StripeInformation; import org.apache.orc.StripeStatistics; import org.apache.orc.UnknownFormatException; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import org.apache.orc.impl.reader.ReaderEncryption; +import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FSDataInputStream; import org.apache.hadoop.fs.FileSystem; @@ -54,6 +56,9 @@ import org.apache.orc.OrcProto; import com.google.protobuf.CodedInputStream; +import org.apache.orc.impl.reader.ReaderEncryptionVariant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class ReaderImpl implements Reader { @@ -78,6 +83,7 @@ public class ReaderImpl implements Reader { private final List stripes; protected final int rowIndexStride; private final long contentLength, numberOfRows; + private final ReaderEncryption encryption; private long deserializedSize = -1; protected final Configuration conf; @@ -89,10 +95,30 @@ public class ReaderImpl implements Reader { public static class StripeInformationImpl implements StripeInformation { + private final long stripeId; + private final long originalStripeId; + private final byte[][] encryptedKeys; private final OrcProto.StripeInformation stripe; - public StripeInformationImpl(OrcProto.StripeInformation stripe) { + public StripeInformationImpl(OrcProto.StripeInformation stripe, + long stripeId, + long previousOriginalStripeId, + byte[][] previousKeys) { this.stripe = stripe; + this.stripeId = stripeId; + if (stripe.hasEncryptStripeId()) { + originalStripeId = stripe.getEncryptStripeId(); + } else { + originalStripeId = previousOriginalStripeId + 1; + } + if (stripe.getEncryptedLocalKeysCount() != 0) { + encryptedKeys = new byte[stripe.getEncryptedLocalKeysCount()][]; + for(int v=0; v < encryptedKeys.length; ++v) { + encryptedKeys[v] = stripe.getEncryptedLocalKeys(v).toByteArray(); + } + } else { + encryptedKeys = previousKeys; + } } @Override @@ -125,6 +151,21 @@ public long getNumberOfRows() { return stripe.getNumberOfRows(); } + @Override + public long getStripeId() { + return stripeId; + } + + @Override + public long getEncryptionStripeId() { + return originalStripeId; + } + + @Override + public byte[][] getEncryptedLocalKeys() { + return encryptedKeys; + } + @Override public String toString() { return "offset: " + getOffset() + " data: " + getDataLength() + @@ -221,20 +262,25 @@ public OrcProto.FileTail getFileTail() { @Override public EncryptionKey[] getColumnEncryptionKeys() { - // TODO - return new EncryptionKey[0]; + return encryption.getKeys(); } @Override public DataMaskDescription[] getDataMasks() { - // TODO - return new DataMaskDescription[0]; + return encryption.getMasks(); } @Override - public EncryptionVariant[] getEncryptionVariants() { - // TODO - return new EncryptionVariant[0]; + public ReaderEncryptionVariant[] getEncryptionVariants() { + return encryption.getVariants(); + } + + /** + * Internal access to our view of the encryption. + * @return the encryption information for this reader. + */ + public ReaderEncryption getEncryption() { + return encryption; } @Override @@ -244,7 +290,61 @@ public int getRowIndexStride() { @Override public ColumnStatistics[] getStatistics() { - return deserializeStats(schema, fileStats); + ColumnStatistics[] result = deserializeStats(schema, fileStats); + try (CompressionCodec codec = OrcCodecPool.getCodec(compressionKind)) { + InStream.StreamOptions compression = InStream.options(); + if (codec != null) { + compression.withCodec(codec).withBufferSize(bufferSize); + } + for (ReaderEncryptionVariant variant : encryption.getVariants()) { + ColumnStatistics[] overrides; + try { + overrides = decryptFileStats(variant, compression, + tail.getFooter()); + } catch (IOException e) { + throw new RuntimeException("Can't decrypt file stats for " + path + + " with " + variant.getKeyDescription()); + } + if (overrides != null) { + for (int i = 0; i < overrides.length; ++i) { + result[variant.getRoot().getId() + i] = overrides[i]; + } + } + } + } + return result; + } + + public static ColumnStatistics[] decryptFileStats(ReaderEncryptionVariant encryption, + InStream.StreamOptions compression, + OrcProto.Footer footer + ) throws IOException { + Key key = encryption.getFileFooterKey(); + if (key == null) { + return null; + } else { + OrcProto.EncryptionVariant protoVariant = + footer.getEncryption().getVariants(encryption.getVariantId()); + byte[] bytes = protoVariant.getFileStatistics().toByteArray(); + BufferChunk buffer = new BufferChunk(ByteBuffer.wrap(bytes), 0); + EncryptionAlgorithm algorithm = encryption.getKeyDescription().getAlgorithm(); + byte[] iv = new byte[algorithm.getIvLength()]; + CryptoUtils.modifyIvForStream(encryption.getRoot().getId(), + OrcProto.Stream.Kind.FILE_STATISTICS, footer.getStripesCount()) + .accept(iv); + InStream.StreamOptions options = new InStream.StreamOptions(compression) + .withEncryption(algorithm, key, iv); + InStream in = InStream.create("encrypted file stats", buffer, + bytes.length, 0, options); + OrcProto.FileStatistics decrypted = OrcProto.FileStatistics.parseFrom(in); + ColumnStatistics[] result = new ColumnStatistics[decrypted.getColumnCount()]; + TypeDescription root = encryption.getRoot(); + for(int i= 0; i < result.length; ++i){ + result[i] = ColumnStatisticsImpl.deserialize(root.findSubtype(root.getId() + i), + decrypted.getColumn(i)); + } + return result; + } } public static ColumnStatistics[] deserializeStats( @@ -351,12 +451,17 @@ public ReaderImpl(Path path, OrcFile.ReaderOptions options) throws IOException { this.writerVersion = OrcFile.WriterVersion.from(writer, fileMetadata.getWriterVersionNum()); this.types = fileMetadata.getTypes(); + OrcUtils.isValidTypeTree(this.types, 0); + this.schema = OrcUtils.convertTypeFromProtobuf(this.types, 0); this.rowIndexStride = fileMetadata.getRowIndexStride(); this.contentLength = fileMetadata.getContentLength(); this.numberOfRows = fileMetadata.getNumberOfRows(); this.fileStats = fileMetadata.getFileStats(); this.stripes = fileMetadata.getStripes(); this.userMetadata = null; // not cached and not needed here + // FileMetadata is obsolete and doesn't support encryption + this.encryption = new ReaderEncryption(null, schema, stripes, + options.getKeyProvider(), conf); } else { OrcTail orcTail = options.getOrcTail(); if (orcTail == null) { @@ -371,6 +476,8 @@ public ReaderImpl(Path path, OrcFile.ReaderOptions options) throws IOException { this.metadataSize = tail.getMetadataSize(); this.versionList = tail.getPostScript().getVersionList(); this.types = tail.getFooter().getTypesList(); + OrcUtils.isValidTypeTree(this.types, 0); + this.schema = OrcUtils.convertTypeFromProtobuf(this.types, 0); this.rowIndexStride = tail.getFooter().getRowIndexStride(); this.contentLength = tail.getFooter().getContentLength(); this.numberOfRows = tail.getFooter().getNumberOfRows(); @@ -378,10 +485,14 @@ public ReaderImpl(Path path, OrcFile.ReaderOptions options) throws IOException { this.fileStats = tail.getFooter().getStatisticsList(); this.writerVersion = tail.getWriterVersion(); this.stripes = tail.getStripes(); - this.stripeStats = tail.getStripeStatisticsProto(); + try (CompressionCodec codec = OrcCodecPool.getCodec(compressionKind)) { + InStream.StreamOptions compress = InStream.options().withCodec(codec) + .withBufferSize(bufferSize); + this.stripeStats = tail.getStripeStatisticsProto(compress); + } + this.encryption = new ReaderEncryption(tail.getFooter(), schema, + stripes, options.getKeyProvider(), conf); } - OrcUtils.isValidTypeTree(this.types, 0); - this.schema = OrcUtils.convertTypeFromProtobuf(this.types, 0); } protected FileSystem getFileSystem() throws IOException { @@ -559,12 +670,9 @@ protected OrcTail extractFileTail(FileSystem fs, Path path, ByteBuffer footerBuffer = buffer.slice(); buffer.reset(); OrcProto.Footer footer; - CompressionCodec codec = OrcCodecPool.getCodec(compressionKind); - try { + try (CompressionCodec codec = OrcCodecPool.getCodec(compressionKind)){ footer = extractFooter(footerBuffer, 0, footerSize, InStream.options().withCodec(codec).withBufferSize(bufferSize)); - } finally { - OrcCodecPool.returnCodec(compressionKind, codec); } fileTailBuilder.setFooter(footer); } catch (Throwable thr) { @@ -604,7 +712,6 @@ public RecordReader rows(Options options) throws IOException { return new RecordReaderImpl(this, options); } - @Override public long getRawDataSize() { // if the deserializedSize is not computed, then compute it, else @@ -757,12 +864,9 @@ public List getOrcProtoFileStatistics() { @Override public List getStripeStatistics() throws IOException { if (metadata == null) { - CompressionCodec codec = OrcCodecPool.getCodec(compressionKind); - try { + try (CompressionCodec codec = OrcCodecPool.getCodec(compressionKind)) { metadata = extractMetadata(tail.getSerializedTail(), 0, metadataSize, InStream.options().withCodec(codec).withBufferSize(bufferSize)); - } finally { - OrcCodecPool.returnCodec(compressionKind, codec); } } if (stripeStats == null) { diff --git a/java/core/src/java/org/apache/orc/impl/mask/SHA256MaskFactory.java b/java/core/src/java/org/apache/orc/impl/mask/SHA256MaskFactory.java index c300500384..b3445f98f3 100644 --- a/java/core/src/java/org/apache/orc/impl/mask/SHA256MaskFactory.java +++ b/java/core/src/java/org/apache/orc/impl/mask/SHA256MaskFactory.java @@ -62,9 +62,9 @@ */ public class SHA256MaskFactory extends MaskFactory { - final MessageDigest md; + private final MessageDigest md; - public SHA256MaskFactory(final String... params) { + SHA256MaskFactory() { super(); try { md = MessageDigest.getInstance("SHA-256"); @@ -138,9 +138,9 @@ void maskString(final BytesColumnVector source, final int row, /** * Helper function to mask binary data with it's SHA-256 hash. * - * @param source - * @param row - * @param target + * @param source the source data + * @param row the row that we are translating + * @param target the output data */ void maskBinary(final BytesColumnVector source, final int row, final BytesColumnVector target) { @@ -207,7 +207,7 @@ class StringMask implements DataMask { final TypeDescription schema; /* create an instance */ - public StringMask(TypeDescription schema) { + StringMask(TypeDescription schema) { super(); this.schema = schema; } @@ -254,7 +254,7 @@ public void maskData(final ColumnVector original, final ColumnVector masked, class BinaryMask implements DataMask { /* create an instance */ - public BinaryMask() { + BinaryMask() { super(); } diff --git a/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryption.java b/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryption.java new file mode 100644 index 0000000000..fe54c49b0f --- /dev/null +++ b/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryption.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.orc.impl.reader; + +import org.apache.hadoop.conf.Configuration; +import org.apache.orc.OrcProto; +import org.apache.orc.StripeInformation; +import org.apache.orc.TypeDescription; +import org.apache.orc.impl.HadoopShims; +import org.apache.orc.impl.HadoopShimsFactory; +import org.apache.orc.impl.MaskDescriptionImpl; + +import java.io.IOException; +import java.security.SecureRandom; +import java.util.Arrays; +import java.util.List; + +public class ReaderEncryption { + private final HadoopShims.KeyProvider keyProvider; + private final ReaderEncryptionKey[] keys; + private final MaskDescriptionImpl[] masks; + private final ReaderEncryptionVariant[] variants; + // Mapping from each column to the next variant to try for that column. + // A value of variants.length means no encryption + private final ReaderEncryptionVariant[] columnVariants; + + public ReaderEncryption() throws IOException { + this(null, null, null, null, null); + } + + public ReaderEncryption(OrcProto.Footer footer, + TypeDescription schema, + List stripes, + HadoopShims.KeyProvider provider, + Configuration conf) throws IOException { + if (footer == null || !footer.hasEncryption()) { + keyProvider = null; + keys = new ReaderEncryptionKey[0]; + masks = new MaskDescriptionImpl[0]; + variants = new ReaderEncryptionVariant[0]; + columnVariants = null; + } else { + keyProvider = provider != null ? provider : + HadoopShimsFactory.get().getKeyProvider(conf, new SecureRandom()); + OrcProto.Encryption encrypt = footer.getEncryption(); + masks = new MaskDescriptionImpl[encrypt.getMaskCount()]; + for(int m=0; m < masks.length; ++m) { + masks[m] = new MaskDescriptionImpl(m, encrypt.getMask(m)); + } + keys = new ReaderEncryptionKey[encrypt.getKeyCount()]; + for(int k=0; k < keys.length; ++k) { + keys[k] = new ReaderEncryptionKey(encrypt.getKey(k)); + } + variants = new ReaderEncryptionVariant[encrypt.getVariantsCount()]; + for(int v=0; v < variants.length; ++v) { + OrcProto.EncryptionVariant variant = encrypt.getVariants(v); + variants[v] = new ReaderEncryptionVariant(keys[variant.getKey()], v, + variant, schema, stripes, keyProvider); + } + columnVariants = new ReaderEncryptionVariant[schema.getMaximumId() + 1]; + for(int v = 0; v < variants.length; ++v) { + TypeDescription root = variants[v].getRoot(); + for(int c = root.getId(); c <= root.getMaximumId(); ++c) { + if (columnVariants[c] == null) { + columnVariants[c] = variants[v]; + } + } + } + } + } + + public MaskDescriptionImpl[] getMasks() { + return masks; + } + + public ReaderEncryptionKey[] getKeys() { + return keys; + } + + public ReaderEncryptionVariant[] getVariants() { + return variants; + } + + /** + * Find the next possible variant in this file for the given column. + * @param column the column to find a variant for + * @param lastVariant the previous variant that we looked at + * @return the next variant or null if there are none + */ + private ReaderEncryptionVariant findNextVariant(int column, + int lastVariant) { + for(int v = lastVariant + 1; v < variants.length; ++v) { + TypeDescription root = variants[v].getRoot(); + if (root.getId() <= column && column <= root.getMaximumId()) { + return variants[v]; + } + } + return null; + } + + /** + * Get the variant for a given column that the user has access to. + * If we haven't tried a given key, try to decrypt this variant's footer key + * to see if the KeyProvider will give it to us. If not, continue to the + * next variant. + * @param column the column id + * @return null for no encryption or the encryption variant + */ + public ReaderEncryptionVariant getVariant(int column) throws IOException { + if (keyProvider != null) { + while (columnVariants[column] != null) { + ReaderEncryptionVariant result = columnVariants[column]; + switch (result.getKeyDescription().getState()) { + case FAILURE: + break; + case SUCCESS: + return result; + case UNTRIED: + // try to get the footer key, to see if we have access + if (result.getFileFooterKey() != null) { + return result; + } + } + columnVariants[column] = findNextVariant(column, result.getVariantId()); + } + } + return null; + } +} diff --git a/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionKey.java b/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionKey.java new file mode 100644 index 0000000000..407bebb36a --- /dev/null +++ b/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionKey.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.orc.impl.reader; + +import org.apache.orc.EncryptionKey; +import org.apache.orc.EncryptionAlgorithm; +import org.apache.orc.OrcProto; +import org.apache.orc.impl.HadoopShims; +import org.jetbrains.annotations.NotNull; + +import java.util.ArrayList; +import java.util.List; + +/** + * This tracks the keys for reading encrypted columns. + */ +public class ReaderEncryptionKey implements EncryptionKey { + private final String name; + private final int version; + private final EncryptionAlgorithm algorithm; + private final List roots = new ArrayList<>(); + + /** + * Store the state of whether we've tried to decrypt a local key using this + * key or not. If it fails the first time, we assume the user doesn't have + * permission and move on. However, we don't want to retry the same failed + * key over and over again. + */ + public enum State { + UNTRIED, + FAILURE, + SUCCESS + } + + private State state = State.UNTRIED; + + public ReaderEncryptionKey(OrcProto.EncryptionKey key) { + name = key.getKeyName(); + version = key.getKeyVersion(); + algorithm = + EncryptionAlgorithm.fromSerialization(key.getAlgorithm().getNumber()); + } + + @Override + public String getKeyName() { + return name; + } + + @Override + public int getKeyVersion() { + return version; + } + + @Override + public EncryptionAlgorithm getAlgorithm() { + return algorithm; + } + + @Override + public ReaderEncryptionVariant[] getEncryptionRoots() { + return roots.toArray(new ReaderEncryptionVariant[roots.size()]); + } + + public HadoopShims.KeyMetadata getMetadata() { + return new HadoopShims.KeyMetadata(name, version, algorithm); + } + + public State getState() { + return state; + } + + public void setFailure() { + state = State.FAILURE; + } + + public void setSucess() { + if (state == State.FAILURE) { + throw new IllegalStateException("Key " + name + " had already failed."); + } + state = State.SUCCESS; + } + + void addVariant(ReaderEncryptionVariant newVariant) { + roots.add(newVariant); + } + + @Override + public boolean equals(Object other) { + if (other == null || getClass() != other.getClass()) { + return false; + } else if (other == this) { + return true; + } else { + return compareTo((EncryptionKey) other) == 0; + } + } + + @Override + public int hashCode() { + return name.hashCode() * 127 + version * 7 + algorithm.hashCode(); + } + + @Override + public int compareTo(@NotNull EncryptionKey other) { + int result = name.compareTo(other.getKeyName()); + if (result == 0) { + result = Integer.compare(version, other.getKeyVersion()); + } + return result; + } + + @Override + public String toString() { + return name + "@" + version + " w/ " + algorithm; + } +} diff --git a/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionVariant.java b/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionVariant.java new file mode 100644 index 0000000000..255952d15c --- /dev/null +++ b/java/core/src/java/org/apache/orc/impl/reader/ReaderEncryptionVariant.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.orc.impl.reader; + +import org.apache.hadoop.io.BytesWritable; +import org.apache.orc.EncryptionAlgorithm; +import org.apache.orc.EncryptionVariant; +import org.apache.orc.OrcProto; +import org.apache.orc.StripeInformation; +import org.apache.orc.TypeDescription; +import org.apache.orc.impl.HadoopShims; +import org.apache.orc.impl.LocalKey; +import org.jetbrains.annotations.NotNull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.security.Key; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Information about an encrypted column. + */ +public class ReaderEncryptionVariant implements EncryptionVariant { + private static final Logger LOG = + LoggerFactory.getLogger(ReaderEncryptionVariant.class); + private final HadoopShims.KeyProvider provider; + private final ReaderEncryptionKey key; + private final TypeDescription column; + private final int variantId; + private final LocalKey[] localKeys; + private final LocalKey footerKey; + + /** + * Create a reader's view of an encryption variant. + * @param key the encryption key description + * @param variantId the of of the variant (0..N-1) + * @param proto the serialized description of the variant + * @param schema the file schema + * @param stripes the stripe information + * @param provider the key provider + */ + public ReaderEncryptionVariant(ReaderEncryptionKey key, + int variantId, + OrcProto.EncryptionVariant proto, + TypeDescription schema, + List stripes, + HadoopShims.KeyProvider provider) { + this.key = key; + this.variantId = variantId; + this.provider = provider; + this.column = proto.hasRoot() ? schema.findSubtype(proto.getRoot()) : null; + this.localKeys = new LocalKey[stripes.size()]; + HashMap cache = new HashMap<>(); + for(int s=0; s < localKeys.length; ++s) { + StripeInformation stripe = stripes.get(s); + localKeys[s] = getCachedKey(cache, key.getAlgorithm(), + stripe.getEncryptedLocalKeys()[variantId]); + } + if (proto.hasEncryptedKey()) { + footerKey = getCachedKey(cache, key.getAlgorithm(), + proto.getEncryptedKey().toByteArray()); + } else { + footerKey = null; + } + key.addVariant(this); + } + + @Override + public ReaderEncryptionKey getKeyDescription() { + return key; + } + + @Override + public TypeDescription getRoot() { + return column; + } + + @Override + public int getVariantId() { + return variantId; + } + + /** + * Deduplicate the local keys so that we only decrypt each local key once. + * @param cache the cache to use + * @param encrypted the encrypted key + * @return the local key + */ + private static LocalKey getCachedKey(Map cache, + EncryptionAlgorithm algorithm, + byte[] encrypted) { + // wrap byte array in BytesWritable to get equality and hash + BytesWritable wrap = new BytesWritable(encrypted); + LocalKey result = cache.get(wrap); + if (result == null) { + result = new LocalKey(algorithm, null, encrypted); + cache.put(wrap, result); + } + return result; + } + + private Key getDecryptedKey(LocalKey localKey) throws IOException { + Key result = localKey.getDecryptedKey(); + if (result == null) { + switch (this.key.getState()) { + case UNTRIED: + try { + result = provider.decryptLocalKey(key.getMetadata(), + localKey.getEncryptedKey()); + } catch (IOException ioe) { + LOG.info("Can't decrypt using key {}", key); + } + if (result != null) { + localKey.setDecryptedKey(result); + key.setSucess(); + } else { + key.setFailure(); + } + break; + case SUCCESS: + result = provider.decryptLocalKey(key.getMetadata(), + localKey.getEncryptedKey()); + if (result == null) { + throw new IOException("Can't decrypt local key " + key); + } + localKey.setDecryptedKey(result); + break; + case FAILURE: + return null; + } + } + return result; + } + + @Override + public Key getFileFooterKey() throws IOException { + return getDecryptedKey(footerKey); + } + + @Override + public Key getStripeKey(long stripe) throws IOException { + return getDecryptedKey(localKeys[(int) stripe]); + } + + @Override + public boolean equals(Object other) { + if (other == null || other.getClass() != getClass()) { + return false; + } else { + return compareTo((EncryptionVariant) other) == 0; + } + } + + @Override + public int hashCode() { + return key.hashCode() * 127 + column.getId(); + } + + @Override + public int compareTo(@NotNull EncryptionVariant other) { + if (other == this) { + return 0; + } else if (key == other.getKeyDescription()) { + return Integer.compare(column.getId(), other.getRoot().getId()); + } else if (key == null) { + return -1; + } else { + return key.compareTo(other.getKeyDescription()); + } + } +} diff --git a/java/core/src/test/org/apache/orc/impl/TestRecordReaderImpl.java b/java/core/src/test/org/apache/orc/impl/TestRecordReaderImpl.java index 2e78196ce7..9eaaa61679 100644 --- a/java/core/src/test/org/apache/orc/impl/TestRecordReaderImpl.java +++ b/java/core/src/test/org/apache/orc/impl/TestRecordReaderImpl.java @@ -2104,7 +2104,7 @@ public void testPickRowGroups() throws Exception { .setKind(OrcProto.ColumnEncoding.Kind.DIRECT_V2).build()); encodings.add(OrcProto.ColumnEncoding.newBuilder() .setKind(OrcProto.ColumnEncoding.Kind.DIRECT_V2).build()); - boolean[] rows = applier.pickRowGroups(new ReaderImpl.StripeInformationImpl(stripe), + boolean[] rows = applier.pickRowGroups(new ReaderImpl.StripeInformationImpl(stripe, 0, 1, null), indexes, null, encodings, null, false); assertEquals(4, rows.length); assertEquals(false, rows[0]); @@ -2150,7 +2150,7 @@ public void testPickRowGroupsError() throws Exception { .setKind(OrcProto.ColumnEncoding.Kind.DIRECT_V2).build()); encodings.add(OrcProto.ColumnEncoding.newBuilder() .setKind(OrcProto.ColumnEncoding.Kind.DIRECT_V2).build()); - boolean[] rows = applier.pickRowGroups(new ReaderImpl.StripeInformationImpl(stripe), + boolean[] rows = applier.pickRowGroups(new ReaderImpl.StripeInformationImpl(stripe, 0, 1, null), indexes, null, encodings, null, false); assertEquals(3, rows.length); assertEquals(false, rows[0]);