Skip to content

Commit

Permalink
minor changes, added documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
sidps committed Jul 6, 2015
1 parent 6978ab3 commit a10827c
Showing 1 changed file with 30 additions and 15 deletions.
45 changes: 30 additions & 15 deletions utils/network_repr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import lasagne
from lasagne.layers import get_all_layers
from copy import deepcopy
from collections import deque, defaultdict


Expand All @@ -14,7 +13,7 @@ def get_network_str(layer, get_network=True, incomings=False, outgoings=False):
into it, or a list of :class:`Layer` instances.
get_network : boolean
if True, calls `get_all_layers` on `layer` with param `treat_as_input`
if True, calls `get_all_layers` on `layer`
if False, assumes `layer` already contains all `Layer` instances intended for representation
incomings : boolean
Expand All @@ -26,48 +25,60 @@ def get_network_str(layer, get_network=True, incomings=False, outgoings=False):
Returns
-------
str
A string representation of `layer`
A string representation of `layer`. Each layer is assigned an ID which is it's corresponding index
in the list obtained from `get_all_layers`.
"""

# `layer` can either be a single `Layer` instance or a list of `Layer` instances.
# If list, it can already be the result from `get_all_layers` or not, indicated by the `get_network` flag
# Get network using get_all_layers if required:
if get_network:
network = get_all_layers(layer)
else:
network = deepcopy(layer)
network = layer

# Initialize a list of lists to (temporarily) hold the str representation of each component, insert header
network_str = deque([])
network_str = _insert_header(network_str, incomings=incomings, outgoings=outgoings)

# The representation can optionally display incoming and outgoing layers for each layer, similar to adjacency lists.
# If requested (using the incomings and outgoings flags), build the adjacency lists.
# The numbers/ids in the adjacency lists correspond to the layer's index in `network`
if incomings or outgoings:
ins, outs = _get_adjacency_lists(network)

# For each layer in the network, build a representation and append to `network_str`
for i, current_layer in enumerate(network):

# Initialize list to (temporarily) hold str of layer
layer_str = deque([])

# First column for incomings, second for the layer itself, third for outgoings, fourth for layer description
if incomings:
layer_str.append(ins[i])
layer_str.append(i)
if outgoings:
layer_str.append(outs[i])
if type(current_layer).__str__ is current_layer.__str__:
layer_str.append(str(current_layer))
else:
layer_str.append(type(current_layer).__name__)
layer_str.append(str(current_layer)) # default representation can be changed by overriding __str__
network_str.append(layer_str)
return _get_table_str(network_str)


def _insert_header(network_str, incomings, outgoings):
""" Insert the first two lines in the representation."""
""" Insert the header (first two lines) in the representation."""
line_1 = deque([])
if incomings:
line_1.append('Incomings -->')
line_1.append('In -->')
line_1.append('Layer')
if outgoings:
line_1.append('--> Outgoings')
line_1.append('--> Out')
line_1.append('Description')
line_2 = deque([])
if incomings:
line_2.append('--------- ')
line_2.append('-------')
line_2.append('-----')
if outgoings:
line_2.append(' ---------')
line_2.append('-------')
line_2.append('-----------')
network_str.appendleft(line_2)
network_str.appendleft(line_1)
Expand All @@ -77,6 +88,8 @@ def _insert_header(network_str, incomings, outgoings):
def _get_adjacency_lists(network):
""" Returns adjacency lists for each layer (node) in network.
Warning: Assumes repr is unique to a layer instance, else this entire approach WILL fail."""
# ins is a dict, keys are layer indices and values are lists of incoming layer indices
# outs is a dict, keys are layer indices and values are lists of outgoing layer indices
ins = defaultdict(list)
outs = defaultdict(list)
lookup = {repr(layer): index for index, layer in enumerate(network)}
Expand All @@ -97,12 +110,12 @@ def _get_adjacency_lists(network):


def _get_table_str(table):
""" Pretty print a table provided as a list of rows."""
""" Pretty print a table provided as a list of lists."""
table_str = ''
col_size = [max(len(str(val)) for val in column) for column in zip(*table)]
for line in table:
table_str += '\n'
table_str += ' '.join('{0:^{1}}'.format(val, col_size[i]) for i, val in enumerate(line))
table_str += ' '.join('{0:<{1}}'.format(val, col_size[i]) for i, val in enumerate(line))
return table_str


Expand Down Expand Up @@ -140,9 +153,11 @@ def example2():


def main():
print('===========================================================')
example1()
print('===========================================================')
example2()
print('===========================================================')
return None

if __name__ == '__main__':
Expand Down

0 comments on commit a10827c

Please sign in to comment.