Skip to content

Commit

Permalink
get_returnn_config_serialized
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Dec 22, 2021
1 parent 6b22c42 commit be78e0b
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 4 deletions.
129 changes: 127 additions & 2 deletions nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,16 +1347,39 @@ def get_returnn_config(self) -> Dict[str, Any]:
assert not self.parent, f"{self} get_returnn_config only makes sense in the root name ctx"
net_dict = self.make_net().make_net_dict_raw()
return {
"network": net_dict,
"behavior_version": _min_returnn_behavior_version,
"extern_data": {
data_key: {
key: getattr(data, key)
for key in [*data.get_kwargs(include_special_axes=False).keys(), "available_for_inference"]
if key not in {"name"}}
for (data_key, data) in self.extern_data.items()},
"behavior_version": _min_returnn_behavior_version,
"network": net_dict,
}

def get_returnn_config_serialized(self) -> str:
"""
:return: serialized config, i.e. Python code
"""
from ..utils.pprint import pformat
config = self.get_returnn_config()
dim_tags_proxy = ReturnnDimTagsProxy()
config = dim_tags_proxy.transform_config(config)

code_lines = [
"from returnn.tf.util.data import Dim, batch_dim, SpatialDim, FeatureDim\n\n",
"use_tensorflow = True\n",
f"behavior_version = {config.pop('behavior_version')}\n\n",
f"dim_tags = {config.pop('dim_tags')!r}\n\n",
f"extern_data = {pformat(config.pop('extern_data'))}\n",
f"network = {pformat(config.pop('network'))}\n",
]
if config:
for key, value in config.items():
code_lines.append(f"{key} = {pformat(value)}\n")
code_lines.append("\n")
return "".join(code_lines)

def make_net(self) -> Net:
"""
Create new (sub) net, an instance of :class:`Net`.
Expand Down Expand Up @@ -1696,3 +1719,105 @@ def _map_layer_dict_elem(value):
f" {layer_desc!r}") from exc

return out_data


class ReturnnDimTagsProxy:
"""
When serialized via __repr__, this represents a dict unique_name -> dim tag.
All usages in the network and extern_data will also get proxies when serialized point to this dict.
"""

def __init__(self):
self.dim_tags_by_name = {} # type: Dict[str, Dim]
self.dim_tags_to_ref = {} # type: Dict[Dim, ReturnnDimTagsProxy.DimRefProxy]

def __repr__(self):
def _dim_repr(dim: Dim) -> str:
# We assume batch_dim, FeatureDim, SpatialDim and Dim are imported.
if dim.kind == Dim.Types.Batch:
return "batch_dim"
if dim.kind == Dim.Types.Feature:
return f"FeatureDim({dim.description!r}, {dim.dimension})"
if dim.kind == Dim.Types.Spatial:
if dim.dimension is not None:
return f"SpatialDim({dim.description!r}, {dim.dimension})"
else:
return f"SpatialDim({dim.description!r})"
# generic fallback
return f"Dim(kind={dim.kind}, description={dim.description!r}, dimension={dim.dimension})"

return "\n".join([
"{",
*(f" {key!r}: {_dim_repr(value)}," for key, value in self.dim_tags_by_name.items()),
"}"])

class DimRefProxy:
"""
This will be a reference to the global dim_tags __repr__.
"""
def __init__(self, *, dim: Dim, name: str, path: Tuple[Any, ...], parent: ReturnnDimTagsProxy):
self.dim = dim
self.name = name
self.path = path
self.parent = parent

def __repr__(self):
return f"dim_tags[{self.name!r}]"

class SetProxy:
"""
This represents a set but with a predefined order.
"""
def __init__(self, values: Sequence[Any]):
self.values = values

def __repr__(self):
return f"{{{', '.join(map(repr, self.values))}}}"

def transform_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
"""
Go through the config and collect all dim tags, replace them by proxies.
:return: new config
"""
# Cannot use nest because nest does not support sets. Also nest requires them to be sorted.

def _sort_key(value):
if isinstance(value, ReturnnDimTagsProxy.DimRefProxy):
return value.name
return value

def _map(path, value):
if isinstance(value, Dim):
name = '.'.join(str(key) for key in path + (value.description,))
assert name not in self.dim_tags_by_name
if value in self.dim_tags_to_ref:
ref = self.dim_tags_to_ref[value]
if "out_shape" in ref.path and "out_shape" not in path:
# Prefer path without "out_shape". Use new name.
del self.dim_tags_by_name[ref.name]
self.dim_tags_by_name[name] = value
ref.name = name
ref.path = path
return ref
self.dim_tags_by_name[name] = value
ref = ReturnnDimTagsProxy.DimRefProxy(dim=value, name=name, path=path, parent=self)
self.dim_tags_to_ref[value] = ref
return ref
if isinstance(value, dict):
return {key: _map(path + (key,), value_) for key, value_ in value.items()}
if isinstance(value, list):
return [_map(path + (i,), value_) for i, value_ in enumerate(value)]
if isinstance(value, tuple):
return tuple(_map(path + (i,), value_) for i, value_ in enumerate(value))
if nest.is_namedtuple(value):
# noinspection PyProtectedMember
return type(value)(*(_map(path + (key,), getattr(value, key)) for key in value._fields))
if isinstance(value, set):
values = [_map(path + ('_',), value_) for value_ in value]
values.sort(key=_sort_key) # this should be possible now because it would be some sortable proxies
return ReturnnDimTagsProxy.SetProxy(values)
return value

config = _map((), config)
return {"dim_tags": self, **config}
9 changes: 7 additions & 2 deletions tests/test_models_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,14 @@ def _dummy_config_net_dict(net: nn.Module, *, with_axis=False):
assert isinstance(out, nn.Layer)
out.mark_as_default_output()

config = name_ctx.get_returnn_config()
config_code = name_ctx.get_returnn_config_serialized()
print(config_code)
scope = {}
exec(config_code, scope, scope)
for tmp in ["__builtins__", "Dim", "batch_dim", "FeatureDim", "SpatialDim"]:
scope.pop(tmp)
config = scope
net_dict = config["network"]
pprint(net_dict)
return config, net_dict


Expand Down

0 comments on commit be78e0b

Please sign in to comment.