Skip to content

Commit

Permalink
Merge pull request #509 from WenjieDu/(refactor)black_format
Browse files Browse the repository at this point in the history
Apply line-length=120 to black format, and update pre-commit config
  • Loading branch information
WenjieDu authored Sep 12, 2024
2 parents fdd3d32 + 4885f62 commit e20ede8
Show file tree
Hide file tree
Showing 165 changed files with 696 additions and 2,103 deletions.
11 changes: 7 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
repos:
# hooks for checking files
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.6.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml

# hooks for linting code
- repo: https://github.com/psf/black
rev: 22.10.0
rev: 24.8.0
hooks:
- id: black
args: [
--line-length=120, # refer to pyproject.toml
]

- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
rev: 7.1.1
hooks:
- id: flake8
args: [
--max-line-length=120, # refer to pyproject.toml
--extend-ignore=E203, # why ignore E203? Refer to https://github.com/PyCQA/pycodestyle/issues/373
--extend-ignore=E203,E231
]
3 changes: 1 addition & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@
html_context["READTHEDOCS"] = True

html_favicon = (
"https://raw.githubusercontent.com/"
"PyPOTS/pypots.github.io/main/static/figs/pypots_logos/PyPOTS/logo_FFBG.svg"
"https://raw.githubusercontent.com/PyPOTS/pypots.github.io/main/static/figs/pypots_logos/PyPOTS/logo_FFBG.svg"
)

