Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #666, broadcast no longer matches dims #864

Merged
merged 4 commits into from
Dec 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/configuration_reference/behavior_version.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ and not listing legacy/deprecated parameters.
Version History
---------------

Behavior version 11 (2021-12-16)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Broadcasting dims no longer match in :class:`CombineLayer` and others.
This was never needed, instead broadcasting happens in RETURNN automatically to non-existing dims.
To fix this, do not add any broadcasting dims.

See issue `#666 <https://github.com/rwth-i6/returnn/issues/666>`__.

Behavior version 10 (2021-12-07)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
15 changes: 12 additions & 3 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,10 @@ def _get_common_input_position_axes(cls, input_data, position_data, old_gather_a
:return: (common_axes_input, common_axes_position, specific_input_axes, specific_position_axes), all counted with
batch dim.
"""
is_equal_opts = dict(allow_same_spatial_dim=True, broadcast_matches=True)
from returnn.util import BehaviorVersion
is_equal_opts = dict(allow_same_spatial_dim=True)
if BehaviorVersion.get() < 11:
is_equal_opts["broadcast_matches"] = True
all_dim_tags, tags_dict = Dim.get_all_dimension_tags([input_data, position_data], is_equal_opts=is_equal_opts)
input_tags, pos_tags = tags_dict[input_data], tags_dict[position_data]
specific_input_axes = [i for i, tag in enumerate(input_tags) if tag not in pos_tags and i != old_gather_axis]
Expand Down Expand Up @@ -3771,7 +3774,10 @@ def get_out_data_from_opts(cls, name, axis, dim=1, sources=(), **kwargs):
data = data.copy_as_batch_major()
axis = cls._get_axis(data=data, axis=axis)

new_dim = SpatialDim("%s_expand_dims" % name, dim)
new_dim = Dim(
kind=Dim.Types.Feature if init_axis.lower() == "f" else Dim.Types.Spatial,
description="%s_expand_dims" % name,
dimension=dim)
data = data.copy_template(name="%s_output" % name)
data = data.copy_add_dim_by_tag(new_dim, unbroadcast=True, axis=axis)
if isinstance(init_axis, str):
Expand Down Expand Up @@ -6420,9 +6426,12 @@ def _auto_var_axes(source1, source2, red1, red2):
:return: var1 tags, var2 tags
:rtype: (list[Dim], list[Dim])
"""
from returnn.util import BehaviorVersion
is_equal_opts = dict(
treat_feature_as_spatial=True, allow_same_spatial_dim=True,
broadcast_matches=True, undefined_matches=True, derived_matches=True)
undefined_matches=True, derived_matches=True)
if BehaviorVersion.get() < 11:
is_equal_opts["broadcast_matches"] = True
all_dim_tags, tags_dict = Dim.get_all_dimension_tags([source1, source2], is_equal_opts=is_equal_opts)
tags1, tags2 = tags_dict[source1], tags_dict[source2]
var1 = [tag for i, tag in enumerate(tags1) if tag not in tags2 and i not in red1]
Expand Down
5 changes: 4 additions & 1 deletion returnn/tf/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5049,6 +5049,7 @@ def get_common_data(cls, sources, ignore_feature_dim=False, allow_broadcast_all_
This is always a template, and a new copy.
:rtype: Data|None
"""
from returnn.util import BehaviorVersion
if not sources:
return None
assert sources
Expand All @@ -5067,8 +5068,10 @@ def get_common_data(cls, sources, ignore_feature_dim=False, allow_broadcast_all_
common.beam = SearchBeam.get_combined_beam(*[s.beam for s in sources])
is_equal_opts = dict(
ignore_feature_dim=ignore_feature_dim, treat_feature_as_spatial=True,
allow_same_spatial_dim=True, broadcast_matches=True,
allow_same_spatial_dim=True,
undefined_matches=True, derived_matches=True)
if BehaviorVersion.get() < 11:
is_equal_opts["broadcast_matches"] = True
all_dim_tags, tags_dict = Dim.get_all_dimension_tags(sources, is_equal_opts=is_equal_opts)
# Check for potential undefined tags, and replace those with defined tags if possible.
for axis, dim_tag in enumerate(common.dim_tags):
Expand Down
2 changes: 1 addition & 1 deletion returnn/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ class BehaviorVersion:
The version will be set after the config is defined at __main__.init_config() or Engine.__init__()
"""

_latest_behavior_version = 10
_latest_behavior_version = 11
_behavior_version = None # type: typing.Optional[int]

@classmethod
Expand Down
27 changes: 23 additions & 4 deletions tests/test_TFNetworkLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,8 @@ def test_CombineLayer_broadcast():
net_dict = {
"lin1": {"class": "linear", "activation": "sigmoid", "n_out": 5, "from": "data:data"},
"lin2": {"class": "linear", "activation": "sigmoid", "n_out": 1, "from": "data:data"},
"combine": {"class": "combine", "kind": "add", "from": ["lin1", "lin2"]},
"lin2_squeeze": {"class": "squeeze", "from": "lin2", "axis": "f"},
Zettelkasten marked this conversation as resolved.
Show resolved Hide resolved
"combine": {"class": "combine", "kind": "add", "from": ["lin1", "lin2_squeeze"]},
"output": {"class": "softmax", "loss": "ce", "from": "combine"}
}
config = Config({"debug_print_layer_output_template": True})
Expand All @@ -939,7 +940,7 @@ def test_CombineLayer_broadcast_multiple():
with make_scope() as session:
net_dict = {
"p1": {"class": "variable", "shape": (5, 5, 3), "add_batch_axis": False},
"p2": {"class": "variable", "shape": (5, 1, 1), "add_batch_axis": False},
"p2": {"class": "variable", "shape": (5,), "add_batch_axis": False},
"combine": {"class": "combine", "kind": "add", "from": ["p1", "p2"]},
"output": {"class": "softmax", "loss": "ce", "from": "combine"}
}
Expand Down Expand Up @@ -1275,7 +1276,7 @@ def test_CombineLayer_time_broadcast():
config = Config({
"debug_print_layer_output_template": True,
"extern_data": {
"in1": {"shape": (n_features, 1), "batch_dim_axis": None, "time_dim_axis": None, "feature_dim_axis": 0},
"in1": {"shape": (n_features,), "batch_dim_axis": None, "time_dim_axis": None, "feature_dim_axis": 0},
"in2": {"shape": (n_features, None), "batch_dim_axis": 0, "time_dim_axis": 2}
}
})
Expand All @@ -1299,7 +1300,7 @@ def test_CombineLayer_time_broadcast_swapped():
"debug_print_layer_output_template": True,
"extern_data": {
"in1": {"shape": (n_features, None), "batch_dim_axis": 0, "time_dim_axis": 2},
"in2": {"shape": (n_features, 1), "batch_dim_axis": None, "time_dim_axis": None, "feature_dim_axis": 0},
"in2": {"shape": (n_features,), "batch_dim_axis": None, "time_dim_axis": None, "feature_dim_axis": 0},
}
})
network = TFNetwork(config=config, train_flag=True)
Expand Down Expand Up @@ -3399,6 +3400,24 @@ def test_GatherLayer_search_beam():
"initial_output": 0}}}})


