Skip to content

Commit

Permalink
Ensure Sink.contextView is propagated
Browse files Browse the repository at this point in the history
Port Flux/Mono improvements to the logic from master
Context view is propagated via the subscriber, so any
nested subscribe calls need to have the context passed
through.

JAVA-5345
  • Loading branch information
rozza committed Aug 7, 2024
1 parent 4a18081 commit c17bb47
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@

import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import reactor.util.context.Context;

import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
Expand All @@ -48,19 +46,19 @@ public void subscribe(final Subscriber<? super T> subscriber) {
if (calculateDemand(demand) > 0 && inProgress.compareAndSet(false, true)) {
if (batchCursor == null) {
int batchSize = calculateBatchSize(sink.requestedFromDownstream());
Context initialContext = subscriber instanceof CoreSubscriber<?>
? ((CoreSubscriber<?>) subscriber).currentContext() : null;
batchCursorPublisher.batchCursor(batchSize).subscribe(bc -> {
batchCursor = bc;
inProgress.set(false);
batchCursorPublisher.batchCursor(batchSize)
.contextWrite(sink.contextView())
.subscribe(bc -> {
batchCursor = bc;
inProgress.set(false);

// Handle any cancelled subscriptions that happen during the time it takes to get the batchCursor
if (sink.isCancelled()) {
closeCursor();
} else {
recurseCursor();
}
}, sink::error, null, initialContext);
// Handle any cancelled subscriptions that happen during the time it takes to get the batchCursor
if (sink.isCancelled()) {
closeCursor();
} else {
recurseCursor();
}
}, sink::error);
} else {
inProgress.set(false);
recurseCursor();
Expand All @@ -86,6 +84,7 @@ private void recurseCursor(){
} else {
batchCursor.setBatchSize(calculateBatchSize(sink.requestedFromDownstream()));
Mono.from(batchCursor.next(() -> sink.isCancelled()))
.contextWrite(sink.contextView())
.doOnCancel(this::closeCursor)
.subscribe(results -> {
if (!results.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,21 +103,17 @@ public Publisher<T> batchSize(final int batchSize) {

public Publisher<T> first() {
return batchCursor(this::asAsyncFirstReadOperation)
.flatMap(batchCursor -> Mono.create(sink -> {
.flatMap(batchCursor -> {
batchCursor.setBatchSize(1);
Mono.from(batchCursor.next())
return Mono.from(batchCursor.next())
.doOnTerminate(batchCursor::close)
.doOnError(sink::error)
.doOnSuccess(results -> {
.flatMap(results -> {
if (results == null || results.isEmpty()) {
sink.success();
} else {
sink.success(results.get(0));
return Mono.empty();
}
})
.contextWrite(sink.contextView())
.subscribe();
}));
return Mono.fromCallable(() -> results.get(0));
});
});
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ private void collInfo(final MongoCryptContext cryptContext,
sink.error(new IllegalStateException("Missing database name"));
} else {
collectionInfoRetriever.filter(databaseName, cryptContext.getMongoOperation())
.contextWrite(sink.contextView())
.doOnSuccess(result -> {
if (result != null) {
cryptContext.addMongoOperationResult(result);
Expand All @@ -326,6 +327,7 @@ private void mark(final MongoCryptContext cryptContext,
sink.error(wrapInClientException(new IllegalStateException("Missing database name")));
} else {
commandMarker.mark(databaseName, cryptContext.getMongoOperation())
.contextWrite(sink.contextView())
.doOnSuccess(result -> {
cryptContext.addMongoOperationResult(result);
cryptContext.completeMongoOperation();
Expand All @@ -340,6 +342,7 @@ private void fetchKeys(final MongoCryptContext cryptContext,
@Nullable final String databaseName,
final MonoSink<RawBsonDocument> sink) {
keyRetriever.find(cryptContext.getMongoOperation())
.contextWrite(sink.contextView())
.doOnSuccess(results -> {
for (BsonDocument result : results) {
cryptContext.addMongoOperationResult(result);
Expand All @@ -357,11 +360,13 @@ private void decryptKeys(final MongoCryptContext cryptContext,
MongoKeyDecryptor keyDecryptor = cryptContext.nextKeyDecryptor();
if (keyDecryptor != null) {
keyManagementService.decryptKey(keyDecryptor)
.contextWrite(sink.contextView())
.doOnSuccess(r -> decryptKeys(cryptContext, databaseName, sink))
.doOnError(e -> sink.error(wrapInClientException(e)))
.subscribe();
} else {
Mono.fromRunnable(cryptContext::completeKeyDecryptors)
.contextWrite(sink.contextView())
.doOnSuccess(r -> executeStateMachineWithSink(cryptContext, databaseName, sink))
.doOnError(e -> sink.error(wrapInClientException(e)))
.subscribe();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,17 @@
import org.reactivestreams.Subscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.function.Tuple2;

import java.nio.ByteBuffer;
import java.util.Date;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.function.Function;

import static com.mongodb.ReadPreference.primary;
import static com.mongodb.assertions.Assertions.notNull;


/**
* <p>This class is not part of the public API and may be removed or changed at any time</p>
*/
Expand Down Expand Up @@ -98,31 +95,22 @@ public BsonValue getId() {

@Override
public void subscribe(final Subscriber<? super Void> s) {
Mono.<Void>create(sink -> {
Mono.deferContextual(ctx -> {
AtomicBoolean terminated = new AtomicBoolean(false);
sink.onCancel(() -> createCancellationMono(terminated).subscribe());

Consumer<Throwable> errorHandler = e -> createCancellationMono(terminated)
.doOnError(i -> sink.error(e))
.doOnSuccess(i -> sink.error(e))
.subscribe();

Consumer<Long> saveFileDataMono = l -> createSaveFileDataMono(terminated, l)
.doOnError(errorHandler)
.doOnSuccess(i -> sink.success())
.subscribe();

Consumer<Void> saveChunksMono = i -> createSaveChunksMono(terminated)
.doOnError(errorHandler)
.doOnSuccess(saveFileDataMono)
.subscribe();

createCheckAndCreateIndexesMono()
.doOnError(errorHandler)
.doOnSuccess(saveChunksMono)
.subscribe();
})
.subscribe(s);
return createCheckAndCreateIndexesMono()
.then(createSaveChunksMono(terminated))
.flatMap(lengthInBytes -> createSaveFileDataMono(terminated, lengthInBytes))
.onErrorResume(originalError ->
createCancellationMono(terminated)
.onErrorMap(cancellationError -> {
// Timeout exception might occur during cancellation. It gets suppressed.
originalError.addSuppressed(cancellationError);
return originalError;
})
.then(Mono.error(originalError)))
.doOnCancel(() -> createCancellationMono(terminated).contextWrite(ctx).subscribe())
.then();
}).subscribe(s);
}

public GridFSUploadPublisher<ObjectId> withObjectId() {
Expand Down Expand Up @@ -156,28 +144,14 @@ private Mono<Void> createCheckAndCreateIndexesMono() {
} else {
findPublisher = collection.find();
}
AtomicBoolean collectionExists = new AtomicBoolean(false);
return Mono.from(findPublisher.projection(PROJECTION).first())
.switchIfEmpty(Mono.defer(() ->
checkAndCreateIndex(filesCollection.withReadPreference(primary()), FILES_INDEX)
.then(checkAndCreateIndex(chunksCollection.withReadPreference(primary()), CHUNKS_INDEX))
.then(Mono.fromCallable(Document::new))
))
.then();

return Mono.create(sink -> Mono.from(findPublisher.projection(PROJECTION).first())
.subscribe(
d -> collectionExists.set(true),
sink::error,
() -> {
if (collectionExists.get()) {
sink.success();
} else {
checkAndCreateIndex(filesCollection.withReadPreference(primary()), FILES_INDEX)
.doOnError(sink::error)
.doOnSuccess(i -> {
checkAndCreateIndex(chunksCollection.withReadPreference(primary()), CHUNKS_INDEX)
.doOnError(sink::error)
.doOnSuccess(sink::success)
.subscribe();
})
.subscribe();
}
})
);
}

private <T> Mono<Boolean> hasIndex(final MongoCollection<T> collection, final Document index) {
Expand All @@ -189,29 +163,23 @@ private <T> Mono<Boolean> hasIndex(final MongoCollection<T> collection, final Do
}

return Flux.from(listIndexesPublisher)
.collectList()
.map(indexes -> {
boolean hasIndex = false;
for (Document result : indexes) {
Document indexDoc = result.get("key", new Document());
for (final Map.Entry<String, Object> entry : indexDoc.entrySet()) {
if (entry.getValue() instanceof Number) {
entry.setValue(((Number) entry.getValue()).intValue());
}
}
if (indexDoc.equals(index)) {
hasIndex = true;
break;
.filter((result) -> {
Document indexDoc = result.get("key", new Document());
for (final Map.Entry<String, Object> entry : indexDoc.entrySet()) {
if (entry.getValue() instanceof Number) {
entry.setValue(((Number) entry.getValue()).intValue());
}
}
return hasIndex;
});
return indexDoc.equals(index);
})
.take(1)
.hasElements();
}

private <T> Mono<Void> checkAndCreateIndex(final MongoCollection<T> collection, final Document index) {
return hasIndex(collection, index).flatMap(hasIndex -> {
if (!hasIndex) {
return createIndexMono(collection, index).flatMap(s -> Mono.empty());
return createIndexMono(collection, index).then();
} else {
return Mono.empty();
}
Expand All @@ -223,14 +191,14 @@ private <T> Mono<String> createIndexMono(final MongoCollection<T> collection, fi
}

private Mono<Long> createSaveChunksMono(final AtomicBoolean terminated) {
return Mono.create(sink -> {
AtomicLong lengthInBytes = new AtomicLong(0);
AtomicInteger chunkIndex = new AtomicInteger(0);
new ResizingByteBufferFlux(source, chunkSizeBytes)
.flatMap((Function<ByteBuffer, Publisher<InsertOneResult>>) byteBuffer -> {
return new ResizingByteBufferFlux(source, chunkSizeBytes)
.index()
.flatMap((Function<Tuple2<Long, ByteBuffer>, Publisher<Integer>>) indexAndBuffer -> {
if (terminated.get()) {
return Mono.empty();
}
Long index = indexAndBuffer.getT1();
ByteBuffer byteBuffer = indexAndBuffer.getT2();
byte[] byteArray = new byte[byteBuffer.remaining()];
if (byteBuffer.hasArray()) {
System.arraycopy(byteBuffer.array(), byteBuffer.position(), byteArray, 0, byteBuffer.remaining());
Expand All @@ -240,18 +208,19 @@ private Mono<Long> createSaveChunksMono(final AtomicBoolean terminated) {
byteBuffer.reset();
}
Binary data = new Binary(byteArray);
lengthInBytes.addAndGet(data.length());

Document chunkDocument = new Document("files_id", fileId)
.append("n", chunkIndex.getAndIncrement())
.append("n", index.intValue())
.append("data", data);

return clientSession == null ? chunksCollection.insertOne(chunkDocument)
Publisher<InsertOneResult> insertOnePublisher = clientSession == null
? chunksCollection.insertOne(chunkDocument)
: chunksCollection.insertOne(clientSession, chunkDocument);

return Mono.from(insertOnePublisher).thenReturn(data.length());
})
.subscribe(null, sink::error, () -> sink.success(lengthInBytes.get()));
});
}
.reduce(0L, Long::sum);
}

private Mono<InsertOneResult> createSaveFileDataMono(final AtomicBoolean terminated, final long lengthInBytes) {
if (terminated.compareAndSet(false, true)) {
Expand Down

0 comments on commit c17bb47

Please sign in to comment.