html_sidebars = {
Expand Down
24 changes: 6 additions & 18 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ def _setup_device(self, device: Union[None, str, torch.device, list]) -> None:
self.device = device
elif isinstance(device, list):
if len(device) == 0:
raise ValueError(
"The list of devices should have at least 1 device, but got 0."
)
raise ValueError("The list of devices should have at least 1 device, but got 0.")
elif len(device) == 1:
return self._setup_device(device[0])
# parallely training on multiple CUDA devices
Expand Down Expand Up @@ -179,18 +177,14 @@ def _setup_path(self, saving_path) -> None:
logger.info(f"Model files will be saved to {self.saving_path}")
logger.info(f"Tensorboard file will be saved to {tb_saving_path}")
else:
logger.warning(
"‼️ saving_path not given. Model files and tensorboard file will not be saved."
)
logger.warning("‼️ saving_path not given. Model files and tensorboard file will not be saved.")

def _send_model_to_given_device(self) -> None:
if isinstance(self.device, list):
# parallely training on multiple devices
self.model = torch.nn.DataParallel(self.model, device_ids=self.device)
self.model = self.model.cuda()
logger.info(
f"Model has been allocated to the given multiple devices: {self.device}"
)
logger.info(f"Model has been allocated to the given multiple devices: {self.device}")
else:
self.model = self.model.to(self.device)

Expand Down Expand Up @@ -291,9 +285,7 @@ def save(

if os.path.exists(saving_path):
if overwrite:
logger.warning(
f"‼️ File {saving_path} exists. Argument `overwrite` is True. Overwriting now..."
)
logger.warning(f"‼️ File {saving_path} exists. Argument `overwrite` is True. Overwriting now...")
else:
logger.error(
f"❌ File {saving_path} exists. Saving operation aborted. "
Expand All @@ -309,9 +301,7 @@ def save(
torch.save(self.model, saving_path)
logger.info(f"Saved the model to {saving_path}")
except Exception as e:
raise RuntimeError(
f'Failed to save the model to "{saving_path}" because of the below error! \n{e}'
)
raise RuntimeError(f'Failed to save the model to "{saving_path}" because of the below error! \n{e}')

def load(self, path: str) -> None:
"""Load the saved model from a disk file.
Expand Down Expand Up @@ -519,9 +509,7 @@ def __init__(

def _print_model_size(self) -> None:
"""Print the number of trainable parameters in the initialized NN model."""
self.num_params = sum(
p.numel() for p in self.model.parameters() if p.requires_grad
)
self.num_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
logger.info(
f"{self.__class__.__name__} initialized with the given hyperparameters, "
f"the number of trainable parameters: {self.num_params:,}"
Expand Down
20 changes: 5 additions & 15 deletions pypots/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,7 @@ def _train_model(
for idx, data in enumerate(val_loader):
inputs = self._assemble_input_for_validating(data)
results = self.model.forward(inputs)
epoch_val_loss_collector.append(
results["loss"].sum().item()
)
epoch_val_loss_collector.append(results["loss"].sum().item())

mean_val_loss = np.mean(epoch_val_loss_collector)

Expand All @@ -333,15 +331,11 @@ def _train_model(
)
mean_loss = mean_val_loss
else:
logger.info(
f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}"
)
logger.info(f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}")
mean_loss = mean_train_loss

if np.isnan(mean_loss):
logger.warning(
f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors."
)
logger.warning(f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors.")

if mean_loss < self.best_loss:
self.best_epoch = epoch
Expand All @@ -363,9 +357,7 @@ def _train_model(
nni.report_final_result(self.best_loss)

if self.patience == 0:
logger.info(
"Exceeded the training patience. Terminating the training procedure..."
)
logger.info("Exceeded the training patience. Terminating the training procedure...")
break

except KeyboardInterrupt: # if keyboard interrupt, only warning
Expand All @@ -386,9 +378,7 @@ def _train_model(
if np.isnan(self.best_loss):
raise ValueError("Something is wrong. best_loss is Nan after training.")

logger.info(
f"Finished training. The best model is from epoch#{self.best_epoch}."
)
logger.info(f"Finished training. The best model is from epoch#{self.best_epoch}.")

@abstractmethod
def fit(
Expand Down
8 changes: 2 additions & 6 deletions pypots/classification/grud/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,15 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
empirical_mean = inputs["empirical_mean"]
X_filledLOCF = inputs["X_filledLOCF"]

_, hidden_state = self.model(
X, missing_mask, deltas, empirical_mean, X_filledLOCF
)
_, hidden_state = self.model(X, missing_mask, deltas, empirical_mean, X_filledLOCF)

logits = self.classifier(hidden_state)
classification_pred = torch.softmax(logits, dim=1)
results = {"classification_pred": classification_pred}

# if in training mode, return results with losses
if training:
classification_loss = F.nll_loss(
torch.log(classification_pred), inputs["label"]
)
classification_loss = F.nll_loss(torch.log(classification_pred), inputs["label"])
results["loss"] = classification_loss

return results
10 changes: 4 additions & 6 deletions pypots/classification/grud/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def __init__(
self.X_filledLOCF = locf_torch(self.X)
self.X = torch.nan_to_num(self.X)
self.deltas = _parse_delta_torch(self.missing_mask)
self.empirical_mean = torch.sum(
self.missing_mask * self.X, dim=[0, 1]
) / torch.sum(self.missing_mask, dim=[0, 1])
self.empirical_mean = torch.sum(self.missing_mask * self.X, dim=[0, 1]) / torch.sum(
self.missing_mask, dim=[0, 1]
)
# fill nan with 0, in case some features have no observations
self.empirical_mean = torch.nan_to_num(self.empirical_mean, 0)

Expand Down Expand Up @@ -134,9 +134,7 @@ def _fetch_data_from_file(self, idx: int) -> Iterable:
X_filledLOCF = locf_torch(X.unsqueeze(dim=0)).squeeze()
X = torch.nan_to_num(X)
deltas = _parse_delta_torch(missing_mask)
empirical_mean = torch.sum(missing_mask * X, dim=[0]) / torch.sum(
missing_mask, dim=[0]
)
empirical_mean = torch.sum(missing_mask * X, dim=[0]) / torch.sum(missing_mask, dim=[0])

sample = [
torch.tensor(idx),
Expand Down
21 changes: 5 additions & 16 deletions pypots/classification/raindrop/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
and takes over the forward progress of the algorithm.
"""


# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

Expand Down Expand Up @@ -84,21 +83,13 @@ def forward(self, inputs, training=True):
lengths2 = lengths.unsqueeze(1).to(device)
mask2 = mask.permute(1, 0).unsqueeze(2).long()
if self.sensor_wise_mask:
output = torch.zeros(
[batch_size, self.n_features, self.d_ob + 16], device=device
)
output = torch.zeros([batch_size, self.n_features, self.d_ob + 16], device=device)
extended_missing_mask = missing_mask.view(-1, batch_size, self.n_features)
for se in range(self.n_features):
representation = representation.view(
-1, batch_size, self.n_features, (self.d_ob + 16)
)
representation = representation.view(-1, batch_size, self.n_features, (self.d_ob + 16))
out = representation[:, :, se, :]
l_ = torch.sum(extended_missing_mask[:, :, se], dim=0).unsqueeze(
1
) # length
out_sensor = torch.sum(
out * (1 - extended_missing_mask[:, :, se].unsqueeze(-1)), dim=0
) / (l_ + 1)
l_ = torch.sum(extended_missing_mask[:, :, se], dim=0).unsqueeze(1) # length
out_sensor = torch.sum(out * (1 - extended_missing_mask[:, :, se].unsqueeze(-1)), dim=0) / (l_ + 1)
output[:, se, :] = out_sensor
output = output.view([-1, self.n_features * (self.d_ob + 16)])
elif self.aggregation == "mean":
Expand All @@ -116,9 +107,7 @@ def forward(self, inputs, training=True):

# if in training mode, return results with losses
if training:
classification_loss = F.nll_loss(
torch.log(classification_pred), inputs["label"]
)
classification_loss = F.nll_loss(torch.log(classification_pred), inputs["label"])
results["loss"] = classification_loss

return results
1 change: 0 additions & 1 deletion pypots/classification/raindrop/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""


# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

Expand Down
17 changes: 5 additions & 12 deletions pypots/cli/dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,9 @@ def checkup(self):
)

if self._cleanup:
assert not self._run_tests and not self._lint_code, (
"Argument `--cleanup` should be used alone. "
"Try `pypots-cli dev --cleanup`"
)
assert (
not self._run_tests and not self._lint_code
), "Argument `--cleanup` should be used alone. Try `pypots-cli dev --cleanup`"

def run(self):
"""Execute the given command."""
Expand All @@ -149,14 +148,8 @@ def run(self):
elif self._build:
self.execute_command("python -m build")
elif self._run_tests:
pytest_command = (
f"pytest -k {self._k}" if self._k is not None else "pytest"
)
command_to_run_test = (
f"coverage run -m {pytest_command}"
if self._show_coverage
else pytest_command
)
pytest_command = f"pytest -k {self._k}" if self._k is not None else "pytest"
command_to_run_test = f"coverage run -m {pytest_command}" if self._show_coverage else pytest_command
self.execute_command(command_to_run_test)
if self._show_coverage and os.path.exists(".coverage"):
self.execute_command("coverage report -m")
Expand Down
31 changes: 9 additions & 22 deletions pypots/cli/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ def doc_command_factory(args: Namespace):


def purge_temp_files():
logger.info(
f"Directories _build and {CLONED_LATEST_PYPOTS} will be deleted if exist"
)
logger.info(f"Directories _build and {CLONED_LATEST_PYPOTS} will be deleted if exist")
shutil.rmtree("docs/_build", ignore_errors=True)
shutil.rmtree(CLONED_LATEST_PYPOTS, ignore_errors=True)

Expand Down Expand Up @@ -148,10 +146,9 @@ def checkup(self):
self.check_if_under_root_dir(strict=True)

if self._cleanup:
assert not self._gene_rst and not self._gene_html and not self._view_doc, (
"Argument `--cleanup` should be used alone. "
"Try `pypots-cli doc --cleanup`"
)
assert (
not self._gene_rst and not self._gene_html and not self._view_doc
), "Argument `--cleanup` should be used alone. Try `pypots-cli doc --cleanup`"

def run(self):
"""Execute the given command."""
Expand All @@ -166,9 +163,7 @@ def run(self):

if self._gene_rst:
if os.path.exists(CLONED_LATEST_PYPOTS):
logger.info(
f"Directory {CLONED_LATEST_PYPOTS} exists, deleting it..."
)
logger.info(f"Directory {CLONED_LATEST_PYPOTS} exists, deleting it...")
shutil.rmtree(CLONED_LATEST_PYPOTS, ignore_errors=True)

# Download the latest code from GitHub
Expand All @@ -185,18 +180,12 @@ def run(self):
for f_ in files_to_move:
shutil.move(os.path.join(code_dir, f_), destination_dir)
# delete code in tests because we don't need its doc
shutil.rmtree(
f"{CLONED_LATEST_PYPOTS}/pypots/tests", ignore_errors=True
)
shutil.rmtree(f"{CLONED_LATEST_PYPOTS}/pypots/tests", ignore_errors=True)

# Generate the docs according to the cloned code
logger.info("Generating rst files...")
os.environ[
"SPHINX_APIDOC_OPTIONS"
] = "members,undoc-members,show-inheritance,inherited-members"
self.execute_command(
f"sphinx-apidoc {CLONED_LATEST_PYPOTS} -o {CLONED_LATEST_PYPOTS}/rst"
)
os.environ["SPHINX_APIDOC_OPTIONS"] = "members,undoc-members,show-inheritance,inherited-members"
self.execute_command(f"sphinx-apidoc {CLONED_LATEST_PYPOTS} -o {CLONED_LATEST_PYPOTS}/rst")

# Only save the files we need.
logger.info("Updating the old documentation...")
Expand All @@ -217,9 +206,7 @@ def run(self):
"docs/_build/html"
), "docs/_build/html does not exists, please run `pypots-cli doc --gene_html` first"
logger.info(f"Deploying HTML to http://127.0.0.1:{self._port}...")
self.execute_command(
f"python -m http.server {self._port} -d docs/_build/html -b 127.0.0.1"
)
self.execute_command(f"python -m http.server {self._port} -d docs/_build/html -b 127.0.0.1")

except ImportError:
raise ImportError(IMPORT_ERROR_MESSAGE)
Expand Down
8 changes: 2 additions & 6 deletions pypots/cli/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,14 @@ def run(self):
# run checks first
self.checkup()

logger.info(
f"Installing the dependencies in scope `{self._install}` for you..."
)
logger.info(f"Installing the dependencies in scope `{self._install}` for you...")

if self._tool == "conda":
assert (
self.execute_command("which conda").returncode == 0
), "Conda not installed, cannot set --tool=conda, please check your conda."

self.execute_command(
"conda install pyg pytorch-scatter pytorch-sparse -c pyg"
)
self.execute_command("conda install pyg pytorch-scatter pytorch-sparse -c pyg")

else: # self._tool == "pip"
torch_version = torch.__version__
Expand Down
4 changes: 1 addition & 3 deletions pypots/cli/pypots_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@


def main():
parser = ArgumentParser(
"PyPOTS Command-Line-Interface tool", usage="pypots-cli <command> [<args>]"
)
parser = ArgumentParser("PyPOTS Command-Line-Interface tool", usage="pypots-cli <command> [<args>]")
commands_parser = parser.add_subparsers(help="pypots-cli command helpers")

# Register commands here
Expand Down
Loading

0 comments on commit e20ede8

Please sign in to comment.