Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Stop the ML memory tracker before closing node #39111

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -435,10 +435,10 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
DatafeedManager datafeedManager = new DatafeedManager(threadPool, client, clusterService, datafeedJobBuilder,
System::currentTimeMillis, auditor, autodetectProcessManager);
this.datafeedManager.set(datafeedManager);
MlLifeCycleService mlLifeCycleService = new MlLifeCycleService(environment, clusterService, datafeedManager,
autodetectProcessManager);
MlMemoryTracker memoryTracker = new MlMemoryTracker(settings, clusterService, threadPool, jobManager, jobResultsProvider);
this.memoryTracker.set(memoryTracker);
MlLifeCycleService mlLifeCycleService = new MlLifeCycleService(environment, clusterService, datafeedManager,
autodetectProcessManager, memoryTracker);

// This object's constructor attaches to the license state, so there's no need to retain another reference to it
new InvalidLicenseEnforcer(getLicenseState(), threadPool, datafeedManager, autodetectProcessManager);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.elasticsearch.common.component.LifecycleListener;
import org.elasticsearch.env.Environment;
import org.elasticsearch.xpack.ml.datafeed.DatafeedManager;
import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
import org.elasticsearch.xpack.ml.process.NativeController;
import org.elasticsearch.xpack.ml.process.NativeControllerHolder;
import org.elasticsearch.xpack.ml.job.process.autodetect.AutodetectProcessManager;
Expand All @@ -20,16 +21,14 @@ public class MlLifeCycleService {
private final Environment environment;
private final DatafeedManager datafeedManager;
private final AutodetectProcessManager autodetectProcessManager;

public MlLifeCycleService(Environment environment, ClusterService clusterService) {
this(environment, clusterService, null, null);
}
private final MlMemoryTracker memoryTracker;

public MlLifeCycleService(Environment environment, ClusterService clusterService, DatafeedManager datafeedManager,
AutodetectProcessManager autodetectProcessManager) {
AutodetectProcessManager autodetectProcessManager, MlMemoryTracker memoryTracker) {
this.environment = environment;
this.datafeedManager = datafeedManager;
this.autodetectProcessManager = autodetectProcessManager;
this.memoryTracker = memoryTracker;
clusterService.addLifecycleListener(new LifecycleListener() {
@Override
public void beforeStop() {
Expand Down Expand Up @@ -59,5 +58,8 @@ public synchronized void stop() {
} catch (IOException e) {
// We're stopping anyway, so don't let this complicate the shutdown sequence
}
if (memoryTracker != null) {
memoryTracker.stop();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Phaser;
import java.util.stream.Collectors;

/**
Expand All @@ -55,6 +56,7 @@ public class MlMemoryTracker implements LocalNodeMasterListener {
private final ClusterService clusterService;
private final JobManager jobManager;
private final JobResultsProvider jobResultsProvider;
private final Phaser stopPhaser;
private volatile boolean isMaster;
private volatile Instant lastUpdateTime;
private volatile Duration reassignmentRecheckInterval;
Expand All @@ -65,6 +67,7 @@ public MlMemoryTracker(Settings settings, ClusterService clusterService, ThreadP
this.clusterService = clusterService;
this.jobManager = jobManager;
this.jobResultsProvider = jobResultsProvider;
this.stopPhaser = new Phaser(1);
setReassignmentRecheckInterval(PersistentTasksClusterService.CLUSTER_TASKS_ALLOCATION_RECHECK_INTERVAL_SETTING.get(settings));
clusterService.addLocalNodeMasterListener(this);
clusterService.getClusterSettings().addSettingsUpdateConsumer(
Expand All @@ -89,6 +92,23 @@ public void offMaster() {
lastUpdateTime = null;
}

/**
* Wait for all outstanding searches to complete.
* After returning, no new searches can be started.
*/
public void stop() {
logger.trace("ML memory tracker stop called");
// We never terminate the phaser
assert stopPhaser.isTerminated() == false;
// If there are no registered parties or no unarrived parties then there is a flaw
// in the register/arrive/unregister logic in another method that uses the phaser
assert stopPhaser.getRegisteredParties() > 0;
assert stopPhaser.getUnarrivedParties() > 0;
stopPhaser.arriveAndAwaitAdvance();
assert stopPhaser.getPhase() > 0;
logger.debug("ML memory tracker stopped");
}

@Override
public String executorName() {
return MachineLearning.UTILITY_THREAD_POOL_NAME;
Expand Down Expand Up @@ -146,13 +166,13 @@ public boolean asyncRefresh() {
try {
ActionListener<Void> listener = ActionListener.wrap(
aVoid -> logger.trace("Job memory requirement refresh request completed successfully"),
e -> logger.error("Failed to refresh job memory requirements", e)
e -> logger.warn("Failed to refresh job memory requirements", e)
);
threadPool.executor(executorName()).execute(
() -> refresh(clusterService.state().getMetaData().custom(PersistentTasksCustomMetaData.TYPE), listener));
return true;
} catch (EsRejectedExecutionException e) {
logger.debug("Couldn't schedule ML memory update - node might be shutting down", e);
logger.warn("Couldn't schedule ML memory update - node might be shutting down", e);
}
}

Expand Down Expand Up @@ -246,25 +266,43 @@ public void refreshJobMemory(String jobId, ActionListener<Long> listener) {
return;
}

// The phaser prevents searches being started after the memory tracker's stop() method has returned
if (stopPhaser.register() != 0) {
benwtrent marked this conversation as resolved.
Show resolved Hide resolved
// Phases above 0 mean we've been stopped, so don't do any operations that involve external interaction
stopPhaser.arriveAndDeregister();
listener.onFailure(new EsRejectedExecutionException("Couldn't run ML memory update - node is shutting down"));
return;
}
ActionListener<Long> phaserListener = ActionListener.wrap(
r -> {
stopPhaser.arriveAndDeregister();
listener.onResponse(r);
},
e -> {
stopPhaser.arriveAndDeregister();
listener.onFailure(e);
}
);

try {
jobResultsProvider.getEstablishedMemoryUsage(jobId, null, null,
establishedModelMemoryBytes -> {
if (establishedModelMemoryBytes <= 0L) {
setJobMemoryToLimit(jobId, listener);
setJobMemoryToLimit(jobId, phaserListener);
} else {
Long memoryRequirementBytes = establishedModelMemoryBytes + Job.PROCESS_MEMORY_OVERHEAD.getBytes();
memoryRequirementByJob.put(jobId, memoryRequirementBytes);
listener.onResponse(memoryRequirementBytes);
phaserListener.onResponse(memoryRequirementBytes);
}
},
e -> {
logger.error("[" + jobId + "] failed to calculate job established model memory requirement", e);
setJobMemoryToLimit(jobId, listener);
setJobMemoryToLimit(jobId, phaserListener);
}
);
} catch (Exception e) {
logger.error("[" + jobId + "] failed to calculate job established model memory requirement", e);
setJobMemoryToLimit(jobId, listener);
setJobMemoryToLimit(jobId, phaserListener);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.persistent.PersistentTasksClusterService;
import org.elasticsearch.persistent.PersistentTasksCustomMetaData;
import org.elasticsearch.test.ESTestCase;
Expand All @@ -29,6 +30,7 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;

import static org.hamcrest.CoreMatchers.instanceOf;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.anyString;
Expand Down Expand Up @@ -157,6 +159,19 @@ public void testRefreshOne() {
assertNull(memoryTracker.getJobMemoryRequirement(jobId));
}

public void testStop() {

memoryTracker.onMaster();
memoryTracker.stop();

AtomicReference<Exception> exception = new AtomicReference<>();
memoryTracker.refreshJobMemory("job", ActionListener.wrap(ESTestCase::assertNull, exception::set));

assertNotNull(exception.get());
assertThat(exception.get(), instanceOf(EsRejectedExecutionException.class));
assertEquals("Couldn't run ML memory update - node is shutting down", exception.get().getMessage());
}

private PersistentTasksCustomMetaData.PersistentTask<OpenJobAction.JobParams> makeTestTask(String jobId) {
return new PersistentTasksCustomMetaData.PersistentTask<>("job-" + jobId, MlTasks.JOB_TASK_NAME, new OpenJobAction.JobParams(jobId),
0, PersistentTasksCustomMetaData.INITIAL_ASSIGNMENT);
Expand Down