Skip to content

Commit

Permalink
address the comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Nan Zhu committed Dec 18, 2019
1 parent 551db22 commit 6485c89
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 39 deletions.
1 change: 1 addition & 0 deletions jvm-packages/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
<spark.version>2.4.3</spark.version>
<scala.version>2.12.8</scala.version>
<scala.binary.version>2.12</scala.binary.version>
<hadoop.version>2.7.3</hadoop.version>
</properties>
<repositories>
<repository>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ private[this] case class XGBoostExecutionParams(
missing: Float,
trackerConf: TrackerConf,
timeoutRequestWorkers: Long,
checkpointParam: ExternalCheckpointParams,
checkpointParam: Option[ExternalCheckpointParams],
xgbInputParams: XGBoostExecutionInputParams,
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
cacheTrainingSet: Boolean) {
Expand Down Expand Up @@ -342,9 +342,7 @@ object XGBoost extends Serializable {
rabitEnv.put("DMLC_NUM_ATTEMPT", attempt)
rabitEnv.put("DMLC_WORKER_STOP_PROCESS_ON_ERROR", "false")
val numRounds = xgbExecutionParam.numRounds
val makeCheckpoint = xgbExecutionParam.checkpointParam.checkpointPath != null &&
xgbExecutionParam.checkpointParam.checkpointPath.nonEmpty &&
taskId.toInt == 0
val makeCheckpoint = xgbExecutionParam.checkpointParam.isDefined && taskId.toInt == 0
try {
Rabit.init(rabitEnv)
val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds
Expand All @@ -354,7 +352,7 @@ object XGBoost extends Serializable {
SXGBoost.trainAndSaveCheckpoint(
watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
watches.toMap, metrics, obj, eval,
earlyStoppingRound = numEarlyStoppingRounds, prevBooster, Some(externalCheckpointParams))
earlyStoppingRound = numEarlyStoppingRounds, prevBooster, externalCheckpointParams)
} else {
SXGBoost.train(watches.toMap("train"), xgbExecutionParam.toMap, numRounds,
watches.toMap, metrics, obj, eval,
Expand Down Expand Up @@ -523,7 +521,7 @@ object XGBoost extends Serializable {
val xgbRabitParams = xgbParamsFactory.buildRabitParams.asJava
val sc = trainingData.sparkContext
val checkpointManager = new ExternalCheckpointManager(
xgbExecParams.checkpointParam.checkpointPath, FileSystem.get(sc.hadoopConfiguration))
xgbExecParams.checkpointParam.get.checkpointPath, FileSystem.get(sc.hadoopConfiguration))
checkpointManager.cleanUpHigherVersions(xgbExecParams.numRounds)
val transformedTrainingData = composeInputData(trainingData, xgbExecParams.cacheTrainingSet,
hasGroup, xgbExecParams.numWorkers)
Expand Down Expand Up @@ -560,7 +558,7 @@ object XGBoost extends Serializable {
tracker.stop()
}
// we should delete the checkpoint directory after a successful training
if (!xgbExecParams.checkpointParam.skipCleanCheckpoint) {
if (!xgbExecParams.checkpointParam.get.skipCleanCheckpoint) {
checkpointManager.cleanPath()
}
(booster, metrics)
Expand Down
6 changes: 4 additions & 2 deletions jvm-packages/xgboost4j/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-hdfs</artifactId>
<version>2.7.3</version>
<version>${hadoop.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-common</artifactId>
<version>2.7.3</version>
<version>${hadoop.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,27 @@ public class ExternalCheckpointManager {

private Log logger = LogFactory.getLog("ExternalCheckpointManager");
private String modelSuffix = ".model";
private String checkpointPath;
private Path checkpointPath;
private FileSystem fs;

public ExternalCheckpointManager(String checkpointPath, FileSystem fs) {
this.checkpointPath = checkpointPath;
public ExternalCheckpointManager(String checkpointPath, FileSystem fs) throws XGBoostError {
if (checkpointPath == null || checkpointPath.isEmpty()) {
throw new XGBoostError("cannot create ExternalCheckpointManager with null or" +
" empty checkpoint path");
}
this.checkpointPath = new Path(checkpointPath);
this.fs = fs;
}

private String getPath(int version) {
return checkpointPath + "/" + version + modelSuffix;
return checkpointPath.toUri().getPath() + "/" + version + modelSuffix;
}

private List<Integer> getExistingVersions() throws IOException {
if (checkpointPath == null || checkpointPath.isEmpty() ||
!fs.exists(new Path(checkpointPath))) {
if (!fs.exists(checkpointPath)) {
return new ArrayList<>();
} else {
return Arrays.stream(fs.listStatus(new Path(checkpointPath)))
return Arrays.stream(fs.listStatus(checkpointPath))
.map(path -> path.getPath().getName())
.filter(fileName -> fileName.endsWith(modelSuffix))
.map(fileName -> Integer.valueOf(
Expand All @@ -42,9 +45,7 @@ private List<Integer> getExistingVersions() throws IOException {
}

public void cleanPath() throws IOException {
if (checkpointPath != null) {
fs.delete(new Path(checkpointPath), true);
}
fs.delete(checkpointPath, true);
}

public Booster loadCheckpointAsBooster() throws IOException, XGBoostError {
Expand All @@ -67,33 +68,33 @@ public void updateCheckpoint(Booster boosterToCheckpoint) throws IOException, XG
.map(this::getPath).collect(Collectors.toList());
String eventualPath = getPath(boosterToCheckpoint.getVersion());
String tempPath = eventualPath + "-" + UUID.randomUUID();
OutputStream out = fs.create(new Path(tempPath), true);
boosterToCheckpoint.saveModel(out);
out.close();
fs.rename(new Path(tempPath), new Path(eventualPath));
logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion());
prevModelPaths.stream().forEach(path -> {
try {
fs.delete(new Path(path), true);
} catch (IOException e) {
logger.error("failed to delete outdated checkpoint at " + path, e);
}
});
try (OutputStream out = fs.create(new Path(tempPath), true)) {
boosterToCheckpoint.saveModel(out);
fs.rename(new Path(tempPath), new Path(eventualPath));
logger.info("saving checkpoint with version " + boosterToCheckpoint.getVersion());
prevModelPaths.stream().forEach(path -> {
try {
fs.delete(new Path(path), true);
} catch (IOException e) {
logger.error("failed to delete outdated checkpoint at " + path, e);
}
});
}
}

public void cleanUpHigherVersions(int currentRound) throws IOException {
getExistingVersions().stream().filter(v -> v / 2 >= currentRound).forEach(v -> {
try {
fs.delete(new Path(getPath(v)), true);
} catch (IOException e) {
e.printStackTrace();
logger.error("failed to clean checkpoint from other training instance", e);
}
});
}

public List<Integer> getCheckpointRounds(int checkpointInterval, int numOfRounds)
throws IOException {
if (checkpointPath != null && checkpointInterval > 0) {
if (checkpointInterval > 0) {
List<Integer> prevRounds =
getExistingVersions().stream().map(v -> v / 2).collect(Collectors.toList());
prevRounds.add(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,8 @@ private static void saveCheckpoint(
ecm.updateCheckpoint(booster);
}
} catch (Exception e) {
logger.error("failed to save checkpoint in XGBoost4J", e);
throw new XGBoostError("failed to save checkpoint in XGBoost4J, due to " + e);

logger.error("failed to save checkpoint in XGBoost4J at iteration " + iter, e);
throw new XGBoostError("failed to save checkpoint in XGBoost4J at iteration" + iter, e);
}
}

Expand Down Expand Up @@ -287,7 +286,7 @@ public static Booster train(
-1, null, null);
} catch (IOException e) {
logger.error("training failed in xgboost4j", e);
throw new XGBoostError("training failed in xgboost4j " + e);
throw new XGBoostError("training failed in xgboost4j ", e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,8 @@ public class XGBoostError extends Exception {
public XGBoostError(String message) {
super(message);
}

public XGBoostError(String message, Throwable cause) {
super(message, cause);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ private[scala] case class ExternalCheckpointParams(
skipCleanCheckpoint: Boolean)

private[scala] object ExternalCheckpointParams {
def extractParams(params: Map[String, Any]): ExternalCheckpointParams = {

def extractParams(params: Map[String, Any]): Option[ExternalCheckpointParams] = {
val checkpointPath: String = params.get("checkpoint_path") match {
case None | Some(null) | Some("") => null
case Some(path: String) => path
Expand All @@ -185,7 +186,11 @@ private[scala] object ExternalCheckpointParams {
case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
" an instance of Boolean")
}
ExternalCheckpointParams(checkpointInterval, checkpointPath, skipCleanCheckpointFile)
if (checkpointPath == null || checkpointInterval == 0) {
None
} else {
Some(ExternalCheckpointParams(checkpointInterval, checkpointPath, skipCleanCheckpointFile))
}
}
}

Expand Down

0 comments on commit 6485c89

Please sign in to comment.