Skip to content

Commit

Permalink
ORC-523: Update ReaderImpl to work with column encryption.
Browse files Browse the repository at this point in the history
Fixes #408

Signed-off-by: Owen O'Malley <omalley@apache.org>
  • Loading branch information
omalley committed Jul 3, 2019
1 parent 7afd89a commit 33916ab
Show file tree
Hide file tree
Showing 9 changed files with 646 additions and 55 deletions.
14 changes: 11 additions & 3 deletions java/core/src/java/org/apache/orc/OrcUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -616,9 +616,17 @@ TypeDescription convertTypeFromProtobuf(List<OrcProto.Type> types,

public static List<StripeInformation> convertProtoStripesToStripes(
List<OrcProto.StripeInformation> stripes) {
List<StripeInformation> result = new ArrayList<StripeInformation>(stripes.size());
for (OrcProto.StripeInformation info : stripes) {
result.add(new ReaderImpl.StripeInformationImpl(info));
List<StripeInformation> result = new ArrayList<>(stripes.size());
long previousStripeId = -1;
byte[][] previousKeys = null;
long stripeId = 0;
for (OrcProto.StripeInformation stripeProto: stripes) {
ReaderImpl.StripeInformationImpl stripe =
new ReaderImpl.StripeInformationImpl(stripeProto, stripeId++,
previousStripeId, previousKeys);
result.add(stripe);
previousStripeId = stripe.getEncryptionStripeId();
previousKeys = stripe.getEncryptedLocalKeys();
}
return result;
}
Expand Down
21 changes: 21 additions & 0 deletions java/core/src/java/org/apache/orc/StripeInformation.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,25 @@ public interface StripeInformation {
* @return a count of the number of rows
*/
long getNumberOfRows();

/**
* Get the index of this stripe in the current file.
* @return 0 to number_of_stripes - 1
*/
long getStripeId();

/**
* Get the original stripe id that was used when the stripe was originally
* written. This is only different that getStripeId in merged files.
* @return the original stripe id
*/
long getEncryptionStripeId();

/**
* Get the encrypted keys starting from this stripe until overridden by
* a new set in a following stripe. The top level array is one for each
* encryption variant. Each element is an encrypted key.
* @return the array of encrypted keys
*/
byte[][] getEncryptedLocalKeys();
}
29 changes: 10 additions & 19 deletions java/core/src/java/org/apache/orc/impl/OrcTail.java
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/**
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
Expand Down Expand Up @@ -27,6 +27,7 @@
import org.apache.orc.CompressionKind;
import org.apache.orc.OrcFile;
import org.apache.orc.OrcProto;
import org.apache.orc.OrcUtils;
import org.apache.orc.StripeInformation;
import org.apache.orc.StripeStatistics;

Expand Down Expand Up @@ -77,11 +78,7 @@ public OrcFile.WriterVersion getWriterVersion() {
}

public List<StripeInformation> getStripes() {
List<StripeInformation> result = new ArrayList<>(fileTail.getFooter().getStripesCount());
for (OrcProto.StripeInformation stripeProto : fileTail.getFooter().getStripesList()) {
result.add(new ReaderImpl.StripeInformationImpl(stripeProto));
}
return result;
return OrcUtils.convertProtoStripesToStripes(getFooter().getStripesList());
}

public CompressionKind getCompressionKind() {
Expand All @@ -92,9 +89,9 @@ public int getCompressionBufferSize() {
return (int) fileTail.getPostscript().getCompressionBlockSize();
}

public List<StripeStatistics> getStripeStatistics() throws IOException {
public List<StripeStatistics> getStripeStatistics(InStream.StreamOptions options) throws IOException {
List<StripeStatistics> result = new ArrayList<>();
List<OrcProto.StripeStatistics> ssProto = getStripeStatisticsProto();
List<OrcProto.StripeStatistics> ssProto = getStripeStatisticsProto(options);
if (ssProto != null) {
for (OrcProto.StripeStatistics ss : ssProto) {
result.add(new StripeStatistics(ss.getColStatsList()));
Expand All @@ -103,17 +100,12 @@ public List<StripeStatistics> getStripeStatistics() throws IOException {
return result;
}

public List<OrcProto.StripeStatistics> getStripeStatisticsProto() throws IOException {
public List<OrcProto.StripeStatistics> getStripeStatisticsProto(InStream.StreamOptions options) throws IOException {
if (serializedTail == null) return null;
if (metadata == null) {
CompressionCodec codec = OrcCodecPool.getCodec(getCompressionKind());
try {
metadata = extractMetadata(serializedTail, 0,
(int) fileTail.getPostscript().getMetadataLength(),
InStream.options().withCodec(codec).withBufferSize(getCompressionBufferSize()));
} finally {
OrcCodecPool.returnCodec(getCompressionKind(), codec);
}
metadata = extractMetadata(serializedTail, 0,
(int) fileTail.getPostscript().getMetadataLength(),
options);
// clear does not clear the contents but sets position to 0 and limit = capacity
serializedTail.clear();
}
Expand All @@ -137,7 +129,6 @@ public OrcProto.FileTail getMinimalFileTail() {
OrcProto.Footer.Builder footerBuilder = OrcProto.Footer.newBuilder(fileTail.getFooter());
footerBuilder.clearStatistics();
fileTailBuilder.setFooter(footerBuilder.build());
OrcProto.FileTail result = fileTailBuilder.build();
return result;
return fileTailBuilder.build();
}
}
152 changes: 128 additions & 24 deletions java/core/src/java/org/apache/orc/impl/ReaderImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@

import java.io.IOException;
import java.nio.ByteBuffer;
import java.security.Key;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.apache.hadoop.fs.FileStatus;
import org.apache.orc.EncryptionAlgorithm;
import org.apache.orc.EncryptionKey;
import org.apache.orc.CompressionKind;
import org.apache.orc.DataMaskDescription;
import org.apache.orc.EncryptionKey;
Expand All @@ -43,8 +45,8 @@
import org.apache.orc.StripeInformation;
import org.apache.orc.StripeStatistics;
import org.apache.orc.UnknownFormatException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.orc.impl.reader.ReaderEncryption;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
Expand All @@ -54,6 +56,9 @@
import org.apache.orc.OrcProto;

import com.google.protobuf.CodedInputStream;
import org.apache.orc.impl.reader.ReaderEncryptionVariant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ReaderImpl implements Reader {

Expand All @@ -78,6 +83,7 @@ public class ReaderImpl implements Reader {
private final List<StripeInformation> stripes;
protected final int rowIndexStride;
private final long contentLength, numberOfRows;
private final ReaderEncryption encryption;

private long deserializedSize = -1;
protected final Configuration conf;
Expand All @@ -89,10 +95,30 @@ public class ReaderImpl implements Reader {

public static class StripeInformationImpl
implements StripeInformation {
private final long stripeId;
private final long originalStripeId;
private final byte[][] encryptedKeys;
private final OrcProto.StripeInformation stripe;

public StripeInformationImpl(OrcProto.StripeInformation stripe) {
public StripeInformationImpl(OrcProto.StripeInformation stripe,
long stripeId,
long previousOriginalStripeId,
byte[][] previousKeys) {
this.stripe = stripe;
this.stripeId = stripeId;
if (stripe.hasEncryptStripeId()) {
originalStripeId = stripe.getEncryptStripeId();
} else {
originalStripeId = previousOriginalStripeId + 1;
}
if (stripe.getEncryptedLocalKeysCount() != 0) {
encryptedKeys = new byte[stripe.getEncryptedLocalKeysCount()][];
for(int v=0; v < encryptedKeys.length; ++v) {
encryptedKeys[v] = stripe.getEncryptedLocalKeys(v).toByteArray();
}
} else {
encryptedKeys = previousKeys;
}
}

@Override
Expand Down Expand Up @@ -125,6 +151,21 @@ public long getNumberOfRows() {
return stripe.getNumberOfRows();
}

@Override
public long getStripeId() {
return stripeId;
}

@Override
public long getEncryptionStripeId() {
return originalStripeId;
}

@Override
public byte[][] getEncryptedLocalKeys() {
return encryptedKeys;
}

@Override
public String toString() {
return "offset: " + getOffset() + " data: " + getDataLength() +
Expand Down Expand Up @@ -221,20 +262,25 @@ public OrcProto.FileTail getFileTail() {

@Override
public EncryptionKey[] getColumnEncryptionKeys() {
// TODO
return new EncryptionKey[0];
return encryption.getKeys();
}

@Override
public DataMaskDescription[] getDataMasks() {
// TODO
return new DataMaskDescription[0];
return encryption.getMasks();
}

@Override
public EncryptionVariant[] getEncryptionVariants() {
// TODO
return new EncryptionVariant[0];
public ReaderEncryptionVariant[] getEncryptionVariants() {
return encryption.getVariants();
}

/**
* Internal access to our view of the encryption.
* @return the encryption information for this reader.
*/
public ReaderEncryption getEncryption() {
return encryption;
}

@Override
Expand All @@ -244,7 +290,61 @@ public int getRowIndexStride() {

@Override
public ColumnStatistics[] getStatistics() {
return deserializeStats(schema, fileStats);
ColumnStatistics[] result = deserializeStats(schema, fileStats);
try (CompressionCodec codec = OrcCodecPool.getCodec(compressionKind)) {
InStream.StreamOptions compression = InStream.options();
if (codec != null) {
compression.withCodec(codec).withBufferSize(bufferSize);
}
for (ReaderEncryptionVariant variant : encryption.getVariants()) {
ColumnStatistics[] overrides;
try {
overrides = decryptFileStats(variant, compression,
tail.getFooter());
} catch (IOException e) {
throw new RuntimeException("Can't decrypt file stats for " + path +
" with " + variant.getKeyDescription());
}
if (overrides != null) {
for (int i = 0; i < overrides.length; ++i) {
result[variant.getRoot().getId() + i] = overrides[i];
}
}
}
}
return result;
}

public static ColumnStatistics[] decryptFileStats(ReaderEncryptionVariant encryption,
InStream.StreamOptions compression,
OrcProto.Footer footer
) throws IOException {
Key key = encryption.getFileFooterKey();
if (key == null) {
return null;
} else {
OrcProto.EncryptionVariant protoVariant =
footer.getEncryption().getVariants(encryption.getVariantId());
byte[] bytes = protoVariant.getFileStatistics().toByteArray();
BufferChunk buffer = new BufferChunk(ByteBuffer.wrap(bytes), 0);
EncryptionAlgorithm algorithm = encryption.getKeyDescription().getAlgorithm();
byte[] iv = new byte[algorithm.getIvLength()];
CryptoUtils.modifyIvForStream(encryption.getRoot().getId(),
OrcProto.Stream.Kind.FILE_STATISTICS, footer.getStripesCount())
.accept(iv);
InStream.StreamOptions options = new InStream.StreamOptions(compression)
.withEncryption(algorithm, key, iv);
InStream in = InStream.create("encrypted file stats", buffer,
bytes.length, 0, options);
OrcProto.FileStatistics decrypted = OrcProto.FileStatistics.parseFrom(in);
ColumnStatistics[] result = new ColumnStatistics[decrypted.getColumnCount()];
TypeDescription root = encryption.getRoot();
for(int i= 0; i < result.length; ++i){
result[i] = ColumnStatisticsImpl.deserialize(root.findSubtype(root.getId() + i),
decrypted.getColumn(i));
}
return result;
}
}

public static ColumnStatistics[] deserializeStats(
Expand Down Expand Up @@ -351,12 +451,17 @@ public ReaderImpl(Path path, OrcFile.ReaderOptions options) throws IOException {
this.writerVersion =
OrcFile.WriterVersion.from(writer, fileMetadata.getWriterVersionNum());
this.types = fileMetadata.getTypes();
OrcUtils.isValidTypeTree(this.types, 0);
this.schema = OrcUtils.convertTypeFromProtobuf(this.types, 0);
this.rowIndexStride = fileMetadata.getRowIndexStride();
this.contentLength = fileMetadata.getContentLength();
this.numberOfRows = fileMetadata.getNumberOfRows();
this.fileStats = fileMetadata.getFileStats();
this.stripes = fileMetadata.getStripes();
this.userMetadata = null; // not cached and not needed here
// FileMetadata is obsolete and doesn't support encryption
this.encryption = new ReaderEncryption(null, schema, stripes,
options.getKeyProvider(), conf);
} else {
OrcTail orcTail = options.getOrcTail();
if (orcTail == null) {
Expand All @@ -371,17 +476,23 @@ public ReaderImpl(Path path, OrcFile.ReaderOptions options) throws IOException {
this.metadataSize = tail.getMetadataSize();
this.versionList = tail.getPostScript().getVersionList();
this.types = tail.getFooter().getTypesList();
OrcUtils.isValidTypeTree(this.types, 0);
this.schema = OrcUtils.convertTypeFromProtobuf(this.types, 0);
this.rowIndexStride = tail.getFooter().getRowIndexStride();
this.contentLength = tail.getFooter().getContentLength();
this.numberOfRows = tail.getFooter().getNumberOfRows();
this.userMetadata = tail.getFooter().getMetadataList();
this.fileStats = tail.getFooter().getStatisticsList();
this.writerVersion = tail.getWriterVersion();
this.stripes = tail.getStripes();
this.stripeStats = tail.getStripeStatisticsProto();
try (CompressionCodec codec = OrcCodecPool.getCodec(compressionKind)) {
InStream.StreamOptions compress = InStream.options().withCodec(codec)
.withBufferSize(bufferSize);
this.stripeStats = tail.getStripeStatisticsProto(compress);
}
this.encryption = new ReaderEncryption(tail.getFooter(), schema,
stripes, options.getKeyProvider(), conf);
}
OrcUtils.isValidTypeTree(this.types, 0);
this.schema = OrcUtils.convertTypeFromProtobuf(this.types, 0);
}

protected FileSystem getFileSystem() throws IOException {
Expand Down Expand Up @@ -559,12 +670,9 @@ protected OrcTail extractFileTail(FileSystem fs, Path path,
ByteBuffer footerBuffer = buffer.slice();
buffer.reset();
OrcProto.Footer footer;
CompressionCodec codec = OrcCodecPool.getCodec(compressionKind);
try {
try (CompressionCodec codec = OrcCodecPool.getCodec(compressionKind)){
footer = extractFooter(footerBuffer, 0, footerSize,
InStream.options().withCodec(codec).withBufferSize(bufferSize));
} finally {
OrcCodecPool.returnCodec(compressionKind, codec);
}
fileTailBuilder.setFooter(footer);
} catch (Throwable thr) {
Expand Down Expand Up @@ -604,7 +712,6 @@ public RecordReader rows(Options options) throws IOException {
return new RecordReaderImpl(this, options);
}


@Override
public long getRawDataSize() {
// if the deserializedSize is not computed, then compute it, else
Expand Down Expand Up @@ -757,12 +864,9 @@ public List<OrcProto.ColumnStatistics> getOrcProtoFileStatistics() {
@Override
public List<StripeStatistics> getStripeStatistics() throws IOException {
if (metadata == null) {
CompressionCodec codec = OrcCodecPool.getCodec(compressionKind);
try {
try (CompressionCodec codec = OrcCodecPool.getCodec(compressionKind)) {
metadata = extractMetadata(tail.getSerializedTail(), 0, metadataSize,
InStream.options().withCodec(codec).withBufferSize(bufferSize));
} finally {
OrcCodecPool.returnCodec(compressionKind, codec);
}
}
if (stripeStats == null) {
Expand Down
Loading

0 comments on commit 33916ab

Please sign in to comment.