Skip to content

Commit

Permalink
Add ability to specify custom evaluation set
Browse files Browse the repository at this point in the history
This fixes dmlc#3231
  • Loading branch information
Boris Filippov committed Apr 9, 2018
1 parent 443ff74 commit 9003a1c
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,21 @@ package ml.dmlc.xgboost4j.scala.spark
import java.io.File
import java.nio.file.Files

import scala.collection.mutable
import scala.util.Random
import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
import org.apache.commons.io.FileUtils
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.{FSDataInputStream, Path}
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.{SparkContext, SparkParallelismTracker, TaskContext}

import scala.collection.mutable
import scala.util.Random


/**
* Rabit tracker configurations.
Expand Down Expand Up @@ -105,6 +106,7 @@ object XGBoost extends Serializable {
round: Int,
obj: ObjectiveTrait,
eval: EvalTrait,
evalSet: Option[RDD[XGBLabeledPoint]] = None,
useExternalMemory: Boolean,
missing: Float,
prevBooster: Booster
Expand All @@ -129,8 +131,10 @@ object XGBoost extends Serializable {
}
rabitEnv.put("DMLC_TASK_ID", taskId)
Rabit.init(rabitEnv)

val watches = Watches(params,
removeMissingValues(labeledPoints, missing),
removeMissingValues(evalSet.map(_.toLocalIterator).getOrElse(Iterator.empty), missing),
fromBaseMarginsToArray(baseMargins), cacheDirName)

try {
Expand Down Expand Up @@ -226,9 +230,11 @@ object XGBoost extends Serializable {
nWorkers: Int,
obj: ObjectiveTrait = null,
eval: EvalTrait = null,
evalSet: Option[RDD[MLLabeledPoint]] = None,
useExternalMemory: Boolean = false,
missing: Float = Float.NaN): XGBoostModel = {
trainWithRDD(trainingData, params, round, nWorkers, obj, eval, useExternalMemory, missing)
trainWithRDD(trainingData, params, round, nWorkers, obj, eval, evalSet,
useExternalMemory, missing)
}

private def overrideParamsAccordingToTaskCPUs(
Expand Down Expand Up @@ -282,13 +288,17 @@ object XGBoost extends Serializable {
nWorkers: Int,
obj: ObjectiveTrait = null,
eval: EvalTrait = null,
evalSet: Option[RDD[MLLabeledPoint]] = None,
useExternalMemory: Boolean = false,
missing: Float = Float.NaN): XGBoostModel = {
import DataUtils._
val xgbTrainingData = trainingData.map { case MLLabeledPoint(label, features) =>
features.asXGB.copy(label = label.toFloat)
}
trainDistributed(xgbTrainingData, params, round, nWorkers, obj, eval,
val xgbEvalSet = evalSet.map(rdd => rdd.map { case MLLabeledPoint(label, features) =>
features.asXGB.copy(label = label.toFloat)
})
trainDistributed(xgbTrainingData, params, round, nWorkers, obj, eval, xgbEvalSet,
useExternalMemory, missing)
}

Expand All @@ -300,6 +310,7 @@ object XGBoost extends Serializable {
nWorkers: Int,
obj: ObjectiveTrait = null,
eval: EvalTrait = null,
evalSet: Option[RDD[XGBLabeledPoint]] = None,
useExternalMemory: Boolean = false,
missing: Float = Float.NaN): XGBoostModel = {
if (params.contains("tree_method")) {
Expand Down Expand Up @@ -340,7 +351,7 @@ object XGBoost extends Serializable {
val parallelismTracker = new SparkParallelismTracker(sc, timeoutRequestWorkers, nWorkers)
val overriddenParams = overrideParamsAccordingToTaskCPUs(params, sc)
val boostersAndMetrics = buildDistributedBoosters(partitionedData, overriddenParams,
tracker.getWorkerEnvs, checkpointRound, obj, eval, useExternalMemory, missing,
tracker.getWorkerEnvs, checkpointRound, obj, eval, evalSet, useExternalMemory, missing,
prevBooster)
val sparkJobThread = new Thread() {
override def run() {
Expand Down Expand Up @@ -492,22 +503,30 @@ private object Watches {
def apply(
params: Map[String, Any],
labeledPoints: Iterator[XGBLabeledPoint],
evalSet: Iterator[XGBLabeledPoint],
baseMarginsOpt: Option[Array[Float]],
cacheDirName: Option[String]): Watches = {
val trainTestRatio = params.get("trainTestRatio").map(_.toString.toDouble).getOrElse(1.0)
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
val r = new Random(seed)
val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]
val trainPoints = labeledPoints.filter { labeledPoint =>
val accepted = r.nextDouble() <= trainTestRatio
if (!accepted) {
testPoints += labeledPoint
}
val trainTestRatio = params.get("trainTestRatio").map(_.toString.toDouble).getOrElse(1.0)

accepted
val (trainPoints, testPoints) = if (evalSet.nonEmpty) {
(labeledPoints, evalSet)
} else {
val testPoints = mutable.ArrayBuffer.empty[XGBLabeledPoint]
val trainPoints = labeledPoints.filter { labeledPoint =>
val accepted = r.nextDouble() <= trainTestRatio
if (!accepted) {
testPoints += labeledPoint
}

accepted
}
(trainPoints, testPoints.toIterator)
}

val trainMatrix = new DMatrix(trainPoints, cacheDirName.map(_ + "/train").orNull)
val testMatrix = new DMatrix(testPoints.iterator, cacheDirName.map(_ + "/test").orNull)
val testMatrix = new DMatrix(testPoints, cacheDirName.map(_ + "/test").orNull)
r.setSeed(seed)
for (baseMargins <- baseMarginsOpt) {
val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@

package ml.dmlc.xgboost4j.scala.spark

import scala.collection.mutable

import ml.dmlc.xgboost4j.scala.spark.params._
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}

import org.apache.spark.ml.Predictor
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.ml.param._
Expand All @@ -30,16 +27,23 @@ import org.apache.spark.sql.types.FloatType
import org.apache.spark.sql.{Dataset, Row}
import org.json4s.DefaultFormats

import scala.collection.mutable

/**
* XGBoost Estimator to produce a XGBoost model
*/
class XGBoostEstimator private[spark](
override val uid: String, xgboostParams: Map[String, Any])
override val uid: String, xgboostParams: Map[String, Any], evalSet: Option[Dataset[_]] = None)
extends Predictor[Vector, XGBoostEstimator, XGBoostModel]
with LearningTaskParams with GeneralParams with BoosterParams with MLWritable {

def this(xgboostParams: Map[String, Any]) =
this(Identifiable.randomUID("XGBoostEstimator"), xgboostParams: Map[String, Any])
def this(xgboostParams: Map[String, Any], evalSet: Option[Dataset[_]]) = {
this(Identifiable.randomUID("XGBoostEstimator"), xgboostParams: Map[String, Any], evalSet)
}

def this(xgboostParams: Map[String, Any]) = {
this(xgboostParams: Map[String, Any], None)
}

def this(uid: String) = this(uid, Map[String, Any]())

Expand Down Expand Up @@ -92,6 +96,22 @@ class XGBoostEstimator private[spark](
}
}

