From df79ad1457f22346cc195962f2bed481f526a587 Mon Sep 17 00:00:00 2001 From: Owen O'Malley Date: Mon, 17 Jun 2019 16:14:03 +0200 Subject: [PATCH] ORC-484: Update PhysicalWriter to support encryption. Fixes #402 Signed-off-by: Owen O'Malley --- .../java/org/apache/orc/PhysicalWriter.java | 35 +- .../org/apache/orc/impl/PhysicalFsWriter.java | 486 +++++++++++++++--- .../java/org/apache/orc/impl/WriterImpl.java | 35 +- .../orc/impl/writer/ListTreeWriter.java | 13 +- .../apache/orc/impl/writer/MapTreeWriter.java | 16 +- .../orc/impl/writer/StructTreeWriter.java | 14 +- .../apache/orc/impl/writer/TreeWriter.java | 11 +- .../orc/impl/writer/TreeWriterBase.java | 17 +- .../orc/impl/writer/UnionTreeWriter.java | 15 +- .../apache/orc/impl/writer/WriterContext.java | 9 + .../orc/impl/writer/WriterEncryptionKey.java | 105 ++++ .../impl/writer/WriterEncryptionVariant.java | 128 +++++ .../org/apache/orc/TestStringDictionary.java | 5 + .../org/apache/orc/TestVectorOrcFile.java | 2 +- 14 files changed, 772 insertions(+), 119 deletions(-) create mode 100644 java/core/src/java/org/apache/orc/impl/writer/WriterEncryptionKey.java create mode 100644 java/core/src/java/org/apache/orc/impl/writer/WriterEncryptionVariant.java diff --git a/java/core/src/java/org/apache/orc/PhysicalWriter.java b/java/core/src/java/org/apache/orc/PhysicalWriter.java index e25e81c046..872d5dcde7 100644 --- a/java/core/src/java/org/apache/orc/PhysicalWriter.java +++ b/java/core/src/java/org/apache/orc/PhysicalWriter.java @@ -1,4 +1,4 @@ -/** +/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -22,6 +22,8 @@ import java.nio.ByteBuffer; import org.apache.orc.impl.StreamName; +import org.apache.orc.impl.writer.StreamOptions; +import org.apache.orc.impl.writer.WriterEncryptionVariant; /** * This interface separates the physical layout of ORC files from the higher @@ -39,7 +41,6 @@ interface OutputReceiver { * Output the given buffer to the final destination * * @param buffer the buffer to output - * @throws IOException */ void output(ByteBuffer buffer) throws IOException; @@ -48,16 +49,15 @@ interface OutputReceiver { */ void suppress(); } + /** * Writes the header of the file, which consists of the magic "ORC" bytes. - * @throws IOException */ void writeHeader() throws IOException; /** * Create an OutputReceiver for the given name. * @param name the name of the stream - * @throws IOException */ OutputReceiver createDataStream(StreamName name) throws IOException; @@ -65,7 +65,6 @@ interface OutputReceiver { * Write an index in the given stream name. * @param name the name of the stream * @param index the bloom filter to write - * @param codec the compression codec to use */ void writeIndex(StreamName name, OrcProto.RowIndex.Builder index) throws IOException; @@ -74,7 +73,6 @@ void writeIndex(StreamName name, * Write a bloom filter index in the given stream name. * @param name the name of the stream * @param bloom the bloom filter to write - * @param codec the compression codec to use */ void writeBloomFilter(StreamName name, OrcProto.BloomFilterIndex.Builder bloom) throws IOException; @@ -89,6 +87,16 @@ void writeBloomFilter(StreamName name, void finalizeStripe(OrcProto.StripeFooter.Builder footer, OrcProto.StripeInformation.Builder dirEntry) throws IOException; + /** + * Write a stripe or file statistics to the file. + * @param name the name of the stream + * @param statistics the statistics to write + * @throws IOException + */ + void writeStatistics(StreamName name, + OrcProto.ColumnStatistics.Builder statistics + ) throws IOException; + /** * Writes out the file metadata. * @param builder Metadata builder to finalize and write. @@ -122,19 +130,24 @@ void finalizeStripe(OrcProto.StripeFooter.Builder footer, * @param stripe Stripe data buffer. * @param dirEntry File metadata entry for the stripe, to be updated with * relevant data. - * @throws IOException */ void appendRawStripe(ByteBuffer stripe, OrcProto.StripeInformation.Builder dirEntry ) throws IOException; - /** Gets a compression codec used by this writer. */ - CompressionCodec getCompressionCodec(); - /** * Get the number of bytes for a file in a givem column. * @param column column from which to get file size + * @param variant the encryption variant to check * @return number of bytes for the given column */ - long getFileBytes(int column); + long getFileBytes(int column, WriterEncryptionVariant variant); + + /** + * Get the unencrypted stream options for this file. This class needs the + * stream options to write the indexes and footers. + * + * Additionally, the LLAP CacheWriter wants to disable the generic compression. + */ + StreamOptions getStreamOptions(); } diff --git a/java/core/src/java/org/apache/orc/impl/PhysicalFsWriter.java b/java/core/src/java/org/apache/orc/impl/PhysicalFsWriter.java index aee16b16db..4736a63052 100644 --- a/java/core/src/java/org/apache/orc/impl/PhysicalFsWriter.java +++ b/java/core/src/java/org/apache/orc/impl/PhysicalFsWriter.java @@ -22,19 +22,25 @@ import java.io.OutputStream; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.TreeMap; +import com.google.protobuf.ByteString; import com.google.protobuf.CodedOutputStream; import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.orc.CompressionCodec; +import org.apache.orc.EncryptionVariant; import org.apache.orc.OrcFile; import org.apache.orc.OrcProto; import org.apache.orc.PhysicalWriter; +import org.apache.orc.TypeDescription; +import org.apache.orc.impl.writer.WriterEncryptionKey; +import org.apache.orc.impl.writer.WriterEncryptionVariant; import org.apache.orc.impl.writer.StreamOptions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,10 +52,12 @@ public class PhysicalFsWriter implements PhysicalWriter { private static final int HDFS_BUFFER_SIZE = 256 * 1024; private FSDataOutputStream rawWriter; + private final DirectStream rawStream; + // the compressed metadata information outStream - private OutStream writer; + private OutStream compressStream; // a protobuf outStream around streamFactory - private CodedOutputStream protobufWriter; + private CodedOutputStream codedCompressStream; private final Path path; private final HadoopShims shims; @@ -59,10 +67,7 @@ public class PhysicalFsWriter implements PhysicalWriter { private final OrcFile.CompressionStrategy compressionStrategy; private final boolean addBlockPadding; private final boolean writeVariableLengthBlocks; - - // the streams that make up the current stripe - private final Map streams = - new TreeMap<>(); + private final VariantTracker unencrypted; private long headerLength; private long stripeStart; @@ -70,11 +75,24 @@ public class PhysicalFsWriter implements PhysicalWriter { // natural blocks private long blockOffset; private int metadataLength; + private int stripeStatisticsLength = 0; private int footerLength; + private int stripeNumber = 0; + + private final Map variants = new TreeMap<>(); + + public PhysicalFsWriter(FileSystem fs, + Path path, + OrcFile.WriterOptions opts + ) throws IOException { + this(fs, path, opts, new WriterEncryptionVariant[0]); + } public PhysicalFsWriter(FileSystem fs, Path path, - OrcFile.WriterOptions opts) throws IOException { + OrcFile.WriterOptions opts, + WriterEncryptionVariant[] encryption + ) throws IOException { this.path = path; long defaultStripeSize = opts.getStripeSize(); this.addBlockPadding = opts.getBlockPadding(); @@ -98,16 +116,124 @@ public PhysicalFsWriter(FileSystem fs, rawWriter = fs.create(path, opts.getOverwrite(), HDFS_BUFFER_SIZE, fs.getDefaultReplication(path), blockSize); blockOffset = 0; - writer = new OutStream("metadata", compress, - new DirectStream(rawWriter)); - protobufWriter = CodedOutputStream.newInstance(writer); + unencrypted = new VariantTracker(opts.getSchema(), compress); writeVariableLengthBlocks = opts.getWriteVariableLengthBlocks(); shims = opts.getHadoopShims(); + rawStream = new DirectStream(rawWriter); + compressStream = new OutStream("stripe footer", compress, rawStream); + codedCompressStream = CodedOutputStream.newInstance(compressStream); + for(WriterEncryptionVariant variant: encryption) { + WriterEncryptionKey key = variant.getKeyDescription(); + StreamOptions encryptOptions = + new StreamOptions(unencrypted.options) + .withEncryption(key.getAlgorithm(), variant.getFileFooterKey()); + variants.put(variant, new VariantTracker(variant.getRoot(), encryptOptions)); + } } - @Override - public CompressionCodec getCompressionCodec() { - return compress.getCodec(); + /** + * Record the information about each column encryption variant. + * The unencrypted data and each encrypted column root are variants. + */ + protected static class VariantTracker { + // the streams that make up the current stripe + protected final Map streams = new TreeMap<>(); + private final int rootColumn; + private final int lastColumn; + protected final StreamOptions options; + // a list for each column covered by this variant + // the elements in the list correspond to each stripe in the file + protected final List[] stripeStats; + protected final List stripeStatsStreams = new ArrayList<>(); + protected final OrcProto.ColumnStatistics[] fileStats; + + VariantTracker(TypeDescription schema, StreamOptions options) { + rootColumn = schema.getId(); + lastColumn = schema.getMaximumId(); + this.options = options; + stripeStats = new List[schema.getMaximumId() - schema.getId() + 1]; + for(int i=0; i < stripeStats.length; ++i) { + stripeStats[i] = new ArrayList<>(); + } + fileStats = new OrcProto.ColumnStatistics[stripeStats.length]; + } + + public BufferedStream createStream(StreamName name) { + BufferedStream result = new BufferedStream(); + streams.put(name, result); + return result; + } + + /** + * Place the streams in the appropriate area while updating the sizes + * with the number of bytes in the area. + * @param area the area to write + * @param sizes the sizes of the areas + * @return the list of stream descriptions to add + */ + public List placeStreams(StreamName.Area area, + SizeCounters sizes) { + List result = new ArrayList<>(streams.size()); + for(Map.Entry stream: streams.entrySet()) { + StreamName name = stream.getKey(); + BufferedStream bytes = stream.getValue(); + if (name.getArea() == area && !bytes.isSuppressed) { + OrcProto.Stream.Builder builder = OrcProto.Stream.newBuilder(); + long size = bytes.getOutputSize(); + if (area == StreamName.Area.INDEX) { + sizes.index += size; + } else { + sizes.data += size; + } + builder.setColumn(name.getColumn()) + .setKind(name.getKind()) + .setLength(size); + result.add(builder.build()); + } + } + return result; + } + + /** + * Write the streams in the appropriate area. + * @param area the area to write + * @param raw the raw stream to write to + */ + public void writeStreams(StreamName.Area area, + FSDataOutputStream raw) throws IOException { + for(Map.Entry stream: streams.entrySet()) { + if (stream.getKey().getArea() == area) { + stream.getValue().spillToDiskAndClear(raw); + } + } + } + + /** + * Computed the size of the given column on disk for this stripe. + * It excludes the index streams. + * @param column a column id + * @return the total number of bytes + */ + public long getFileBytes(int column) { + long result = 0; + if (column >= rootColumn && column <= lastColumn) { + for(Map.Entry entry: streams.entrySet()) { + StreamName name = entry.getKey(); + if (name.getColumn() == column && + name.getArea() != StreamName.Area.INDEX) { + result += entry.getValue().getOutputSize(); + } + } + } + return result; + } + } + + VariantTracker getVariant(EncryptionVariant column) { + if (column == null) { + return unencrypted; + } + return variants.get(column); } /** @@ -120,20 +246,13 @@ public CompressionCodec getCompressionCodec() { * @return number of bytes for the given column */ @Override - public long getFileBytes(final int column) { - long size = 0; - for (final Map.Entry pair: streams.entrySet()) { - final BufferedStream receiver = pair.getValue(); - if(!receiver.isSuppressed) { - - final StreamName name = pair.getKey(); - if(name.getColumn() == column && name.getArea() != StreamName.Area.INDEX ) { - size += receiver.getOutputSize(); - } - } + public long getFileBytes(int column, WriterEncryptionVariant variant) { + return getVariant(variant).getFileBytes(column); + } - } - return size; + @Override + public StreamOptions getStreamOptions() { + return unencrypted.options; } private static final byte[] ZEROS = new byte[64*1024]; @@ -198,36 +317,139 @@ public void suppress() { } private void writeStripeFooter(OrcProto.StripeFooter footer, - long dataSize, - long indexSize, + SizeCounters sizes, OrcProto.StripeInformation.Builder dirEntry) throws IOException { - footer.writeTo(protobufWriter); - protobufWriter.flush(); - writer.flush(); + footer.writeTo(codedCompressStream); + codedCompressStream.flush(); + compressStream.flush(); dirEntry.setOffset(stripeStart); - dirEntry.setFooterLength(rawWriter.getPos() - stripeStart - dataSize - indexSize); + dirEntry.setFooterLength(rawWriter.getPos() - stripeStart - sizes.total()); + } + + /** + * Write the saved encrypted stripe statistic in a variant out to the file. + * The streams that are written are added to the tracker.stripeStatsStreams. + * @param output the file we are writing to + * @param stripeNumber the number of stripes in the file + * @param tracker the variant to write out + */ + static void writeEncryptedStripeStatistics(DirectStream output, + int stripeNumber, + VariantTracker tracker + ) throws IOException { + StreamOptions options = new StreamOptions(tracker.options); + tracker.stripeStatsStreams.clear(); + for(int col = tracker.rootColumn; + col < tracker.rootColumn + tracker.stripeStats.length; ++col) { + options.modifyIv(CryptoUtils.modifyIvForStream(col, + OrcProto.Stream.Kind.STRIPE_STATISTICS, stripeNumber)); + OutStream stream = new OutStream("stripe stats for " + col, + options, output); + OrcProto.ColumnarStripeStatistics stats = + OrcProto.ColumnarStripeStatistics.newBuilder() + .addAllColStats(tracker.stripeStats[col - tracker.rootColumn]) + .build(); + long start = output.output.getPos(); + stats.writeTo(stream); + stream.flush(); + OrcProto.Stream description = OrcProto.Stream.newBuilder() + .setColumn(col) + .setKind(OrcProto.Stream.Kind.STRIPE_STATISTICS) + .setLength(output.output.getPos() - start) + .build(); + tracker.stripeStatsStreams.add(description); + } + } + + /** + * Merge the saved unencrypted stripe statistics into the Metadata section + * of the footer. + * @param builder the Metadata section of the file + * @param stripeCount the number of stripes in the file + * @param stats the stripe statistics + */ + static void setUnencryptedStripeStatistics(OrcProto.Metadata.Builder builder, + int stripeCount, + List[] stats) { + // Make the unencrypted stripe stats into lists of StripeStatistics. + builder.clearStripeStats(); + for(int s=0; s < stripeCount; ++s) { + OrcProto.StripeStatistics.Builder stripeStats = + OrcProto.StripeStatistics.newBuilder(); + for(List col: stats) { + stripeStats.addColStats(col.get(s)); + } + builder.addStripeStats(stripeStats.build()); + } + } + + static void setEncryptionStatistics(OrcProto.Encryption.Builder encryption, + int stripeNumber, + Collection variants + ) throws IOException { + int v = 0; + for(VariantTracker variant: variants) { + OrcProto.EncryptionVariant.Builder variantBuilder = + encryption.getVariantsBuilder(v++); + + // Add the stripe statistics streams to the variant description. + variantBuilder.clearStripeStatistics(); + variantBuilder.addAllStripeStatistics(variant.stripeStatsStreams); + + // Serialize and encrypt the file statistics. + OrcProto.FileStatistics.Builder file = OrcProto.FileStatistics.newBuilder(); + for(OrcProto.ColumnStatistics col: variant.fileStats) { + file.addColumn(col); + } + StreamOptions options = new StreamOptions(variant.options); + options.modifyIv(CryptoUtils.modifyIvForStream(variant.rootColumn, + OrcProto.Stream.Kind.FILE_STATISTICS, stripeNumber)); + BufferedStream buffer = new BufferedStream(); + OutStream stream = new OutStream("stats for " + variant, options, buffer); + file.build().writeTo(stream); + stream.flush(); + variantBuilder.setFileStatistics(buffer.getBytes()); + } } @Override public void writeFileMetadata(OrcProto.Metadata.Builder builder) throws IOException { - long startPosn = rawWriter.getPos(); - OrcProto.Metadata metadata = builder.build(); - metadata.writeTo(protobufWriter); - protobufWriter.flush(); - writer.flush(); - this.metadataLength = (int) (rawWriter.getPos() - startPosn); + long stripeStatisticsStart = rawWriter.getPos(); + for(VariantTracker variant: variants.values()) { + writeEncryptedStripeStatistics(rawStream, stripeNumber, variant); + } + setUnencryptedStripeStatistics(builder, stripeNumber, unencrypted.stripeStats); + long metadataStart = rawWriter.getPos(); + builder.build().writeTo(codedCompressStream); + codedCompressStream.flush(); + compressStream.flush(); + this.stripeStatisticsLength = (int) (metadataStart - stripeStatisticsStart); + this.metadataLength = (int) (rawWriter.getPos() - metadataStart); + } + + static void addUnencryptedStatistics(OrcProto.Footer.Builder builder, + OrcProto.ColumnStatistics[] stats) { + for(OrcProto.ColumnStatistics stat: stats) { + builder.addStatistics(stat); + } } @Override public void writeFileFooter(OrcProto.Footer.Builder builder) throws IOException { - long bodyLength = rawWriter.getPos() - metadataLength; + if (variants.size() > 0) { + OrcProto.Encryption.Builder encryption = builder.getEncryptionBuilder(); + setEncryptionStatistics(encryption, stripeNumber, variants.values()); + builder.setStripeStatisticsLength(stripeStatisticsLength); + } + addUnencryptedStatistics(builder, unencrypted.fileStats); + long bodyLength = rawWriter.getPos() - metadataLength - stripeStatisticsLength; builder.setContentLength(bodyLength); builder.setHeaderLength(headerLength); long startPosn = rawWriter.getPos(); OrcProto.Footer footer = builder.build(); - footer.writeTo(protobufWriter); - protobufWriter.flush(); - writer.flush(); + footer.writeTo(codedCompressStream); + codedCompressStream.flush(); + compressStream.flush(); this.footerLength = (int) (rawWriter.getPos() - startPosn); } @@ -300,7 +522,7 @@ public void appendRawStripe(ByteBuffer buffer, * data as buffers fill up and stores them in the output list. When the * stripe is being written, the whole stream is written to the file. */ - private static final class BufferedStream implements OutputReceiver { + static final class BufferedStream implements OutputReceiver { private boolean isSuppressed = false; private final List output = new ArrayList<>(); @@ -319,17 +541,56 @@ public void suppress() { /** * Write any saved buffers to the OutputStream if needed, and clears all the * buffers. + * @return true if the stream was written */ - void spillToDiskAndClear(FSDataOutputStream raw - ) throws IOException { + boolean spillToDiskAndClear(FSDataOutputStream raw) throws IOException { if (!isSuppressed) { for (ByteBuffer buffer: output) { raw.write(buffer.array(), buffer.arrayOffset() + buffer.position(), buffer.remaining()); } output.clear(); + return true; } isSuppressed = false; + return false; + } + + /** + * Get the buffer as a protobuf ByteString and clears the BufferedStream. + * @return the bytes + */ + ByteString getBytes() { + int len = output.size(); + if (len == 0) { + return ByteString.EMPTY; + } else { + ByteString result = ByteString.copyFrom(output.get(0)); + for (int i=1; i < output.size(); ++i) { + result = result.concat(ByteString.copyFrom(output.get(i))); + } + output.clear(); + return result; + } + } + + /** + * Get the stream as a ByteBuffer and clear it. + * @return a single ByteBuffer with the contents of the stream + */ + ByteBuffer getByteBuffer() { + ByteBuffer result; + if (output.size() == 1) { + result = output.get(0); + } else { + result = ByteBuffer.allocate((int) getOutputSize()); + for (ByteBuffer buffer : output) { + result.put(buffer); + } + output.clear(); + result.flip(); + } + return result; } /** @@ -347,38 +608,86 @@ public long getOutputSize() { } } + static class SizeCounters { + long index = 0; + long data = 0; + + long total() { + return index + data; + } + } + + void buildStreamList(OrcProto.StripeFooter.Builder footerBuilder, + SizeCounters sizes + ) throws IOException { + footerBuilder.addAllStreams( + unencrypted.placeStreams(StreamName.Area.INDEX, sizes)); + final long unencryptedIndexSize = sizes.index; + int v = 0; + for (VariantTracker variant: variants.values()) { + OrcProto.StripeEncryptionVariant.Builder builder = + footerBuilder.getEncryptionBuilder(v++); + builder.addAllStreams( + variant.placeStreams(StreamName.Area.INDEX, sizes)); + } + if (sizes.index != unencryptedIndexSize) { + // add a placeholder that covers the hole where the encrypted indexes are + footerBuilder.addStreams(OrcProto.Stream.newBuilder() + .setKind(OrcProto.Stream.Kind.ENCRYPTED_INDEX) + .setLength(sizes.index - unencryptedIndexSize)); + } + footerBuilder.addAllStreams( + unencrypted.placeStreams(StreamName.Area.DATA, sizes)); + final long unencryptedDataSize = sizes.data; + v = 0; + for (VariantTracker variant: variants.values()) { + OrcProto.StripeEncryptionVariant.Builder builder = + footerBuilder.getEncryptionBuilder(v++); + builder.addAllStreams( + variant.placeStreams(StreamName.Area.DATA, sizes)); + } + if (sizes.data != unencryptedDataSize) { + // add a placeholder that covers the hole where the encrypted indexes are + footerBuilder.addStreams(OrcProto.Stream.newBuilder() + .setKind(OrcProto.Stream.Kind.ENCRYPTED_DATA) + .setLength(sizes.data - unencryptedDataSize)); + } + } + @Override public void finalizeStripe(OrcProto.StripeFooter.Builder footerBuilder, OrcProto.StripeInformation.Builder dirEntry ) throws IOException { - long indexSize = 0; - long dataSize = 0; - for (Map.Entry pair: streams.entrySet()) { - BufferedStream receiver = pair.getValue(); - if (!receiver.isSuppressed) { - long streamSize = receiver.getOutputSize(); - StreamName name = pair.getKey(); - footerBuilder.addStreams(OrcProto.Stream.newBuilder().setColumn(name.getColumn()) - .setKind(name.getKind()).setLength(streamSize)); - if (StreamName.Area.INDEX == name.getArea()) { - indexSize += streamSize; - } else { - dataSize += streamSize; - } - } - } - dirEntry.setIndexLength(indexSize).setDataLength(dataSize); + SizeCounters sizes = new SizeCounters(); + buildStreamList(footerBuilder, sizes); OrcProto.StripeFooter footer = footerBuilder.build(); + // Do we need to pad the file so the stripe doesn't straddle a block boundary? - padStripe(indexSize + dataSize + footer.getSerializedSize()); + padStripe(sizes.total() + footer.getSerializedSize()); + + // write the unencrypted index streams + unencrypted.writeStreams(StreamName.Area.INDEX, rawWriter); + // write the encrypted index streams + for (VariantTracker variant: variants.values()) { + variant.writeStreams(StreamName.Area.INDEX, rawWriter); + } - // write out the data streams - for (Map.Entry pair : streams.entrySet()) { - pair.getValue().spillToDiskAndClear(rawWriter); + // write the unencrypted data streams + unencrypted.writeStreams(StreamName.Area.DATA, rawWriter); + // write out the unencrypted data streams + for (VariantTracker variant: variants.values()) { + variant.writeStreams(StreamName.Area.DATA, rawWriter); } + // Write out the footer. - writeStripeFooter(footer, dataSize, indexSize, dirEntry); + writeStripeFooter(footer, sizes, dirEntry); + + // fill in the data sizes + dirEntry.setDataLength(sizes.data); + dirEntry.setIndexLength(sizes.index); + + stripeNumber += 1; } @Override @@ -389,10 +698,11 @@ public void writeHeader() throws IOException { @Override public BufferedStream createDataStream(StreamName name) { - BufferedStream result = streams.get(name); + VariantTracker variant = getVariant(name.getEncryption()); + BufferedStream result = variant.streams.get(name); if (result == null) { result = new BufferedStream(); - streams.put(name, result); + variant.streams.put(name, result); } return result; } @@ -402,11 +712,26 @@ private StreamOptions getOptions(OrcProto.Stream.Kind kind) { kind); } + protected OutputStream createIndexStream(StreamName name) { + BufferedStream buffer = createDataStream(name); + VariantTracker tracker = getVariant(name.getEncryption()); + StreamOptions options = + SerializationUtils.getCustomizedCodec(tracker.options, + compressionStrategy, name.getKind()); + if (options.isEncrypted()) { + if (options == tracker.options) { + options = new StreamOptions(options); + } + options.modifyIv(CryptoUtils.modifyIvForStream(name, stripeNumber)); + } + return new OutStream(name.toString(), options, buffer); + } + @Override public void writeIndex(StreamName name, - OrcProto.RowIndex.Builder index) throws IOException { - OutputStream stream = new OutStream(path.toString(), - getOptions(name.getKind()), createDataStream(name)); + OrcProto.RowIndex.Builder index + ) throws IOException { + OutputStream stream = createIndexStream(name); index.build().writeTo(stream); stream.flush(); } @@ -415,12 +740,25 @@ public void writeIndex(StreamName name, public void writeBloomFilter(StreamName name, OrcProto.BloomFilterIndex.Builder bloom ) throws IOException { - OutputStream stream = new OutStream(path.toString(), - getOptions(name.getKind()), createDataStream(name)); + OutputStream stream = createIndexStream(name); bloom.build().writeTo(stream); stream.flush(); } + @Override + public void writeStatistics(StreamName name, + OrcProto.ColumnStatistics.Builder statistics + ) { + VariantTracker tracker = getVariant(name.getEncryption()); + if (name.getKind() == OrcProto.Stream.Kind.FILE_STATISTICS) { + tracker.fileStats[name.getColumn() - tracker.rootColumn] = + statistics.build(); + } else { + tracker.stripeStats[name.getColumn() - tracker.rootColumn] + .add(statistics.build()); + } + } + @Override public String toString() { return path.toString(); diff --git a/java/core/src/java/org/apache/orc/impl/WriterImpl.java b/java/core/src/java/org/apache/orc/impl/WriterImpl.java index 639f963398..4add5d3298 100644 --- a/java/core/src/java/org/apache/orc/impl/WriterImpl.java +++ b/java/core/src/java/org/apache/orc/impl/WriterImpl.java @@ -161,7 +161,7 @@ public Writer getWriter() { ? new PhysicalFsWriter(fs, path, opts) : opts.getPhysicalWriter(); physicalWriter.writeHeader(); - CompressionCodec codec = physicalWriter.getCompressionCodec(); + CompressionCodec codec = physicalWriter.getStreamOptions().getCodec(); if (codec != null) { compress.withCodec(codec, codec.getDefaultOptions()); } @@ -312,7 +312,7 @@ public boolean buildIndex() { * @return are the streams compressed */ public boolean isCompressed() { - return physicalWriter.getCompressionCodec() != null; + return physicalWriter.getStreamOptions().getCodec() != null; } /** @@ -379,6 +379,13 @@ public void writeBloomFilter(StreamName name, physicalWriter.writeBloomFilter(name, bloom); } + @Override + public void writeStatistics(StreamName name, + OrcProto.ColumnStatistics.Builder stats + ) throws IOException { + physicalWriter.writeStatistics(name, stats); + } + public boolean getUseUTCTimestamp() { return useUTCTimeZone; } @@ -453,9 +460,8 @@ private OrcProto.CompressionKind writeCompressionKind(CompressionKind kind) { } } - private void writeFileStatistics(OrcProto.Footer.Builder builder, - TreeWriter writer) throws IOException { - writer.writeFileStatistics(builder); + private void writeFileStatistics(TreeWriter writer) throws IOException { + writer.writeFileStatistics(); } private void writeMetadata() throws IOException { @@ -492,7 +498,7 @@ private long writeFooter() throws IOException { builder.addStripes(stripe); } // add the column statistics - writeFileStatistics(builder, treeWriter); + writeFileStatistics(treeWriter); // add all of the user metadata for(Map.Entry entry: userMetadata.entrySet()) { builder.addMetadata(OrcProto.UserMetadataItem.newBuilder() @@ -633,18 +639,17 @@ public void appendUserMetadata(List userMetadata) { } @Override - public ColumnStatistics[] getStatistics() - throws IOException { - // Generate the stats - OrcProto.Footer.Builder builder = OrcProto.Footer.newBuilder(); - - // add the column statistics - writeFileStatistics(builder, treeWriter); - return ReaderImpl.deserializeStats(schema, builder.getStatisticsList()); + public ColumnStatistics[] getStatistics() { + // get the column statistics + final ColumnStatistics[] result = + new ColumnStatistics[schema.getMaximumId() + 1]; + // Get the file statistics, preferring the encrypted one. + treeWriter.getCurrentStatistics(result); + return result; } public CompressionCodec getCompressionCodec() { - return physicalWriter.getCompressionCodec(); + return physicalWriter.getStreamOptions().getCodec(); } private static boolean hasTimestamp(TypeDescription schema) { diff --git a/java/core/src/java/org/apache/orc/impl/writer/ListTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/ListTreeWriter.java index c6068cdec0..a433e83cf7 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/ListTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/ListTreeWriter.java @@ -20,6 +20,7 @@ import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.ListColumnVector; +import org.apache.orc.ColumnStatistics; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; import org.apache.orc.impl.IntegerWriter; @@ -153,9 +154,9 @@ public long getRawDataSize() { } @Override - public void writeFileStatistics(OrcProto.Footer.Builder footer) { - super.writeFileStatistics(footer); - childWriter.writeFileStatistics(footer); + public void writeFileStatistics() throws IOException { + super.writeFileStatistics(); + childWriter.writeFileStatistics(); } @Override @@ -164,4 +165,10 @@ public void flushStreams() throws IOException { lengths.flush(); childWriter.flushStreams(); } + + @Override + public void getCurrentStatistics(ColumnStatistics[] output) { + super.getCurrentStatistics(output); + childWriter.getCurrentStatistics(output); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/MapTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/MapTreeWriter.java index 91e56578ae..5c09b365af 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/MapTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/MapTreeWriter.java @@ -19,6 +19,7 @@ import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.MapColumnVector; +import org.apache.orc.ColumnStatistics; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; import org.apache.orc.impl.IntegerWriter; @@ -164,10 +165,10 @@ public long getRawDataSize() { } @Override - public void writeFileStatistics(OrcProto.Footer.Builder footer) { - super.writeFileStatistics(footer); - keyWriter.writeFileStatistics(footer); - valueWriter.writeFileStatistics(footer); + public void writeFileStatistics() throws IOException { + super.writeFileStatistics(); + keyWriter.writeFileStatistics(); + valueWriter.writeFileStatistics(); } @Override @@ -177,4 +178,11 @@ public void flushStreams() throws IOException { keyWriter.flushStreams(); valueWriter.flushStreams(); } + + @Override + public void getCurrentStatistics(ColumnStatistics[] output) { + super.getCurrentStatistics(output); + keyWriter.getCurrentStatistics(output); + valueWriter.getCurrentStatistics(output); + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/StructTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/StructTreeWriter.java index ee0b0c041a..d66f8468ec 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/StructTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/StructTreeWriter.java @@ -21,6 +21,7 @@ import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.StructColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.orc.ColumnStatistics; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; @@ -147,10 +148,10 @@ public long getRawDataSize() { } @Override - public void writeFileStatistics(OrcProto.Footer.Builder footer) { - super.writeFileStatistics(footer); + public void writeFileStatistics() throws IOException { + super.writeFileStatistics(); for (TreeWriter child : childrenWriters) { - child.writeFileStatistics(footer); + child.writeFileStatistics(); } } @@ -160,6 +161,13 @@ public void flushStreams() throws IOException { for (TreeWriter child : childrenWriters) { child.flushStreams(); } + } + @Override + public void getCurrentStatistics(ColumnStatistics[] output) { + super.getCurrentStatistics(output); + for (TreeWriter child: childrenWriters) { + child.getCurrentStatistics(output); + } } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/TreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/TreeWriter.java index bfa403eeff..662c57db48 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/TreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/TreeWriter.java @@ -20,6 +20,7 @@ import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.orc.ColumnStatistics; import org.apache.orc.OrcFile; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; @@ -98,9 +99,15 @@ void writeStripe(OrcProto.StripeFooter.Builder stripeFooter, /** * Add the file statistics to the file footer. - * @param footer the file footer builder */ - void writeFileStatistics(OrcProto.Footer.Builder footer); + void writeFileStatistics() throws IOException; + + /** + * Get the current file statistics for each column. If a column is encrypted, + * the encrypted variant statistics are used. + * @param output an array that is filled in with the results + */ + void getCurrentStatistics(ColumnStatistics[] output); class Factory { public static TreeWriter create(TypeDescription schema, diff --git a/java/core/src/java/org/apache/orc/impl/writer/TreeWriterBase.java b/java/core/src/java/org/apache/orc/impl/writer/TreeWriterBase.java index 7934b21755..4c36d09956 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/TreeWriterBase.java +++ b/java/core/src/java/org/apache/orc/impl/writer/TreeWriterBase.java @@ -23,6 +23,7 @@ import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.orc.ColumnStatistics; import org.apache.orc.OrcFile; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; @@ -256,12 +257,15 @@ public void writeStripe(OrcProto.StripeFooter.Builder builder, } /* Update byte count */ - final long byteCount = streamFactory.getPhysicalWriter().getFileBytes(id); + final long byteCount = streamFactory.getPhysicalWriter().getFileBytes(id, null); stripeColStatistics.updateByteCount(byteCount); // merge stripe-level column statistics to file statistics and write it to // stripe statistics fileStatistics.merge(stripeColStatistics); + streamFactory.writeStatistics( + new StreamName(id, OrcProto.Stream.Kind.STRIPE_STATISTICS, null), + stripeColStatistics.serialize()); stats.addColStats(stripeColStatistics.serialize()); stripeColStatistics.reset(); @@ -369,8 +373,10 @@ public long estimateMemory() { } @Override - public void writeFileStatistics(OrcProto.Footer.Builder footer) { - footer.addStatistics(fileStatistics.serialize()); + public void writeFileStatistics() throws IOException { + streamFactory.writeStatistics(new StreamName(id, + OrcProto.Stream.Kind.FILE_STATISTICS, null), + fileStatistics.serialize()); } static class RowIndexPositionRecorder implements PositionRecorder { @@ -385,4 +391,9 @@ public void addPosition(long position) { builder.addPositions(position); } } + + @Override + public void getCurrentStatistics(ColumnStatistics[] output) { + output[id] = fileStatistics; + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/UnionTreeWriter.java b/java/core/src/java/org/apache/orc/impl/writer/UnionTreeWriter.java index 54a9a3a6dc..4456b62f55 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/UnionTreeWriter.java +++ b/java/core/src/java/org/apache/orc/impl/writer/UnionTreeWriter.java @@ -20,6 +20,7 @@ import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.UnionColumnVector; +import org.apache.orc.ColumnStatistics; import org.apache.orc.OrcProto; import org.apache.orc.TypeDescription; import org.apache.orc.impl.PositionRecorder; @@ -165,10 +166,10 @@ public long getRawDataSize() { } @Override - public void writeFileStatistics(OrcProto.Footer.Builder footer) { - super.writeFileStatistics(footer); + public void writeFileStatistics() throws IOException { + super.writeFileStatistics(); for (TreeWriter child : childrenWriters) { - child.writeFileStatistics(footer); + child.writeFileStatistics(); } } @@ -180,4 +181,12 @@ public void flushStreams() throws IOException { child.flushStreams(); } } + + @Override + public void getCurrentStatistics(ColumnStatistics[] output) { + super.getCurrentStatistics(output); + for (TreeWriter child: childrenWriters) { + child.getCurrentStatistics(output); + } + } } diff --git a/java/core/src/java/org/apache/orc/impl/writer/WriterContext.java b/java/core/src/java/org/apache/orc/impl/writer/WriterContext.java index 9ef3ddaf2e..cfd2ec7ab0 100644 --- a/java/core/src/java/org/apache/orc/impl/writer/WriterContext.java +++ b/java/core/src/java/org/apache/orc/impl/writer/WriterContext.java @@ -102,6 +102,15 @@ void writeBloomFilter(StreamName name, OrcProto.BloomFilterIndex.Builder bloom ) throws IOException; + /** + * Set the column statistics for the stripe or file. + * @param name the name of the statistics stream + * @param stats the statistics for this column in this stripe + */ + void writeStatistics(StreamName name, + OrcProto.ColumnStatistics.Builder stats + ) throws IOException; + boolean getUseUTCTimestamp(); double getDictionaryKeySizeThreshold(int column); diff --git a/java/core/src/java/org/apache/orc/impl/writer/WriterEncryptionKey.java b/java/core/src/java/org/apache/orc/impl/writer/WriterEncryptionKey.java new file mode 100644 index 0000000000..40606aa98c --- /dev/null +++ b/java/core/src/java/org/apache/orc/impl/writer/WriterEncryptionKey.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.orc.impl.writer; + +import org.apache.orc.EncryptionAlgorithm; +import org.apache.orc.EncryptionKey; +import org.apache.orc.EncryptionVariant; +import org.apache.orc.impl.HadoopShims; +import org.jetbrains.annotations.NotNull; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class WriterEncryptionKey implements EncryptionKey { + private final HadoopShims.KeyMetadata metadata; + private final List roots = new ArrayList<>(); + private int id; + + public WriterEncryptionKey(HadoopShims.KeyMetadata key) { + this.metadata = key; + } + + public void addRoot(WriterEncryptionVariant root) { + roots.add(root); + } + + public HadoopShims.KeyMetadata getMetadata() { + return metadata; + } + + public void setId(int id) { + this.id = id; + } + + @Override + public String getKeyName() { + return metadata.getKeyName(); + } + + @Override + public int getKeyVersion() { + return metadata.getVersion(); + } + + public EncryptionAlgorithm getAlgorithm() { + return metadata.getAlgorithm(); + } + + @Override + public WriterEncryptionVariant[] getEncryptionRoots() { + return roots.toArray(new WriterEncryptionVariant[roots.size()]); + } + + public int getId() { + return id; + } + + public void sortRoots() { + Collections.sort(roots); + } + + @Override + public int hashCode() { + return id; + } + + @Override + public boolean equals(Object other) { + if (other == null || getClass() != other.getClass()) { + return false; + } + return compareTo((EncryptionKey) other) == 0; + } + + @Override + public int compareTo(@NotNull EncryptionKey other) { + int result = getKeyName().compareTo(other.getKeyName()); + if (result == 0) { + result = Integer.compare(getKeyVersion(), other.getKeyVersion()); + } + return result; + } + + @Override + public String toString() { + return metadata.toString(); + } +} diff --git a/java/core/src/java/org/apache/orc/impl/writer/WriterEncryptionVariant.java b/java/core/src/java/org/apache/orc/impl/writer/WriterEncryptionVariant.java new file mode 100644 index 0000000000..ed026f7a90 --- /dev/null +++ b/java/core/src/java/org/apache/orc/impl/writer/WriterEncryptionVariant.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.orc.impl.writer; + +import org.apache.orc.EncryptionVariant; +import org.apache.orc.OrcProto; +import org.apache.orc.TypeDescription; +import org.apache.orc.impl.LocalKey; +import org.jetbrains.annotations.NotNull; + +import java.security.Key; +import java.util.ArrayList; +import java.util.List; + +public class WriterEncryptionVariant implements EncryptionVariant { + private int id; + private final WriterEncryptionKey key; + private final TypeDescription root; + private final LocalKey material; + private final OrcProto.FileStatistics.Builder fileStats = + OrcProto.FileStatistics.newBuilder(); + private final List encodings = new ArrayList<>(); + + public WriterEncryptionVariant(WriterEncryptionKey key, + TypeDescription root, + LocalKey columnKey) { + this.key = key; + this.root = root; + this.material = columnKey; + } + + @Override + public WriterEncryptionKey getKeyDescription() { + return key; + } + + public TypeDescription getRoot() { + return root; + } + + public void setId(int id) { + this.id = id; + } + + @Override + public int getVariantId() { + return id; + } + + @Override + public Key getFileFooterKey() { + return material.getDecryptedKey(); + } + + @Override + public Key getStripeKey(long stripe) { + return material.getDecryptedKey(); + } + + public LocalKey getMaterial() { + return material; + } + + public void clearFileStatistics() { + fileStats.clearColumn(); + } + + public void addFileStatistics(OrcProto.ColumnStatistics column) { + fileStats.addColumn(column); + } + + public OrcProto.FileStatistics getFileStatistics() { + return fileStats.build(); + } + + public void addEncoding(OrcProto.ColumnEncoding encoding) { + encodings.add(encoding); + } + + public List getEncodings() { + return encodings; + } + + public void clearEncodings() { + encodings.clear(); + } + + @Override + public int hashCode() { + return key.hashCode() << 16 ^ root.getId(); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } else if (other == null || other.getClass() != getClass()) { + return false; + } + return compareTo((WriterEncryptionVariant) other) == 0; + } + + @Override + public int compareTo(@NotNull EncryptionVariant other) { + int result = key.compareTo(other.getKeyDescription()); + if (result == 0) { + result = Integer.compare(root.getId(), other.getRoot().getId()); + } + return result; + } +} + diff --git a/java/core/src/test/org/apache/orc/TestStringDictionary.java b/java/core/src/test/org/apache/orc/TestStringDictionary.java index 27965fe375..2ab15064c0 100644 --- a/java/core/src/test/org/apache/orc/TestStringDictionary.java +++ b/java/core/src/test/org/apache/orc/TestStringDictionary.java @@ -239,6 +239,11 @@ public void writeBloomFilter(StreamName name, } + @Override + public void writeStatistics(StreamName name, OrcProto.ColumnStatistics.Builder stats) throws IOException { + + } + @Override public boolean getUseUTCTimestamp() { return true; diff --git a/java/core/src/test/org/apache/orc/TestVectorOrcFile.java b/java/core/src/test/org/apache/orc/TestVectorOrcFile.java index 658c1cea71..95f6458926 100644 --- a/java/core/src/test/org/apache/orc/TestVectorOrcFile.java +++ b/java/core/src/test/org/apache/orc/TestVectorOrcFile.java @@ -2099,7 +2099,7 @@ private CompressionCodec writeBatchesAndGetCodec(int count, ) throws IOException { fs.delete(testFilePath, false); PhysicalWriter physical = new PhysicalFsWriter(fs, testFilePath, opts); - CompressionCodec codec = physical.getCompressionCodec(); + CompressionCodec codec = physical.getStreamOptions().getCodec(); Writer writer = OrcFile.createWriter(testFilePath, opts.physicalWriter(physical)); writeRandomIntBytesBatches(writer, batch, count, size);