Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix invalid hub #317

Merged
merged 5 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions nlptest/nlptest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import pickle
from collections import defaultdict
from typing import Optional, Union
from typing import Optional, Union, Any

import pandas as pd
import yaml
Expand Down Expand Up @@ -35,7 +35,7 @@ class Harness:

def __init__(
self,
model: Union[str],
model: Union[str, Any],
task: Optional[str] = "ner",
hub: Optional[str] = None,
data: Optional[str] = None,
Expand All @@ -58,7 +58,14 @@ def __init__(
super().__init__()
self.task = task

if data is None and (task, model, hub) in self.DEFAULTS_DATASET.keys():
if isinstance(model, str) and hub is None:
raise ValueError(f"When passing a string argument to the 'model' parameter, you must provide an argument "
f"for the 'hub' parameter as well.")

if hub is not None and hub not in self.SUPPORTED_HUBS:
raise ValueError(f"Provided hub is not supported. Please choose one of the supported hubs: {self.SUPPORTED_HUBS}")

if data is None and (task, model, hub) in self.DEFAULTS_DATASET:
data_path = os.path.join("data", self.DEFAULTS_DATASET[(task, model, hub)])
data = resource_filename("nlptest", data_path)
self.data = DataFactory(data, task=self.task).load()
Expand All @@ -77,17 +84,14 @@ def __init__(
self.data = DataFactory(data, task=self.task).load() if data is not None else None

if isinstance(model, str):
if hub is None:
raise OSError(f"You need to pass the 'hub' parameter when passing a string as 'model'.")

self.model = ModelFactory.load_model(path=model, task=task, hub=hub)
else:
self.model = ModelFactory(task=task, model=model)

if config is not None:
self._config = self.configure(config)
else:
logging.info(f"No configuration file was provided, loading default config.")
logging.info("No configuration file was provided, loading default config.")
self._config = self.configure(resource_filename("nlptest", "data/config.yml"))

self._testcases = None
Expand Down
2 changes: 1 addition & 1 deletion tests/test_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_Harness(self):

def test_missing_parameter(self):
""""""
with self.assertRaises(OSError) as _:
with self.assertRaises(ValueError) as _:
Harness(task='ner', model='dslim/bert-base-NER',
data=self.data_path, config=self.config_path)

Expand Down