Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
lugimzzz committed Feb 24, 2023
1 parent ff1d3c4 commit f0da8a1
Showing 1 changed file with 22 additions and 7 deletions.
29 changes: 22 additions & 7 deletions tests/experimental/autonlp/test_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,15 +269,21 @@ def test_multilabel(self, custom_model_candidate, hp_overrides):
)

# test predict
copy_test_ds = copy.deepcopy(self.multi_label_dev_ds)
test_output = auto_trainer.predict(test_dataset=copy_test_ds)
dev_output = auto_trainer.predict(test_dataset=copy_dev_ds)
self.assertEqual(
eval_metrics1[auto_trainer.metric_for_best_model],
test_output.metrics[auto_trainer.metric_for_best_model.replace("eval", "test")],
dev_output.metrics[auto_trainer.metric_for_best_model.replace("eval", "test")],
)
self.assertEqual(len(copy_dev_ds), len(dev_output.label_ids))
self.assertEqual(len(copy_dev_ds), len(dev_output.predictions))
self.assertEqual(len(auto_trainer.id2label), len(dev_output.predictions[0]))

copy_test_ds = copy.deepcopy(self.test_ds)
test_output = auto_trainer.predict(test_dataset=copy_test_ds)
self.assertFalse(auto_trainer.metric_for_best_model.replace("eval", "test") in test_output.metrics)
self.assertEqual(None, test_output.label_ids)
self.assertEqual(len(copy_test_ds), len(test_output.predictions))
self.assertEqual(len(auto_trainer.id2label), len(test_output.predictions[0]))
self.assertEqual(len(copy_test_ds), len(test_output.label_ids))

# test taskflow
taskflow = auto_trainer.to_taskflow()
Expand Down Expand Up @@ -371,12 +377,21 @@ def test_default_model_candidate(self, language, hp_overrides):
)

# test predict
copy_test_ds = copy.deepcopy(self.multi_class_dev_ds)
eval_metrics3 = auto_trainer.predict(test_dataset=copy_test_ds).metrics
dev_output = auto_trainer.predict(test_dataset=copy_dev_ds)
self.assertEqual(
eval_metrics1[auto_trainer.metric_for_best_model],
eval_metrics3[auto_trainer.metric_for_best_model.replace("eval", "test")],
dev_output.metrics[auto_trainer.metric_for_best_model.replace("eval", "test")],
)
self.assertEqual(len(copy_dev_ds), len(dev_output.label_ids))
self.assertEqual(len(copy_dev_ds), len(dev_output.predictions))
self.assertEqual(len(auto_trainer.id2label), len(dev_output.predictions[0]))

copy_test_ds = copy.deepcopy(self.test_ds)
test_output = auto_trainer.predict(test_dataset=copy_test_ds)
self.assertFalse(auto_trainer.metric_for_best_model.replace("eval", "test") in test_output.metrics)
self.assertEqual(None, test_output.label_ids)
self.assertEqual(len(copy_test_ds), len(test_output.predictions))
self.assertEqual(len(auto_trainer.id2label), len(test_output.predictions[0]))

# test export
temp_export_path = os.path.join(temp_dir_path, "test_export")
Expand Down

0 comments on commit f0da8a1

Please sign in to comment.