diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala index 00cef158db18..36ad18bb8ef9 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/rabit/RabitTracker.scala @@ -20,7 +20,7 @@ import java.net.{InetAddress, InetSocketAddress} import akka.actor.ActorSystem import akka.pattern.ask -import ml.dmlc.xgboost4j.java.IRabitTracker +import ml.dmlc.xgboost4j.java.{IRabitTracker, TrackerProperties} import ml.dmlc.xgboost4j.scala.rabit.handler.RabitTrackerHandler import scala.concurrent.duration._ @@ -93,8 +93,11 @@ private[scala] class RabitTracker(numWorkers: Int, port: Option[Int] = None, * @return Boolean flag indicating if the Rabit tracker starts successfully. */ private def start(timeout: Duration): Boolean = { + val hostAddress = Option(TrackerProperties.getInstance().getHostIp) + .map(InetAddress.getByName).getOrElse(InetAddress.getLocalHost) + handler ? RabitTrackerHandler.StartTracker( - new InetSocketAddress(InetAddress.getLocalHost, port.getOrElse(0)), maxPortTrials, timeout) + new InetSocketAddress(hostAddress, port.getOrElse(0)), maxPortTrials, timeout) // block by waiting for the actor to bind to a port Try(Await.result(handler ? RabitTrackerHandler.RequestBoundFuture, askTimeout.duration)