Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@
import static java.util.concurrent.TimeUnit.NANOSECONDS;
import static java.util.stream.Collectors.toList;

abstract class BaseCluster implements Cluster {
@VisibleForTesting(otherwise = PRIVATE)
public abstract class BaseCluster implements Cluster {
private static final Logger LOGGER = Loggers.getLogger("cluster");
private static final StructuredLogger STRUCTURED_LOGGER = new StructuredLogger("cluster");

Expand All @@ -112,10 +113,11 @@ abstract class BaseCluster implements Cluster {
private volatile boolean isClosed;
private volatile ClusterDescription description;

BaseCluster(final ClusterId clusterId,
final ClusterSettings settings,
final ClusterableServerFactory serverFactory,
final ClientMetadata clientMetadata) {
@VisibleForTesting(otherwise = PRIVATE)
protected BaseCluster(final ClusterId clusterId,
final ClusterSettings settings,
final ClusterableServerFactory serverFactory,
final ClientMetadata clientMetadata) {
this.clusterId = notNull("clusterId", clusterId);
this.settings = notNull("settings", settings);
this.serverFactory = notNull("serverFactory", serverFactory);
Expand Down Expand Up @@ -361,8 +363,7 @@ private static ServerSelector getCompleteServerSelector(
final ClusterSettings settings) {
List<ServerSelector> selectors = Stream.of(
getRaceConditionPreFilteringSelector(serversSnapshot),
serverSelector,
serverDeprioritization.getServerSelector(),
serverDeprioritization.applyDeprioritization(serverSelector),
settings.getServerSelector(), // may be null
new LatencyMinimizingServerSelector(settings.getLocalThreshold(MILLISECONDS), MILLISECONDS),
AtMostTwoRandomServerSelector.instance(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import com.mongodb.ServerAddress;
import com.mongodb.ServerApi;
import com.mongodb.connection.ClusterDescription;
import com.mongodb.connection.ClusterType;
import com.mongodb.connection.ServerDescription;
import com.mongodb.internal.IgnorableRequestContext;
import com.mongodb.internal.TimeoutContext;
Expand All @@ -40,6 +39,8 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;

import static com.mongodb.internal.VisibleForTesting.AccessModifier.PACKAGE;
import static com.mongodb.internal.VisibleForTesting.AccessModifier.PRIVATE;
import static java.util.stream.Collectors.toList;

/**
Expand Down Expand Up @@ -113,6 +114,13 @@ public OperationContext withOperationName(final String operationName) {
operationName, tracingSpan);
}

// TODO-JAVA-6058: This method enables overriding the ServerDeprioritization state.
// It is a temporary solution to handle cases where deprioritization state persists across operations.
public OperationContext withServerDeprioritization(final ServerDeprioritization serverDeprioritization) {
return new OperationContext(id, requestContext, sessionContext, timeoutContext, serverDeprioritization, tracingManager, serverApi,
operationName, tracingSpan);
}

public long getId() {
return id;
}
Expand Down Expand Up @@ -228,24 +236,26 @@ public static final class ServerDeprioritization {
@Nullable
private ServerAddress candidate;
private final Set<ServerAddress> deprioritized;
private final DeprioritizingSelector selector;

private ServerDeprioritization() {
@VisibleForTesting(otherwise = PRIVATE)
public ServerDeprioritization() {
candidate = null;
deprioritized = new HashSet<>();
selector = new DeprioritizingSelector();
}

/**
* The returned {@link ServerSelector} tries to {@linkplain ServerSelector#select(ClusterDescription) select}
* only the {@link ServerDescription}s that do not have deprioritized {@link ServerAddress}es.
* If no such {@link ServerDescription} can be selected, then it selects {@link ClusterDescription#getServerDescriptions()}.
* The returned {@link ServerSelector} wraps the provided selector and attempts server selection in two passes:
* <ol>
* <li>First pass: calls the wrapped selector with only non-deprioritized {@link ServerDescription}s</li>
* <li>Second pass: if the first pass returns no servers, calls the wrapped selector again with all servers (including deprioritized ones)</li>
* </ol>
*/
ServerSelector getServerSelector() {
return selector;
ServerSelector applyDeprioritization(final ServerSelector wrappedSelector) {
return new DeprioritizingSelector(wrappedSelector);
}

void updateCandidate(final ServerAddress serverAddress) {
@VisibleForTesting(otherwise = PACKAGE)
public void updateCandidate(final ServerAddress serverAddress) {
candidate = serverAddress;
}

Expand All @@ -263,24 +273,35 @@ public void onAttemptFailure(final Throwable failure) {
* which indeed may be used concurrently. {@link DeprioritizingSelector} does not need to be thread-safe.
*/
private final class DeprioritizingSelector implements ServerSelector {
private DeprioritizingSelector() {
private final ServerSelector wrappedSelector;

private DeprioritizingSelector(final ServerSelector wrappedSelector) {
this.wrappedSelector = wrappedSelector;
}

@Override
public List<ServerDescription> select(final ClusterDescription clusterDescription) {
List<ServerDescription> serverDescriptions = clusterDescription.getServerDescriptions();
if (!isEnabled(clusterDescription.getType())) {
return serverDescriptions;

if (serverDescriptions.size() == 1 || deprioritized.isEmpty()) {
return wrappedSelector.select(clusterDescription);
}

List<ServerDescription> nonDeprioritizedServerDescriptions = serverDescriptions
.stream()
.filter(serverDescription -> !deprioritized.contains(serverDescription.getAddress()))
.collect(toList());
return nonDeprioritizedServerDescriptions.isEmpty() ? serverDescriptions : nonDeprioritizedServerDescriptions;
}

private boolean isEnabled(final ClusterType clusterType) {
return clusterType == ClusterType.SHARDED;
if (nonDeprioritizedServerDescriptions.isEmpty()) {
return wrappedSelector.select(clusterDescription);
}

List<ServerDescription> selected = wrappedSelector.select(
new ClusterDescription(clusterDescription.getConnectionMode(), clusterDescription.getType(),
nonDeprioritizedServerDescriptions,
clusterDescription.getClusterSettings(),
clusterDescription.getServerSettings()));
return selected.isEmpty() ? wrappedSelector.select(clusterDescription) : selected;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ final class ChangeStreamBatchCursor<T> implements AggregateResponseBatchCursor<T
final int maxWireVersion) {
this.changeStreamOperation = changeStreamOperation;
this.binding = binding.retain();
this.initialOperationContext = operationContext.withOverride(TimeoutContext::withMaxTimeAsMaxAwaitTimeOverride);
this.initialOperationContext = operationContext
.withOverride(TimeoutContext::withMaxTimeAsMaxAwaitTimeOverride)
//TODO-JAVA-6058. Temporary workaround to reset any server deprioritization
// state from the previous find operation.
.withServerDeprioritization(new OperationContext.ServerDeprioritization());
this.wrapped = wrapped;
this.resumeToken = resumeToken;
this.maxWireVersion = maxWireVersion;
Expand Down
54 changes: 28 additions & 26 deletions driver-core/src/test/functional/com/mongodb/ClusterFixture.java
Original file line number Diff line number Diff line change
Expand Up @@ -193,16 +193,18 @@ public static ServerVersion getServerVersion() {
if (serverVersion == null) {
serverVersion = getVersion(new CommandReadOperation<>("admin",
new BsonDocument("buildInfo", new BsonInt32(1)), new BsonDocumentCodec())
.execute(new ClusterBinding(getCluster(), ReadPreference.nearest()), OPERATION_CONTEXT));
.execute(new ClusterBinding(getCluster(), ReadPreference.nearest()), getOperationContext()));
}
return serverVersion;
}

public static final OperationContext OPERATION_CONTEXT = new OperationContext(
IgnorableRequestContext.INSTANCE,
new ReadConcernAwareNoOpSessionContext(ReadConcern.DEFAULT),
new TimeoutContext(TIMEOUT_SETTINGS),
getServerApi());
public static OperationContext getOperationContext() {
return new OperationContext(
IgnorableRequestContext.INSTANCE,
new ReadConcernAwareNoOpSessionContext(ReadConcern.DEFAULT),
new TimeoutContext(TIMEOUT_SETTINGS),
getServerApi());
}

public static final InternalOperationContextFactory OPERATION_CONTEXT_FACTORY =
new InternalOperationContextFactory(TIMEOUT_SETTINGS, getServerApi());
Expand Down Expand Up @@ -255,7 +257,7 @@ public static boolean hasEncryptionTestsEnabled() {
public static Document getServerStatus() {
return new CommandReadOperation<>("admin", new BsonDocument("serverStatus", new BsonInt32(1)),
new DocumentCodec())
.execute(getBinding(), OPERATION_CONTEXT);
.execute(getBinding(), getOperationContext());
}

public static boolean supportsFsync() {
Expand All @@ -270,7 +272,7 @@ static class ShutdownHook extends Thread {
public void run() {
if (cluster != null) {
try {
new DropDatabaseOperation(getDefaultDatabaseName(), WriteConcern.ACKNOWLEDGED).execute(getBinding(), OPERATION_CONTEXT);
new DropDatabaseOperation(getDefaultDatabaseName(), WriteConcern.ACKNOWLEDGED).execute(getBinding(), getOperationContext());
} catch (MongoCommandException e) {
// if we do not have permission to drop the database, assume it is cleaned up in some other way
if (!e.getMessage().contains("Command dropDatabase requires authentication")) {
Expand Down Expand Up @@ -322,7 +324,7 @@ public static synchronized ConnectionString getConnectionString() {
try {
BsonDocument helloResult = new CommandReadOperation<>("admin",
new BsonDocument(LEGACY_HELLO, new BsonInt32(1)), new BsonDocumentCodec())
.execute(new ClusterBinding(cluster, ReadPreference.nearest()), OPERATION_CONTEXT);
.execute(new ClusterBinding(cluster, ReadPreference.nearest()), getOperationContext());
if (helloResult.containsKey("setName")) {
connectionString = new ConnectionString(DEFAULT_URI + "/?replicaSet="
+ helloResult.getString("setName").getValue());
Expand Down Expand Up @@ -382,11 +384,11 @@ public static ReadWriteBinding getBinding(final OperationContext operationContex
}

public static ReadWriteBinding getBinding(final ReadPreference readPreference) {
return getBinding(getCluster(), readPreference, OPERATION_CONTEXT);
return getBinding(getCluster(), readPreference, getOperationContext());
}

public static OperationContext createNewOperationContext(final TimeoutSettings timeoutSettings) {
return OPERATION_CONTEXT.withTimeoutContext(new TimeoutContext(timeoutSettings));
return getOperationContext().withTimeoutContext(new TimeoutContext(timeoutSettings));
}

private static ReadWriteBinding getBinding(final Cluster cluster,
Expand All @@ -403,23 +405,23 @@ private static ReadWriteBinding getBinding(final Cluster cluster,
}

public static SingleConnectionBinding getSingleConnectionBinding() {
return new SingleConnectionBinding(getCluster(), ReadPreference.primary(), OPERATION_CONTEXT);
return new SingleConnectionBinding(getCluster(), ReadPreference.primary(), getOperationContext());
}

public static AsyncSingleConnectionBinding getAsyncSingleConnectionBinding() {
return getAsyncSingleConnectionBinding(getAsyncCluster());
}

public static AsyncSingleConnectionBinding getAsyncSingleConnectionBinding(final Cluster cluster) {
return new AsyncSingleConnectionBinding(cluster, ReadPreference.primary(), OPERATION_CONTEXT);
return new AsyncSingleConnectionBinding(cluster, ReadPreference.primary(), getOperationContext());
}

public static AsyncReadWriteBinding getAsyncBinding(final Cluster cluster) {
return new AsyncClusterBinding(cluster, ReadPreference.primary());
}

public static AsyncReadWriteBinding getAsyncBinding() {
return getAsyncBinding(getAsyncCluster(), ReadPreference.primary(), OPERATION_CONTEXT);
return getAsyncBinding(getAsyncCluster(), ReadPreference.primary(), getOperationContext());
}

public static AsyncReadWriteBinding getAsyncBinding(final TimeoutSettings timeoutSettings) {
Expand All @@ -431,7 +433,7 @@ public static AsyncReadWriteBinding getAsyncBinding(final OperationContext opera
}

public static AsyncReadWriteBinding getAsyncBinding(final ReadPreference readPreference) {
return getAsyncBinding(getAsyncCluster(), readPreference, OPERATION_CONTEXT);
return getAsyncBinding(getAsyncCluster(), readPreference, getOperationContext());
}

public static AsyncReadWriteBinding getAsyncBinding(
Expand Down Expand Up @@ -605,7 +607,7 @@ public static BsonDocument getServerParameters() {
if (serverParameters == null) {
serverParameters = new CommandReadOperation<>("admin",
new BsonDocument("getParameter", new BsonString("*")), new BsonDocumentCodec())
.execute(getBinding(), OPERATION_CONTEXT);
.execute(getBinding(), getOperationContext());
}
return serverParameters;
}
Expand Down Expand Up @@ -673,7 +675,7 @@ public static void configureFailPoint(final BsonDocument failPointDocument) {
if (!isSharded()) {
try {
new CommandReadOperation<>("admin", failPointDocument, new BsonDocumentCodec())
.execute(getBinding(), OPERATION_CONTEXT);
.execute(getBinding(), getOperationContext());
} catch (MongoCommandException e) {
if (e.getErrorCode() == COMMAND_NOT_FOUND_ERROR_CODE) {
failsPointsSupported = false;
Expand All @@ -689,7 +691,7 @@ public static void disableFailPoint(final String failPoint) {
.append("mode", new BsonString("off"));
try {
new CommandReadOperation<>("admin", failPointDocument, new BsonDocumentCodec())
.execute(getBinding(), OPERATION_CONTEXT);
.execute(getBinding(), getOperationContext());
} catch (MongoCommandException e) {
// ignore
}
Expand All @@ -703,7 +705,7 @@ public static <T> T executeSync(final WriteOperation<T> op) {

@SuppressWarnings("overloads")
public static <T> T executeSync(final WriteOperation<T> op, final ReadWriteBinding binding) {
return op.execute(binding, applySessionContext(OPERATION_CONTEXT, binding.getReadPreference()));
return op.execute(binding, applySessionContext(getOperationContext(), binding.getReadPreference()));
}

@SuppressWarnings("overloads")
Expand All @@ -713,7 +715,7 @@ public static <T> T executeSync(final ReadOperation<T, ?> op) {

@SuppressWarnings("overloads")
public static <T> T executeSync(final ReadOperation<T, ?> op, final ReadWriteBinding binding) {
return op.execute(binding, OPERATION_CONTEXT);
return op.execute(binding, getOperationContext());
}

@SuppressWarnings("overloads")
Expand All @@ -729,7 +731,7 @@ public static <T> T executeAsync(final WriteOperation<T> op) throws Throwable {
@SuppressWarnings("overloads")
public static <T> T executeAsync(final WriteOperation<T> op, final AsyncReadWriteBinding binding) throws Throwable {
FutureResultCallback<T> futureResultCallback = new FutureResultCallback<>();
op.executeAsync(binding, applySessionContext(OPERATION_CONTEXT, binding.getReadPreference()), futureResultCallback);
op.executeAsync(binding, applySessionContext(getOperationContext(), binding.getReadPreference()), futureResultCallback);
return futureResultCallback.get(TIMEOUT, SECONDS);
}

Expand All @@ -741,7 +743,7 @@ public static <T> T executeAsync(final ReadOperation<?, T> op) throws Throwable
@SuppressWarnings("overloads")
public static <T> T executeAsync(final ReadOperation<?, T> op, final AsyncReadBinding binding) throws Throwable {
FutureResultCallback<T> futureResultCallback = new FutureResultCallback<>();
op.executeAsync(binding, OPERATION_CONTEXT, futureResultCallback);
op.executeAsync(binding, getOperationContext(), futureResultCallback);
return futureResultCallback.get(TIMEOUT, SECONDS);
}

Expand Down Expand Up @@ -811,19 +813,19 @@ public static <T> List<T> collectCursorResults(final BatchCursor<T> batchCursor)

public static AsyncConnectionSource getWriteConnectionSource(final AsyncReadWriteBinding binding) throws Throwable {
FutureResultCallback<AsyncConnectionSource> futureResultCallback = new FutureResultCallback<>();
binding.getWriteConnectionSource(OPERATION_CONTEXT, futureResultCallback);
binding.getWriteConnectionSource(getOperationContext(), futureResultCallback);
return futureResultCallback.get(TIMEOUT, SECONDS);
}

public static AsyncConnectionSource getReadConnectionSource(final AsyncReadWriteBinding binding) throws Throwable {
FutureResultCallback<AsyncConnectionSource> futureResultCallback = new FutureResultCallback<>();
binding.getReadConnectionSource(OPERATION_CONTEXT, futureResultCallback);
binding.getReadConnectionSource(getOperationContext(), futureResultCallback);
return futureResultCallback.get(TIMEOUT, SECONDS);
}

public static AsyncConnection getConnection(final AsyncConnectionSource source) throws Throwable {
FutureResultCallback<AsyncConnection> futureResultCallback = new FutureResultCallback<>();
source.getConnection(OPERATION_CONTEXT, futureResultCallback);
source.getConnection(getOperationContext(), futureResultCallback);
return futureResultCallback.get(TIMEOUT, SECONDS);
}

Expand Down Expand Up @@ -867,6 +869,6 @@ private static OperationContext applySessionContext(final OperationContext opera
}

public static OperationContext getOperationContext(final ReadPreference readPreference) {
return applySessionContext(OPERATION_CONTEXT, readPreference);
return applySessionContext(getOperationContext(), readPreference);
}
}
Loading