Skip to content

Commit

Permalink
fixed some precommit errors
Browse files Browse the repository at this point in the history
  • Loading branch information
manujosephv committed Nov 25, 2024
1 parent 872b018 commit ed602ba
Showing 1 changed file with 42 additions and 11 deletions.
53 changes: 42 additions & 11 deletions src/pytorch_tabular/tabular_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1659,8 +1659,9 @@ def summary(self, model=None, max_depth: int = -1) -> None:
"""Prints a summary of the model.
Args:
max_depth (int): The maximum depth to traverse the modules and displayed in the summary.
Defaults to -1, which means will display all the modules.
max_depth (int): The maximum depth to traverse the modules and
displayed in the summary. Defaults to -1, which means will
display all the modules.
"""
if model is not None:
Expand All @@ -1682,8 +1683,9 @@ def ret_summary(self, model=None, max_depth: int = -1) -> str:
"""Returns a summary of the model as a string.
Args:
max_depth (int): The maximum depth to traverse the modules and displayed in the summary.
Defaults to -1, which means will display all the modules.
max_depth (int): The maximum depth to traverse the modules and
displayed in the summary. Defaults to -1, which means will
display all the modules.
Returns:
str: The summary of the model.
Expand All @@ -1699,12 +1701,13 @@ def ret_summary(self, model=None, max_depth: int = -1) -> str:
summary_str += "Config\n"
summary_str += "-" * 100 + "\n"
summary_str += pformat(self.config.__dict__["_content"], indent=4, width=80, compact=True)
summary_str += "\nFull Model Summary once model has been initialized or passed in as an argument"
summary_str += "\nFull Model Summary once model has been " "initialized or passed in as an argument"
return summary_str

def __str__(self) -> str:
"""Returns a readable summary of the TabularModel object."""
return f"{self.__class__.__name__}(model={self.model.__class__.__name__ if self.has_model else self.config._model_name+'(Not Initialized)'})"
model_name = self.model.__class__.__name__ if self.has_model else self.config._model_name + "(Not Initialized)"
return f"{self.__class__.__name__}(model={model_name})"

def __repr__(self) -> str:
"""Returns an unambiguous representation of the TabularModel object."""
Expand Down Expand Up @@ -1792,7 +1795,8 @@ def _repr_html_(self):
# Header (Main model name)
uid = str(uuid.uuid4())
model_status = "" if self.has_model else "(Not Initialized)"
header_html = f"<div class='header'>{html.escape(self.model.__class__.__name__ if self.has_model else self.config._model_name)}{model_status}</div>"
model_name = self.model.__class__.__name__ if self.has_model else self.config._model_name
header_html = f"<div class='header'>{html.escape(model_name)}{model_status}</div>"

# Config Section
config_html = self._generate_collapsible_section("Model Config", self.config, uid=uid, is_dict=True)
Expand Down Expand Up @@ -1822,7 +1826,12 @@ def _generate_collapsible_section(self, title, content, uid, is_dict=False):
)
return f"""
<div>
<span class="toggle-button" onclick="toggleVisibility('{container_id}')">&#9654;</span>
<span
class="toggle-button"
onclick="toggleVisibility('{container_id}')"
>
&#9654;
</span>
<strong>{html.escape(title)}</strong>
<div id="{container_id}" class="hidden section">
{content}
Expand All @@ -1839,7 +1848,12 @@ def _generate_nested_collapsible_sections(self, content, parent_id):
nested_content = self._generate_nested_collapsible_sections(value, nested_id)
html_content += f"""
<div>
<span class="toggle-button" onclick="toggleVisibility('{nested_id}')">&#9654;</span>
<span
class="toggle-button"
onclick="toggleVisibility('{nested_id}')"
>
&#9654;
</span>
<strong>{html.escape(key)}</strong>
<div id="{nested_id}" class="hidden section">
{nested_content}
Expand All @@ -1852,9 +1866,26 @@ def _generate_nested_collapsible_sections(self, content, parent_id):

def _generate_model_summary_table(self):
model_summary = summarize(self.model, max_depth=1)
table_html = "<table><tr><th><b>Layer</b></th><th><b>Type</b></th><th><b>Params</b></th><th><b>In sizes</b></th><th><b>Out sizes</b></th></tr>"
table_html = """
<table>
<tr>
<th><b>Layer</b></th>
<th><b>Type</b></th>
<th><b>Params</b></th>
<th><b>In sizes</b></th>
<th><b>Out sizes</b></th>
</tr>
"""
for name, layer in model_summary._layer_summary.items():
table_html += f"<tr><td>{html.escape(name)}</td><td>{html.escape(layer.layer_type)}</td><td>{html.escape(str(layer.num_parameters))}</td><td>{html.escape(str(layer.in_size))}</td><td>{html.escape(str(layer.out_size))}</td></tr>"
table_html += f"""
<tr>
<td>{html.escape(name)}</td>
<td>{html.escape(layer.layer_type)}</td>
<td>{html.escape(str(layer.num_parameters))}</td>
<td>{html.escape(str(layer.in_size))}</td>
<td>{html.escape(str(layer.out_size))}</td>
</tr>
"""
table_html += "</table>"
return table_html

Expand Down

0 comments on commit ed602ba

Please sign in to comment.