private def toLabeledPoints(df: Dataset[_]) = {
val instances = ensureColumns(df).select(
col($(featuresCol)),
col($(labelCol)).cast(FloatType),
col($(baseMarginCol)).cast(FloatType),
col($(weightCol)).cast(FloatType)
).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float, weight: Float) =>
val (indices, values) = features match {
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
case v: DenseVector => (null, v.values.map(_.toFloat))
}
XGBLabeledPoint(label.toFloat, indices, values, baseMargin = baseMargin, weight = weight)
}
instances
}

fromXGBParamMapToParams()

private[spark] def fromParamsToXGBParamMap: Map[String, Any] = {
Expand Down Expand Up @@ -122,23 +142,13 @@ class XGBoostEstimator private[spark](
* produce a XGBoostModel by fitting the given dataset
*/
override def train(trainingSet: Dataset[_]): XGBoostModel = {
val instances = ensureColumns(trainingSet).select(
col($(featuresCol)),
col($(labelCol)).cast(FloatType),
col($(baseMarginCol)).cast(FloatType),
col($(weightCol)).cast(FloatType)
).rdd.map { case Row(features: Vector, label: Float, baseMargin: Float, weight: Float) =>
val (indices, values) = features match {
case v: SparseVector => (v.indices, v.values.map(_.toFloat))
case v: DenseVector => (null, v.values.map(_.toFloat))
}
XGBLabeledPoint(label.toFloat, indices, values, baseMargin = baseMargin, weight = weight)
}
val instances = toLabeledPoints(trainingSet)

transformSchema(trainingSet.schema, logging = true)
val derivedXGBoosterParamMap = fromParamsToXGBParamMap
val trainedModel = XGBoost.trainDistributed(instances, derivedXGBoosterParamMap,
$(round), $(nWorkers), $(customObj), $(customEval), $(useExternalMemory),
$(missing)).setParent(this)
$(round), $(nWorkers), $(customObj), $(customEval),
evalSet.map(toLabeledPoints), $(useExternalMemory), $(missing)).setParent(this)
val returnedModel = copyValues(trainedModel, extractParamMap())
if (XGBoost.isClassificationTask(derivedXGBoosterParamMap)) {
returnedModel.asInstanceOf[XGBoostClassificationModel].numOfClasses = $(numClasses)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

package ml.dmlc.xgboost4j.scala.spark.params

import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.scala.spark.TrackerConf

import org.apache.spark.ml.param._
import org.apache.spark.rdd.RDD

trait GeneralParams extends Params {

Expand Down

0 comments on commit 9003a1c

Please sign in to comment.