def test_GatherLayer_broadcast_dim():
from returnn.tf.util.data import batch_dim
head_dim = SpatialDim("head", 1) # previously, this dim would match all others and therefore fail.
round_dim = SpatialDim("round", 2)
chunk_dim = SpatialDim("chunk")
time_dim = SpatialDim("time")
config = Config({"extern_data": {
"source": {"dim_tags": [batch_dim, head_dim, time_dim]},
"position": {"dim_tags": [batch_dim, head_dim, round_dim, chunk_dim], "dtype": "int32"}},
"debug_print_layer_output_template": True})
net = TFNetwork(config=config)
net.construct_from_dict({
"output": {
'class': 'gather', 'from': 'data:source', 'position': 'data:position', 'axis': time_dim,
'out_shape': {batch_dim, head_dim, round_dim, chunk_dim}}
})


def test_SliceNdLayer():
n_batch = 5
n_time = 7
Expand Down
2 changes: 1 addition & 1 deletion tests/test_TFUtil.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def test_Data_get_common_data_extra_static_spatial():

def test_Data_get_common_data_broadcast_multiple():
d1 = Data(name='d_orig', shape=(5, 5, 3), dtype='float32', batch_dim_axis=None)
d2 = Data(name='d_bc', shape=(5, 1, 1), dtype='float32', batch_dim_axis=None)
d2 = Data(name='d_bc', shape=(5,), dtype='float32', batch_dim_axis=None)
common = Data.get_common_data([d1, d2])
assert d1.shape == common.shape

Expand Down