Skip to content

Commit

Permalink
chore: ensure ParallelCompositeUploadBlobWriteSessionConfig is serial…
Browse files Browse the repository at this point in the history
…izable (#2240)
  • Loading branch information
BenWhitehead authored Oct 3, 2023
1 parent f8f4e22 commit a599c63
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static java.util.Objects.requireNonNull;

import com.google.api.core.ApiFuture;
import com.google.api.core.ApiFutures;
Expand All @@ -37,6 +38,8 @@
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.google.storage.v2.WriteObjectResponse;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.time.Clock;
Expand Down Expand Up @@ -257,7 +260,8 @@ WriterFactory createFactory(Clock clock) throws IOException {
*/
@BetaApi
@Immutable
public abstract static class BufferStrategy extends Factory<BufferHandlePool> {
public abstract static class BufferStrategy extends Factory<BufferHandlePool>
implements Serializable {

private BufferStrategy() {}

Expand Down Expand Up @@ -289,6 +293,7 @@ public static BufferStrategy fixedPool(int bufferCount, int bufferCapacity) {
}

private static class SimpleBufferStrategy extends BufferStrategy {
private static final long serialVersionUID = 8884826090481043434L;

private final int capacity;

Expand All @@ -303,6 +308,7 @@ BufferHandlePool get() {
}

private static class FixedBufferStrategy extends BufferStrategy {
private static final long serialVersionUID = 3288902741819257066L;

private final int bufferCount;
private final int bufferCapacity;
Expand All @@ -328,7 +334,7 @@ BufferHandlePool get() {
*/
@BetaApi
@Immutable
public abstract static class ExecutorSupplier extends Factory<Executor> {
public abstract static class ExecutorSupplier extends Factory<Executor> implements Serializable {
private static final AtomicInteger INSTANCE_COUNTER = new AtomicInteger(1);

private ExecutorSupplier() {}
Expand All @@ -341,13 +347,7 @@ private ExecutorSupplier() {}
*/
@BetaApi
public static ExecutorSupplier cachedPool() {
return new ExecutorSupplier() {
@Override
Executor get() {
ThreadFactory threadFactory = newThreadFactory();
return Executors.newCachedThreadPool(threadFactory);
}
};
return new CachedSupplier();
}

/**
Expand All @@ -359,13 +359,7 @@ Executor get() {
*/
@BetaApi
public static ExecutorSupplier fixedPool(int poolSize) {
return new ExecutorSupplier() {
@Override
Executor get() {
ThreadFactory threadFactory = newThreadFactory();
return Executors.newFixedThreadPool(poolSize, threadFactory);
}
};
return new FixedSupplier(poolSize);
}

/**
Expand All @@ -380,6 +374,7 @@ Executor get() {
*/
@BetaApi
public static ExecutorSupplier useExecutor(Executor executor) {
requireNonNull(executor, "executor must be non null");
return new SuppliedExecutorSupplier(executor);
}

Expand All @@ -403,6 +398,36 @@ public SuppliedExecutorSupplier(Executor executor) {
Executor get() {
return executor;
}

private void writeObject(ObjectOutputStream out) throws IOException {
throw new java.io.InvalidClassException(this.getClass().getName() + "; Not serializable");
}
}

private static class CachedSupplier extends ExecutorSupplier implements Serializable {
private static final long serialVersionUID = 7768210719775319260L;

@Override
Executor get() {
ThreadFactory threadFactory = newThreadFactory();
return Executors.newCachedThreadPool(threadFactory);
}
}

private static class FixedSupplier extends ExecutorSupplier implements Serializable {
private static final long serialVersionUID = 7771825977551614347L;

private final int poolSize;

public FixedSupplier(int poolSize) {
this.poolSize = poolSize;
}

@Override
Executor get() {
ThreadFactory threadFactory = newThreadFactory();
return Executors.newFixedThreadPool(poolSize, threadFactory);
}
}
}

Expand All @@ -415,7 +440,8 @@ Executor get() {
*/
@BetaApi
@Immutable
public abstract static class PartNamingStrategy {
public abstract static class PartNamingStrategy implements Serializable {
private static final long serialVersionUID = 8343436026774231869L;
private static final String FIELD_SEPARATOR = ";";
private static final Encoder B64 = Base64.getUrlEncoder().withoutPadding();
private static final HashFunction OBJECT_NAME_HASH_FUNCTION = Hashing.goodFastHash(128);
Expand Down Expand Up @@ -496,6 +522,7 @@ public static PartNamingStrategy prefix(String prefixPattern) {
}

static final class WithPrefix extends PartNamingStrategy {
private static final long serialVersionUID = 5709330763161570411L;

private final String prefix;

Expand All @@ -518,6 +545,8 @@ protected String fmtFields(String randomKey, String nameDigest, String partRange
}

static final class NoPrefix extends PartNamingStrategy {
private static final long serialVersionUID = 5202415556658566017L;

public NoPrefix(SecureRandom rand) {
super(rand);
}
Expand Down Expand Up @@ -548,34 +577,35 @@ protected String fmtFields(String randomKey, String nameDigest, String partRange
*/
@BetaApi
@Immutable
public static class PartCleanupStrategy {
private final boolean deleteParts;
private final boolean deleteOnError;

private PartCleanupStrategy(boolean deleteParts, boolean deleteOnError) {
this.deleteParts = deleteParts;
this.deleteOnError = deleteOnError;
public static class PartCleanupStrategy implements Serializable {
private static final long serialVersionUID = -1434253614347199051L;
private final boolean deletePartsOnSuccess;
private final boolean deleteAllOnError;

private PartCleanupStrategy(boolean deletePartsOnSuccess, boolean deleteAllOnError) {
this.deletePartsOnSuccess = deletePartsOnSuccess;
this.deleteAllOnError = deleteAllOnError;
}

boolean isDeleteParts() {
return deleteParts;
public boolean isDeletePartsOnSuccess() {
return deletePartsOnSuccess;
}

boolean isDeleteOnError() {
return deleteOnError;
public boolean isDeleteAllOnError() {
return deleteAllOnError;
}

/**
* If an unrecoverable error is encountered, define whether to attempt to delete any object
* parts already uploaded.
* If an unrecoverable error is encountered, define whether to attempt to delete any objects
* already uploaded.
*
* <p><i>Default:</i> {@code true}
*
* @since 2.28.0 This new api is in preview and is subject to breaking changes.
*/
@BetaApi
PartCleanupStrategy withDeleteOnError(boolean deleteOnError) {
return new PartCleanupStrategy(deleteParts, deleteOnError);
PartCleanupStrategy withDeleteAllOnError(boolean deleteAllOnError) {
return new PartCleanupStrategy(deletePartsOnSuccess, deleteAllOnError);
}

/**
Expand Down Expand Up @@ -615,7 +645,11 @@ public static PartCleanupStrategy never() {
}
}

private abstract static class Factory<T> {
private abstract static class Factory<T> implements Serializable {
private static final long serialVersionUID = 271806144227661056L;

private Factory() {}

abstract T get();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ public synchronized void close() throws IOException {
},
exec);

if (partCleanupStrategy.isDeleteOnError()) {
if (partCleanupStrategy.isDeleteAllOnError()) {
ApiFuture<BlobInfo> cleaningFuture =
ApiFutures.catchingAsync(
validatingTransform, Throwable.class, this::asyncCleanupAfterFailure, exec);
Expand Down Expand Up @@ -316,7 +316,7 @@ private void internalFlush(ByteBuffer buf) {

Throwable cause = e.getCause();
BaseServiceException storageException;
if (partCleanupStrategy.isDeleteOnError()) {
if (partCleanupStrategy.isDeleteAllOnError()) {
storageException = StorageException.coalesce(cause);
ApiFuture<Object> cleanupFutures = asyncCleanupAfterFailure(storageException);
// asynchronously fail the finalObject future
Expand Down Expand Up @@ -394,7 +394,7 @@ private BlobInfo compose(ImmutableList<BlobInfo> parts) {
}

private ApiFuture<BlobInfo> cleanupParts(BlobInfo finalInfo) {
if (!partCleanupStrategy.isDeleteParts()) {
if (!partCleanupStrategy.isDeletePartsOnSuccess()) {
return ApiFutures.immediateFuture(finalInfo);
}
List<ApiFuture<Boolean>> deletes =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;

import com.google.api.services.storage.model.StorageObject;
import com.google.cloud.BaseSerializationTest;
Expand All @@ -32,6 +33,10 @@
import com.google.cloud.storage.BlobReadChannelV2.BlobReadChannelContext;
import com.google.cloud.storage.BlobReadChannelV2.BlobReadChannelV2State;
import com.google.cloud.storage.BlobWriteChannelV2.BlobWriteChannelV2State;
import com.google.cloud.storage.ParallelCompositeUploadBlobWriteSessionConfig.BufferStrategy;
import com.google.cloud.storage.ParallelCompositeUploadBlobWriteSessionConfig.ExecutorSupplier;
import com.google.cloud.storage.ParallelCompositeUploadBlobWriteSessionConfig.PartCleanupStrategy;
import com.google.cloud.storage.ParallelCompositeUploadBlobWriteSessionConfig.PartNamingStrategy;
import com.google.cloud.storage.Storage.BlobTargetOption;
import com.google.cloud.storage.Storage.BucketField;
import com.google.cloud.storage.Storage.ComposeRequest;
Expand All @@ -46,13 +51,15 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InvalidClassException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Base64;
import java.util.Collections;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.Executor;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
Expand Down Expand Up @@ -368,4 +375,38 @@ public void testSerializableObjects() throws Exception {
assertEquals(copy, copy);
}
}

@Test
public void blobWriteSessionConfig_pcu() throws IOException, ClassNotFoundException {
ParallelCompositeUploadBlobWriteSessionConfig pcu1 =
BlobWriteSessionConfigs.parallelCompositeUpload();
ParallelCompositeUploadBlobWriteSessionConfig pcu1copy = serializeAndDeserialize(pcu1);
assertThat(pcu1copy).isNotNull();

ParallelCompositeUploadBlobWriteSessionConfig pcu2 =
BlobWriteSessionConfigs.parallelCompositeUpload()
.withBufferStrategy(BufferStrategy.fixedPool(1, 3))
.withPartCleanupStrategy(PartCleanupStrategy.never())
.withPartNamingStrategy(PartNamingStrategy.prefix("prefix"))
.withExecutorSupplier(ExecutorSupplier.fixedPool(5));
ParallelCompositeUploadBlobWriteSessionConfig pcu2copy = serializeAndDeserialize(pcu2);
assertThat(pcu2copy).isNotNull();

InvalidClassException invalidClassException =
assertThrows(
InvalidClassException.class,
() -> {
Executor executor = command -> {};
ParallelCompositeUploadBlobWriteSessionConfig pcu3 =
BlobWriteSessionConfigs.parallelCompositeUpload()
.withExecutorSupplier(ExecutorSupplier.useExecutor(executor));
// executor is not serializable, this should throw an exception
serializeAndDeserialize(pcu3);
});

assertThat(invalidClassException)
.hasMessageThat()
.isEqualTo(
"com.google.cloud.storage.ParallelCompositeUploadBlobWriteSessionConfig$ExecutorSupplier$SuppliedExecutorSupplier; Not serializable");
}
}

0 comments on commit a599c63

Please sign in to comment.