Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jvm-packages]Adds the ability to pass multiple custom evaluation #3544

Closed
wants to merge 14 commits into from
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package ml.dmlc.xgboost4j.java;

import java.io.*;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.KryoSerializable;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class BoosterResults implements Serializable, KryoSerializable {
private Booster booster;
private String[] logInfos;

public BoosterResults(Booster booster, String[] logInfos) {
this.booster = booster;
this.logInfos = logInfos;
}

public Booster getBooster() {
return this.booster;
}

public String[] getLogInfos() {
return this.logInfos;
}

@Override
public void write(Kryo kryo, Output output) {
kryo.writeObject(output, booster);
kryo.writeObject(output, logInfos);
}

@Override
public void read(Kryo kryo, Input input) {
booster = kryo.readObject(input, Booster.class);
logInfos = kryo.readObject(input, String[].class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ public static Booster train(
IObjective obj,
IEvaluation eval,
int earlyStoppingRound) throws XGBoostError {
return train(dtrain, params, round, watches, metrics, obj, eval, earlyStoppingRound, null);
return train(dtrain, params, round, watches, metrics, obj,
eval, earlyStoppingRound, null);
}

/**
Expand Down Expand Up @@ -136,12 +137,32 @@ public static Booster train(
IEvaluation eval,
int earlyStoppingRound,
Booster booster) throws XGBoostError {
if(eval != null) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space after if

return trainWithMultipleEvals(dtrain, params, round, watches, metrics, obj,
new IEvaluation[]{eval}, earlyStoppingRound, booster);
} else {
return trainWithMultipleEvals(dtrain, params, round, watches, metrics, obj,
null, earlyStoppingRound, booster);
}
}

public static BoosterResults trainWithResults(
DMatrix dtrain,
Map<String, Object> params,
int round,
Map<String, DMatrix> watches,
float[][] metrics,
IObjective obj,
IEvaluation[] evals,
int earlyStoppingRound,
Booster booster) throws XGBoostError {
Helw150 marked this conversation as resolved.
Show resolved Hide resolved

//collect eval matrixs
String[] evalNames;
DMatrix[] evalMats;
List<String> names = new ArrayList<String>();
List<DMatrix> mats = new ArrayList<DMatrix>();
List<String> logLines = new ArrayList<String>();

for (Map.Entry<String, DMatrix> evalEntry : watches.entrySet()) {
names.add(evalEntry.getKey());
Expand Down Expand Up @@ -187,11 +208,16 @@ public static Booster train(
//evaluation
if (evalMats.length > 0) {
float[] metricsOut = new float[evalMats.length];
String evalInfo;
if (eval != null) {
evalInfo = booster.evalSet(evalMats, evalNames, eval, metricsOut);
String evalInfo = "";
if (evals != null && evals.length > 0) {
for (int i = 0; i < evals.length; i++) {
String evalLine = booster.evalSet(evalMats, evalNames, evals[i], metricsOut);
Helw150 marked this conversation as resolved.
Show resolved Hide resolved
evalInfo = evalInfo + " " + evalLine;
}
logLines.add(evalInfo);
} else {
evalInfo = booster.evalSet(evalMats, evalNames, iter, metricsOut);
logLines.add(evalInfo);
}
for (int i = 0; i < metricsOut.length; i++) {
metrics[i][iter] = metricsOut[i];
Expand All @@ -208,13 +234,48 @@ public static Booster train(
"early stopping after %d decreasing rounds", earlyStoppingRound));
break;
}

if (Rabit.getRank() == 0) {
Rabit.trackerPrint(evalInfo + '\n');
}
}
booster.saveRabitCheckpoint();
}
return booster;
BoosterResults results = new BoosterResults(booster,
logLines.toArray(new String[logLines.size()]));
return results;
}

/**
* Train a booster with given parameters.
*
* @param dtrain Data to be trained.
* @param params Booster params.
* @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set.
* @param obj customized objective (set to null if not used)
* @param eval customized evaluation (set to null if not used)
* @param earlyStoppingRound if non-zero, training would be stopped
* after a specified number of consecutive
* increases in any evaluation metric.
* @return trained booster
* @throws XGBoostError native error
*/
public static Booster trainWithMultipleEvals(
DMatrix dtrain,
Map<String, Object> params,
int round,
Map<String, DMatrix> watches,
float[][] metrics,
IObjective obj,
IEvaluation[] evals,
int earlyStoppingRound,
Booster booster) throws XGBoostError {

BoosterResults results = trainWithResults(dtrain, params, round, watches, metrics,
obj, evals, earlyStoppingRound, booster);
return results.getBooster();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,57 @@ package ml.dmlc.xgboost4j.scala

import java.io.InputStream

import ml.dmlc.xgboost4j.java.{Booster => JBooster, XGBoost => JXGBoost, XGBoostError}
import ml.dmlc.xgboost4j.java.{
Booster => JBooster,
XGBoost => JXGBoost,
XGBoostError,
BoosterResults,
IEvaluation
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary multi-lining

import scala.collection.JavaConverters._

/**
* XGBoost Scala Training function.
*/
object XGBoost {

@throws(classOf[XGBoostError])
def trainWithResults(
dtrain: DMatrix,
params: Map[String, Any],
round: Int,
watches: Map[String, DMatrix] = Map[String, DMatrix](),
metrics: Array[Array[Float]] = null,
obj: ObjectiveTrait = null,
evals: Array[IEvaluation] = null,
earlyStoppingRound: Int = 0,
booster: Booster = null
): BoosterResults = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consolidate line 45 and 46


val jWatches = watches.mapValues(_.jDMatrix).asJava
val jBooster = if (booster == null) {
null
} else {
booster.booster
}
val xgboostResults = JXGBoost.trainWithResults(
dtrain.jDMatrix,
// we have to filter null value for customized obj and eval
params
.filter(_._2 != null)
.mapValues(_.toString.asInstanceOf[AnyRef])
.asJava,
Helw150 marked this conversation as resolved.
Show resolved Hide resolved
round,
jWatches,
metrics,
obj,
evals,
earlyStoppingRound,
jBooster
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need make one parameter per line

)
xgboostResults
}

/**
* Train a booster given parameters.
*
Expand All @@ -52,22 +95,80 @@ object XGBoost {
watches: Map[String, DMatrix] = Map(),
metrics: Array[Array[Float]] = null,
obj: ObjectiveTrait = null,
eval: EvalTrait = null,
eval: IEvaluation = null,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EvalTrait is not used at all?

earlyStoppingRound: Int = 0,
booster: Booster = null): Booster = {
val jWatches = watches.mapValues(_.jDMatrix).asJava
val jBooster = if (booster == null) {
null
booster: Booster = null
): Booster = {
val evals: Array[IEvaluation] = {
if (eval != null) {
Array(eval)
} else {
null
}
}

val xgboostResults = trainWithResults(
dtrain,
params,
round,
watches,
metrics,
obj,
evals,
earlyStoppingRound,
booster
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary multi-lining

)
if (booster == null) {
new Booster(xgboostResults.getBooster())
} else {
booster.booster
// Avoid creating a new SBooster with the same JBooster
booster
}
val xgboostInJava = JXGBoost.train(
dtrain.jDMatrix,
// we have to filter null value for customized obj and eval
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava,
round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster)
}

/**
* Train a booster given parameters.
*
* @param dtrain Data to be trained.
* @param params Parameters.
* @param round Number of boosting iterations.
* @param watches a group of items to be evaluated during training, this allows user to watch
* performance on the validation set.
* @param metrics array containing the evaluation metrics for each matrix in watches for each
* iteration
* @param earlyStoppingRound if non-zero, training would be stopped
* after a specified number of consecutive
* increases in any evaluation metric.
* @param obj customized objective
* @param evals customized evaluations
* @param booster train from scratch if set to null; train from an existing booster if not null.
* @return The trained booster.
*/
@throws(classOf[XGBoostError])
def trainWithMultipleEvals(
dtrain: DMatrix,
params: Map[String, Any],
round: Int,
watches: Map[String, DMatrix] = Map[String, DMatrix](),
metrics: Array[Array[Float]] = null,
obj: ObjectiveTrait = null,
evals: Array[IEvaluation] = null,
earlyStoppingRound: Int = 0,
booster: Booster = null
): Booster = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consolidate 157 and 158

val xgboostResults = trainWithResults(
dtrain,
params,
round,
watches,
metrics,
obj,
evals,
earlyStoppingRound,
booster
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary multi-lining

)
if (booster == null) {
new Booster(xgboostInJava)
new Booster(xgboostResults.getBooster())
} else {
// Avoid creating a new SBooster with the same JBooster
booster
Expand All @@ -94,11 +195,20 @@ object XGBoost {
nfold: Int = 5,
metrics: Array[String] = null,
obj: ObjectiveTrait = null,
eval: EvalTrait = null): Array[String] = {
eval: EvalTrait = null
): Array[String] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consolidate 198 and 199

JXGBoost.crossValidation(
data.jDMatrix, params.map{ case (key: String, value) => (key, value.toString)}.
toMap[String, AnyRef].asJava,
round, nfold, metrics, obj, eval)
data.jDMatrix,
params
.map { case (key: String, value) => (key, value.toString) }
.toMap[String, AnyRef]
.asJava,
round,
nfold,
metrics,
obj,
eval
)
}

/**
Expand Down