Skip to content

Commit

Permalink
LayerBase.fixup_out_data, prepare for other layer dict args
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Nov 22, 2021
1 parent aba4232 commit bff669c
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion returnn/tf/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def _post_init_output(cls, output, network, target=None, size_target=None, _targ
output.available_for_inference = False

@classmethod
def fixup_out_data(cls, output, network):
def fixup_out_data(cls, output, network, **_kwargs):
"""
This is called after get_out_data_from_opts, to fixup incomplete information.
E.g. we can patch batch or beam information here
Expand Down
6 changes: 3 additions & 3 deletions returnn/tf/layers/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,13 +1250,13 @@ def add_templated_layer(lself, name, layer_class, **layer_desc):
layer_desc["network"] = self.net
old_layer_kwargs = layer_.kwargs
layer_.kwargs = layer_desc.copy() # set it now already for better debugging
if "output" not in layer_desc:
if "output" not in layer_.kwargs:
if old_layer_kwargs and "output" in old_layer_kwargs:
# First copy old output. Maybe the get_out_data_from_opts raises an exception,
# and we don't want this to be unset.
layer_.kwargs["output"] = old_layer_kwargs["output"]
layer_.kwargs["output"] = layer_class.get_out_data_from_opts(**layer_desc)
layer_.kwargs["output"] = layer_class.fixup_out_data(layer_.kwargs["output"], network=self.net)
layer_.kwargs["output"] = layer_class.fixup_out_data(**layer_.kwargs)
layer_.kwargs["output"].sanity_check(ignore_placeholder=True) # placeholder might be overwritten later
layer_.init(layer_class=layer_class, **layer_.kwargs)
if layer_.need_last:
Expand Down Expand Up @@ -1529,7 +1529,7 @@ def _add_template_layer(self, layer_name, layer_dict):
layer_class.transform_config_dict(
layer_dict, network=self.net, get_layer=lambda _name: self.layer_data_templates[_name])
out = layer_class.get_out_data_from_opts(name=layer_name, network=self.net, **layer_dict)
out = layer_class.fixup_out_data(output=out, network=self.net)
out = layer_class.fixup_out_data(output=out, network=self.net, **layer_dict)
layer.init(output=out, layer_class=layer_class, **layer_dict)
self.layer_data_templates[layer_name] = layer
return layer
Expand Down
2 changes: 1 addition & 1 deletion returnn/tf/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,7 @@ def _create_layer(self, name, layer_class, **layer_desc):
output_template = layer_desc["output"]
assert isinstance(output_template, Data), "%s %r layer_desc %r ['output'] is not a Data instance" % (
layer_class.__name__, name, layer_desc)
output_template = layer_class.fixup_out_data(output_template, network=self)
output_template = layer_class.fixup_out_data(**layer_desc)
layer_desc["output"] = output_template
print(
"layer %s/%r output: %r" % (self.name, name, output_template),
Expand Down

0 comments on commit bff669c

Please sign in to comment.