Skip to content

Commit

Permalink
Fix extra summary column when show_trainable=True (#435)
Browse files Browse the repository at this point in the history
Silly bug, we were literally just adding the trainable field twice
  • Loading branch information
mattdangerw authored Jul 10, 2023
1 parent e65d025 commit c54f806
Showing 1 changed file with 22 additions and 24 deletions.
46 changes: 22 additions & 24 deletions keras_core/utils/summary_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,26 @@ def print_summary(

table = rich.table.Table(*columns, width=line_length, show_lines=True)

def get_connections(layer):
connections = ""
for node in layer._inbound_nodes:
if relevant_nodes and node not in relevant_nodes:
# node is not part of the current network
continue
for kt in node.input_tensors:
keras_history = kt._keras_history
inbound_layer = keras_history.operation
node_index = highlight_number(keras_history.node_index)
tensor_index = highlight_number(keras_history.tensor_index)
if connections:
connections += ", "
connections += (
f"{inbound_layer.name}[{node_index}][{tensor_index}]"
)
if not connections:
connections = "-"
return connections

def get_layer_fields(layer, prefix=""):
output_shape = format_layer_shape(layer)
name = prefix + layer.name
Expand All @@ -230,6 +250,8 @@ def get_layer_fields(layer, prefix=""):
params = highlight_number(f"{layer.count_params():,}")

fields = [name, output_shape, params]
if not sequential_like:
fields.append(get_connections(layer))
if show_trainable:
fields.append(
bold_text("Y", color=34)
Expand All @@ -238,37 +260,13 @@ def get_layer_fields(layer, prefix=""):
)
return fields

def get_connections(layer):
connections = ""
for node in layer._inbound_nodes:
if relevant_nodes and node not in relevant_nodes:
# node is not part of the current network
continue
for kt in node.input_tensors:
keras_history = kt._keras_history
inbound_layer = keras_history.operation
node_index = highlight_number(keras_history.node_index)
tensor_index = highlight_number(keras_history.tensor_index)
if connections:
connections += ", "
connections += (
f"{inbound_layer.name}[{node_index}][{tensor_index}]"
)
if not connections:
connections = "-"
return connections

def print_layer(layer, nested_level=0):
if nested_level:
prefix = " " * nested_level + "└" + " "
else:
prefix = ""

fields = get_layer_fields(layer, prefix=prefix)
if not sequential_like:
fields.append(get_connections(layer))
if show_trainable:
fields.append("Y" if layer.trainable else "N")

rows = [fields]
if expand_nested and hasattr(layer, "layers") and layer.layers:
Expand Down

0 comments on commit c54f806

Please sign in to comment.