Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance Degradation When Upgrading to v1.60 #8063

Open
tristers-at-square opened this issue Jul 12, 2022 · 6 comments
Open

Performance Degradation When Upgrading to v1.60 #8063

tristers-at-square opened this issue Jul 12, 2022 · 6 comments

Comments

@tristers-at-square
Copy link

tristers-at-square commented Jul 12, 2022

I recently tried updating from 0.90 to 1.60. However, my distributed training job (using the approx method) is ~8 times slower now. On version 0.90, each boosting round took about 25 seconds. On version 1.6, each boosting round is now taking around 3 minutes.

Even the hist method on v1.60 is slower than using approx on v0.90.

My dataset has ~5000 features and 500K rows. The exact same parameters and exact same data are being used in my training runs for both versions. The only difference is the version. I cannot share the dataset since it is a work dataset. Roughly 20% of the values in the data are null.

One thing I've noticed, on version 0.90, is that if I increase the nthread parameter the time taken per boosting round goes down. If I decrease the nthread parameter, the time taken per boosting round goes up. This makes sense.

However, on version 1.60, increasing or decreasing the nthread parameter doesn't seem to have any affect. I'm wondering if this is related in some way.

Here's the relevant code snippet:

dtrain = get_dmatrix("my_training_data_path.csv", "csv")
dval = get_dmatrix("my_validation_data_path.csv", "csv")
watchlist = [(dtrain, "train"), (dval, "validation")]

hyperparameters = {
    "alpha": 0,
    "gamma": 0.15,
    "learning_rate": 0.1,
    "max_depth": 8,
    "num_round": 120,
    "objective": "binary:logistic",
    "scale_pos_weight": 1,
    "subsample": 0.8,
    "tree_method": "approx",
}

feval = None
maximize = True
progress_metrics = {}
booster = xgboost.train(
  params=hyperparameters,
  feval=feval,
  dtrain=dtrain,
  evals=watchlist,
  evals_result=progress_metrics,
  maximize=maximize,
  num_boost_round=num_boost_round,
  early_stopping_rounds=num_early_stopping_rounds,
  callbacks=callbacks,
)

Using 2 ml.m5.12xlarge (48vCPU's, 192 GiB RAM) on AWS SageMaker for each training job. Python 3.7.

@trivialfis
Copy link
Member

Hi, thank you for raising the issue. From your code snippet, it seems you are not using distributed training?

@tristers-at-square
Copy link
Author

tristers-at-square commented Jul 12, 2022

Hi, thank you for raising the issue. From your code snippet, it seems you are not using distributed training?

Using Rabit to sync across an AWS cluster to do distributed training on AWS SageMaker as shown here in their official example:

https://sagemaker-examples.readthedocs.io/en/latest/introduction_to_amazon_algorithms/xgboost_abalone/xgboost_abalone_dist_script_mode.html#Create-an-XGBoost-training-script

@tristers-at-square
Copy link
Author

From the logging output:

Screen Shot 2022-07-12 at 5 38 39 PM

@tristers-at-square
Copy link
Author

tristers-at-square commented Jul 13, 2022

Things I've tried:

  • running with a different eval metric
  • running without callbacks
  • running on a single node without Rabit
  • running with different configurations of nthread (1, 4, 8, 12, 48, and -1)

The training time was slow each time.

@trivialfis
Copy link
Member

Thank you for running these experiments. It's probably due to the number of features. The approx was recently rewritten, the new version might be less efficient for wide dataset: #7214 (comment) . Also, the parameter sketch_eps is replaced by max_bin for aligning with hist, the old default for max_bin translated from sketch_eps was around 63 while the rewritten one is 256, which means the new implementation builds larger histogram.

@eugeneyarovoi
Copy link

Do you have any suggestions for anything that can be done to recover the old level of performance for wide datasets like this, datasets with 5K to 10K features? We can try setting max_bin = 63. If I understand correctly, accuracy-wise it should be on-par with what we had before. Are there any other settings that would help?

The main reason we use the "approx" method over the "hist" method for many of our workloads is that "approx" used far less memory than "hist" in version 0.9. Is that still expected to be applicable in 1.6?

We want to upgrade to 1.6 because of all the great new features since 0.9, like early stopping, categorical feature support, etc. However, an 8x increase in running time is prohibitive in our case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants