-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[jvm-packages]Adds the ability to pass multiple custom evaluation #3544
Changes from 8 commits
45d5a84
170a6c4
40fc3d1
257ba63
3b3c2b1
51071c2
3a72203
bc2620b
484d6db
3eaf440
d67e8cf
2c07ec0
5daced5
4762c3f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
package ml.dmlc.xgboost4j.java; | ||
|
||
import java.io.*; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import com.esotericsoftware.kryo.Kryo; | ||
import com.esotericsoftware.kryo.KryoSerializable; | ||
import com.esotericsoftware.kryo.io.Input; | ||
import com.esotericsoftware.kryo.io.Output; | ||
import org.apache.commons.logging.Log; | ||
import org.apache.commons.logging.LogFactory; | ||
|
||
public class BoosterResults implements Serializable, KryoSerializable { | ||
private Booster booster; | ||
private String[] logInfos; | ||
|
||
public BoosterResults(Booster booster, String[] logInfos) { | ||
this.booster = booster; | ||
this.logInfos = logInfos; | ||
} | ||
|
||
public Booster getBooster() { | ||
return this.booster; | ||
} | ||
|
||
public String[] getLogInfos() { | ||
return this.logInfos; | ||
} | ||
|
||
@Override | ||
public void write(Kryo kryo, Output output) { | ||
kryo.writeObject(output, booster); | ||
kryo.writeObject(output, logInfos); | ||
} | ||
|
||
@Override | ||
public void read(Kryo kryo, Input input) { | ||
booster = kryo.readObject(input, Booster.class); | ||
logInfos = kryo.readObject(input, String[].class); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,14 +18,57 @@ package ml.dmlc.xgboost4j.scala | |
|
||
import java.io.InputStream | ||
|
||
import ml.dmlc.xgboost4j.java.{Booster => JBooster, XGBoost => JXGBoost, XGBoostError} | ||
import ml.dmlc.xgboost4j.java.{ | ||
Booster => JBooster, | ||
XGBoost => JXGBoost, | ||
XGBoostError, | ||
BoosterResults, | ||
IEvaluation | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unnecessary multi-lining |
||
import scala.collection.JavaConverters._ | ||
|
||
/** | ||
* XGBoost Scala Training function. | ||
*/ | ||
object XGBoost { | ||
|
||
@throws(classOf[XGBoostError]) | ||
def trainWithResults( | ||
dtrain: DMatrix, | ||
params: Map[String, Any], | ||
round: Int, | ||
watches: Map[String, DMatrix] = Map[String, DMatrix](), | ||
metrics: Array[Array[Float]] = null, | ||
obj: ObjectiveTrait = null, | ||
evals: Array[IEvaluation] = null, | ||
earlyStoppingRound: Int = 0, | ||
booster: Booster = null | ||
): BoosterResults = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. consolidate line 45 and 46 |
||
|
||
val jWatches = watches.mapValues(_.jDMatrix).asJava | ||
val jBooster = if (booster == null) { | ||
null | ||
} else { | ||
booster.booster | ||
} | ||
val xgboostResults = JXGBoost.trainWithResults( | ||
dtrain.jDMatrix, | ||
// we have to filter null value for customized obj and eval | ||
params | ||
.filter(_._2 != null) | ||
.mapValues(_.toString.asInstanceOf[AnyRef]) | ||
.asJava, | ||
Helw150 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
round, | ||
jWatches, | ||
metrics, | ||
obj, | ||
evals, | ||
earlyStoppingRound, | ||
jBooster | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think you need make one parameter per line |
||
) | ||
xgboostResults | ||
} | ||
|
||
/** | ||
* Train a booster given parameters. | ||
* | ||
|
@@ -52,22 +95,80 @@ object XGBoost { | |
watches: Map[String, DMatrix] = Map(), | ||
metrics: Array[Array[Float]] = null, | ||
obj: ObjectiveTrait = null, | ||
eval: EvalTrait = null, | ||
eval: IEvaluation = null, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. EvalTrait is not used at all? |
||
earlyStoppingRound: Int = 0, | ||
booster: Booster = null): Booster = { | ||
val jWatches = watches.mapValues(_.jDMatrix).asJava | ||
val jBooster = if (booster == null) { | ||
null | ||
booster: Booster = null | ||
): Booster = { | ||
val evals: Array[IEvaluation] = { | ||
if (eval != null) { | ||
Array(eval) | ||
} else { | ||
null | ||
} | ||
} | ||
|
||
val xgboostResults = trainWithResults( | ||
dtrain, | ||
params, | ||
round, | ||
watches, | ||
metrics, | ||
obj, | ||
evals, | ||
earlyStoppingRound, | ||
booster | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unnecessary multi-lining |
||
) | ||
if (booster == null) { | ||
new Booster(xgboostResults.getBooster()) | ||
} else { | ||
booster.booster | ||
// Avoid creating a new SBooster with the same JBooster | ||
booster | ||
} | ||
val xgboostInJava = JXGBoost.train( | ||
dtrain.jDMatrix, | ||
// we have to filter null value for customized obj and eval | ||
params.filter(_._2 != null).mapValues(_.toString.asInstanceOf[AnyRef]).asJava, | ||
round, jWatches, metrics, obj, eval, earlyStoppingRound, jBooster) | ||
} | ||
|
||
/** | ||
* Train a booster given parameters. | ||
* | ||
* @param dtrain Data to be trained. | ||
* @param params Parameters. | ||
* @param round Number of boosting iterations. | ||
* @param watches a group of items to be evaluated during training, this allows user to watch | ||
* performance on the validation set. | ||
* @param metrics array containing the evaluation metrics for each matrix in watches for each | ||
* iteration | ||
* @param earlyStoppingRound if non-zero, training would be stopped | ||
* after a specified number of consecutive | ||
* increases in any evaluation metric. | ||
* @param obj customized objective | ||
* @param evals customized evaluations | ||
* @param booster train from scratch if set to null; train from an existing booster if not null. | ||
* @return The trained booster. | ||
*/ | ||
@throws(classOf[XGBoostError]) | ||
def trainWithMultipleEvals( | ||
dtrain: DMatrix, | ||
params: Map[String, Any], | ||
round: Int, | ||
watches: Map[String, DMatrix] = Map[String, DMatrix](), | ||
metrics: Array[Array[Float]] = null, | ||
obj: ObjectiveTrait = null, | ||
evals: Array[IEvaluation] = null, | ||
earlyStoppingRound: Int = 0, | ||
booster: Booster = null | ||
): Booster = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. consolidate 157 and 158 |
||
val xgboostResults = trainWithResults( | ||
dtrain, | ||
params, | ||
round, | ||
watches, | ||
metrics, | ||
obj, | ||
evals, | ||
earlyStoppingRound, | ||
booster | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unnecessary multi-lining |
||
) | ||
if (booster == null) { | ||
new Booster(xgboostInJava) | ||
new Booster(xgboostResults.getBooster()) | ||
} else { | ||
// Avoid creating a new SBooster with the same JBooster | ||
booster | ||
|
@@ -94,11 +195,20 @@ object XGBoost { | |
nfold: Int = 5, | ||
metrics: Array[String] = null, | ||
obj: ObjectiveTrait = null, | ||
eval: EvalTrait = null): Array[String] = { | ||
eval: EvalTrait = null | ||
): Array[String] = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. consolidate 198 and 199 |
||
JXGBoost.crossValidation( | ||
data.jDMatrix, params.map{ case (key: String, value) => (key, value.toString)}. | ||
toMap[String, AnyRef].asJava, | ||
round, nfold, metrics, obj, eval) | ||
data.jDMatrix, | ||
params | ||
.map { case (key: String, value) => (key, value.toString) } | ||
.toMap[String, AnyRef] | ||
.asJava, | ||
round, | ||
nfold, | ||
metrics, | ||
obj, | ||
eval | ||
) | ||
} | ||
|
||
/** | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
space after
if