Skip to content

Commit

Permalink
Feed Dask Array to MultinomialNB as it not accept Dask Cudf.
Browse files Browse the repository at this point in the history
  • Loading branch information
justinuliu committed Jun 21, 2022
1 parent e9e0bf5 commit 5cce366
Showing 1 changed file with 7 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public static enum Learner {
ElasticNet("linear_model", false, true, false, false,
"\talpha=1.0, l1_ratio=0.5, fit_intercept=True, normalize=False, max_iter=1000,\n"
+ "\ttol=0.001, selection='cyclic', handle=None, output_type=None, verbose=False"),
MultinomialNB("naive_bayes", true, false, true, false,
MultinomialNB("naive_bayes", true, false, false, false,
"\talpha=1.0, fit_prior=True, class_prior=None, output_type=None, handle=None, verbose=False"),
CD("solvers", false, true, false, false, "");

Expand Down Expand Up @@ -847,6 +847,9 @@ public void buildClassifier(Instances data) throws Exception {
learnScript.append("X, Y = cudf.from_pandas(X), cudf.from_pandas(Y)\n");
learnScript.append("X, Y = dask_cudf.from_cudf(X, npartitions=len(dask_cluster.workers))"
+ " ,dask_cudf.from_cudf(Y, npartitions=len(dask_cluster.workers))\n");
if (m_learner == Learner.MultinomialNB) {
learnScript.append("X, Y = X.to_dask_array(), Y.to_dask_array()\n");
}
if (m_learner.isClassifier()) {
learnScript.append(MODEL_ID + m_modelHash + ".fit(X.astype('float32'), Y.astype('int32'))").append("\n");
} else if (m_learner.isRegressor()) {
Expand Down Expand Up @@ -963,6 +966,9 @@ public double[][] distributionsForInstances(Instances insts)
predictScript.append("import dask_cudf\n");
predictScript.append("X = cudf.from_pandas(X)\n");
predictScript.append("X = dask_cudf.from_cudf(X, npartitions=len(dask_cluster.workers))\n");
if (m_learner == Learner.MultinomialNB) {
predictScript.append("X = X.to_dask_array()\n");
}
predictScript.append("preds = " + MODEL_ID + m_modelHash + ".predict"
+ (m_learner.producesProbabilities(m_learnerOpts) ? "_proba" : "") + "(X)\n");
predictScript.append("if type(preds) in (dask_cudf.Series, dask_cudf.DataFrame):\n");
Expand Down

0 comments on commit 5cce366

Please sign in to comment.