Skip to content

Commit

Permalink
fixed some issues and added test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
manujosephv committed Nov 24, 2024
1 parent f2c2780 commit be9c563
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 6 deletions.
35 changes: 30 additions & 5 deletions src/pytorch_tabular/tabular_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1832,7 +1832,20 @@ def ret_summary(self, model=None, max_depth: int = -1) -> str:

def __str__(self) -> str:
"""Returns a readable summary of the TabularModel object."""
return f"{self.__class__.__name__}(model={self.model.__class__.__name__ if self.has_model else 'None'})"
return f"{self.__class__.__name__}(model={self.model.__class__.__name__ if self.has_model else self.config._model_name+'(Not Initialized)'})"

def __repr__(self) -> str:
"""Returns an unambiguous representation of the TabularModel object."""
config_str = json.dumps(
OmegaConf.to_container(self.config, resolve=True), indent=4
)
ret_str = f"{self.__class__.__name__}(\n"
if self.has_model:
ret_str += f" model={self.model.__class__.__name__},\n"
else:
ret_str += f" model={self.config._model_name} (Not Initialized),\n"
ret_str += f" config={config_str},\n"
return ret_str

def _repr_html_(self):
"""Generate an HTML representation for Jupyter Notebook."""
Expand Down Expand Up @@ -1912,10 +1925,18 @@ def _repr_html_(self):
header_html = f"<div class='header'>{html.escape(self.model.__class__.__name__ if self.has_model else self.config._model_name)}{model_status}</div>"

# Config Section
config_html = self._generate_collapsible_section("Model Config", self.config, uid=uid, is_dict=True)
config_html = self._generate_collapsible_section(
"Model Config", self.config, uid=uid, is_dict=True
)

# Summary Section
summary_html = "" if not self.has_model else self._generate_collapsible_section("Model Summary", self._generate_model_summary_table(), uid=uid)
summary_html = (
""
if not self.has_model
else self._generate_collapsible_section(
"Model Summary", self._generate_model_summary_table(), uid=uid
)
)

# Combine sections
return f"""
Expand All @@ -1930,7 +1951,9 @@ def _repr_html_(self):
def _generate_collapsible_section(self, title, content, uid, is_dict=False):
container_id = title.lower().replace(" ", "_") + uid
if is_dict:
content = self._generate_nested_collapsible_sections(OmegaConf.to_container(content, resolve=True), container_id)
content = self._generate_nested_collapsible_sections(
OmegaConf.to_container(content, resolve=True), container_id
)
return f"""
<div>
<span class="toggle-button" onclick="toggleVisibility('{container_id}')">&#9654;</span>
Expand All @@ -1947,7 +1970,9 @@ def _generate_nested_collapsible_sections(self, content, parent_id):
if isinstance(value, dict):
nested_id = f"{parent_id}_{key}".replace(" ", "_")
nested_id = nested_id + str(uuid.uuid4())
nested_content = self._generate_nested_collapsible_sections(value, nested_id)
nested_content = self._generate_nested_collapsible_sections(
value, nested_id
)
html_content += f"""
<div>
<span class="toggle-button" onclick="toggleVisibility('{nested_id}')">&#9654;</span>
Expand Down
62 changes: 61 additions & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

MODEL_CONFIG_SAVE_TEST = [
(CategoryEmbeddingModelConfig, {"layers": "10-20"}),
(AutoIntConfig, {"num_heads": 1, "num_attn_blocks": 1}),
(GANDALFConfig, {}),
(NodeConfig, {"num_trees": 100, "depth": 2}),
(TabNetModelConfig, {"n_a": 2, "n_d": 2}),
]
Expand Down Expand Up @@ -1247,3 +1247,63 @@ def test_model_compare_regression(regression_data, model_list, continuous_cols,
# # there may be multiple models with the same score
# best_models = comp_df.loc[comp_df[f"test_{rank_metric[0]}"] == best_score, "model"].values.tolist()
# assert best_model.model._get_name() in best_models

@pytest.mark.parametrize("model_config_class", MODEL_CONFIG_SAVE_TEST)
@pytest.mark.parametrize("continuous_cols", [list(DATASET_CONTINUOUS_COLUMNS)])
@pytest.mark.parametrize("categorical_cols", [["HouseAgeBin"]])
@pytest.mark.parametrize("custom_metrics", [None, [fake_metric]])
@pytest.mark.parametrize("custom_loss", [None, torch.nn.L1Loss()])
@pytest.mark.parametrize("custom_optimizer", [None, torch.optim.Adagrad, "SGD", "torch_optimizer.AdaBound"])
def test_str_repr(
regression_data,
model_config_class,
continuous_cols,
categorical_cols,
custom_metrics,
custom_loss,
custom_optimizer,
):
(train, test, target) = regression_data
data_config = DataConfig(
target=target,
continuous_cols=continuous_cols,
categorical_cols=categorical_cols,
)
model_config_class, model_config_params = model_config_class
model_config_params["task"] = "regression"
model_config = model_config_class(**model_config_params)
trainer_config = TrainerConfig(
max_epochs=3,
checkpoints=None,
early_stopping=None,
accelerator="cpu",
fast_dev_run=True,
)
optimizer_config = OptimizerConfig()

tabular_model = TabularModel(
data_config=data_config,
model_config=model_config,
optimizer_config=optimizer_config,
trainer_config=trainer_config,
)
assert "Not Initialized" in str(tabular_model)
assert "Not Initialized" in repr(tabular_model)
assert "Model Summary" not in tabular_model._repr_html_()
assert "Model Config" in tabular_model._repr_html_()
assert "config" in tabular_model.__repr__()
assert "config" not in str(tabular_model)
tabular_model.fit(
train=train,
metrics=custom_metrics,
metrics_prob_inputs=None if custom_metrics is None else [False],
loss=custom_loss,
optimizer=custom_optimizer,
optimizer_params={}
)
assert model_config_class._model_name in str(tabular_model)
assert model_config_class._model_name in repr(tabular_model)
assert "Model Summary" in tabular_model._repr_html_()
assert "Model Config" in tabular_model._repr_html_()
assert "config" in tabular_model.__repr__()
assert model_config_class._model_name in tabular_model._repr_html_()

0 comments on commit be9c563

Please sign in to comment.