DATAMONGO-2393 - Fix BufferOverflow in GridFS upload.

AsyncInputStreamAdapter now properly splits and buffers incoming DataBuffers according the read requests of AsyncInputStream.read(…) calls.
Previously, the adapter used the input buffer size to be used as the output buffer size. A larger DataBuffer than the transfer buffer handed in through read(…) caused a BufferOverflow.

Original Pull Request: #799
This commit is contained in:
Mark Paluch
2019-10-21 09:17:34 +02:00
committed by Christoph Strobl
parent e73cea0ecf
commit 6cb246c18a

View File

@@ -17,6 +17,8 @@ package org.springframework.data.mongodb.gridfs;
import lombok.RequiredArgsConstructor;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Operators;
import reactor.util.concurrent.Queues;
@@ -25,14 +27,15 @@ import reactor.util.context.Context;
import java.nio.ByteBuffer;
import java.util.Queue;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
import java.util.function.BiConsumer;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscription;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
import com.mongodb.reactivestreams.client.Success;
import com.mongodb.reactivestreams.client.gridfs.AsyncInputStream;
@@ -66,15 +69,16 @@ class AsyncInputStreamAdapter implements AsyncInputStream {
private final Publisher<? extends DataBuffer> buffers;
private final Context subscriberContext;
private final DefaultDataBufferFactory factory = new DefaultDataBufferFactory();
private volatile Subscription subscription;
private volatile boolean cancelled;
private volatile boolean complete;
private volatile boolean allDataBuffersReceived;
private volatile Throwable error;
private final Queue<BiConsumer<DataBuffer, Integer>> readRequests = Queues.<BiConsumer<DataBuffer, Integer>> small()
.get();
private final Queue<DataBuffer> bufferQueue = Queues.<DataBuffer> small().get();
// see DEMAND
volatile long demand;
@@ -88,41 +92,75 @@ class AsyncInputStreamAdapter implements AsyncInputStream {
@Override
public Publisher<Integer> read(ByteBuffer dst) {
return Mono.create(sink -> {
return Flux.create(sink -> {
AtomicLong written = new AtomicLong();
readRequests.offer((db, bytecount) -> {
try {
if (error != null) {
sink.error(error);
onError(sink, error);
return;
}
if (bytecount == -1) {
sink.success(-1);
onComplete(sink, written.get() > 0 ? written.intValue() : -1);
return;
}
ByteBuffer byteBuffer = db.asByteBuffer();
int toWrite = byteBuffer.remaining();
int remaining = byteBuffer.remaining();
int writeCapacity = Math.min(dst.remaining(), remaining);
int limit = Math.min(byteBuffer.position() + writeCapacity, byteBuffer.capacity());
int toWrite = limit - byteBuffer.position();
if (toWrite == 0) {
onComplete(sink, written.intValue());
return;
}
int oldPosition = byteBuffer.position();
byteBuffer.limit(toWrite);
dst.put(byteBuffer);
sink.success(toWrite);
byteBuffer.limit(byteBuffer.capacity());
byteBuffer.position(oldPosition);
db.readPosition(db.readPosition() + toWrite);
written.addAndGet(toWrite);
} catch (Exception e) {
sink.error(e);
onError(sink, e);
} finally {
DataBufferUtils.release(db);
if (db != null && db.readableByteCount() == 0) {
DataBufferUtils.release(db);
}
}
});
request(1);
sink.onCancel(this::terminatePendingReads);
sink.onDispose(this::terminatePendingReads);
sink.onRequest(this::request);
});
}
void onError(FluxSink<Integer> sink, Throwable e) {
readRequests.poll();
sink.error(e);
}
void onComplete(FluxSink<Integer> sink, int writtenBytes) {
readRequests.poll();
DEMAND.decrementAndGet(this);
sink.next(writtenBytes);
sink.complete();
}
/*
* (non-Javadoc)
* @see com.mongodb.reactivestreams.client.gridfs.AsyncInputStream#skip(long)
@@ -144,17 +182,19 @@ class AsyncInputStreamAdapter implements AsyncInputStream {
cancelled = true;
if (error != null) {
terminatePendingReads();
sink.error(error);
return;
}
terminatePendingReads();
sink.success(Success.SUCCESS);
});
}
protected void request(int n) {
protected void request(long n) {
if (complete) {
if (allDataBuffersReceived && bufferQueue.isEmpty()) {
terminatePendingReads();
return;
@@ -176,18 +216,51 @@ class AsyncInputStreamAdapter implements AsyncInputStream {
requestFromSubscription(subscription);
}
}
}
void requestFromSubscription(Subscription subscription) {
long demand = DEMAND.get(AsyncInputStreamAdapter.this);
if (cancelled) {
subscription.cancel();
}
if (demand > 0 && DEMAND.compareAndSet(AsyncInputStreamAdapter.this, demand, demand - 1)) {
subscription.request(1);
drainLoop();
}
void drainLoop() {
while (DEMAND.get(AsyncInputStreamAdapter.this) > 0) {
DataBuffer wip = bufferQueue.peek();
if (wip == null) {
break;
}
if (wip.readableByteCount() == 0) {
bufferQueue.poll();
continue;
}
BiConsumer<DataBuffer, Integer> consumer = AsyncInputStreamAdapter.this.readRequests.peek();
if (consumer == null) {
break;
}
consumer.accept(wip, wip.readableByteCount());
}
if (bufferQueue.isEmpty()) {
if (allDataBuffersReceived) {
terminatePendingReads();
return;
}
if (demand > 0) {
subscription.request(1);
}
}
}
@@ -199,7 +272,7 @@ class AsyncInputStreamAdapter implements AsyncInputStream {
BiConsumer<DataBuffer, Integer> readers;
while ((readers = readRequests.poll()) != null) {
readers.accept(factory.wrap(new byte[0]), -1);
readers.accept(null, -1);
}
}
@@ -214,23 +287,21 @@ class AsyncInputStreamAdapter implements AsyncInputStream {
public void onSubscribe(Subscription s) {
AsyncInputStreamAdapter.this.subscription = s;
Operators.addCap(DEMAND, AsyncInputStreamAdapter.this, -1);
s.request(1);
}
@Override
public void onNext(DataBuffer dataBuffer) {
if (cancelled || complete) {
if (cancelled || allDataBuffersReceived) {
DataBufferUtils.release(dataBuffer);
Operators.onNextDropped(dataBuffer, AsyncInputStreamAdapter.this.subscriberContext);
return;
}
BiConsumer<DataBuffer, Integer> poll = AsyncInputStreamAdapter.this.readRequests.poll();
BiConsumer<DataBuffer, Integer> readRequest = AsyncInputStreamAdapter.this.readRequests.peek();
if (poll == null) {
if (readRequest == null) {
DataBufferUtils.release(dataBuffer);
Operators.onNextDropped(dataBuffer, AsyncInputStreamAdapter.this.subscriberContext);
@@ -238,29 +309,31 @@ class AsyncInputStreamAdapter implements AsyncInputStream {
return;
}
poll.accept(dataBuffer, dataBuffer.readableByteCount());
bufferQueue.offer(dataBuffer);
requestFromSubscription(subscription);
drainLoop();
}
@Override
public void onError(Throwable t) {
if (AsyncInputStreamAdapter.this.cancelled || AsyncInputStreamAdapter.this.complete) {
if (AsyncInputStreamAdapter.this.cancelled || AsyncInputStreamAdapter.this.allDataBuffersReceived) {
Operators.onErrorDropped(t, AsyncInputStreamAdapter.this.subscriberContext);
return;
}
AsyncInputStreamAdapter.this.error = t;
AsyncInputStreamAdapter.this.complete = true;
AsyncInputStreamAdapter.this.allDataBuffersReceived = true;
terminatePendingReads();
}
@Override
public void onComplete() {
AsyncInputStreamAdapter.this.complete = true;
terminatePendingReads();
AsyncInputStreamAdapter.this.allDataBuffersReceived = true;
if (bufferQueue.isEmpty()) {
terminatePendingReads();
}
}
}
}