From 37704d3926df0ab8a9742d826fd2259e9a8b1b5d Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 27 Aug 2021 12:40:09 +0200 Subject: [PATCH 1/4] test_reclayer_optimize_out_dot_consistent_axes Test for #569 --- tests/test_TFNetworkRecLayer.py | 92 +++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/tests/test_TFNetworkRecLayer.py b/tests/test_TFNetworkRecLayer.py index 0489db566..8c3d4966d 100644 --- a/tests/test_TFNetworkRecLayer.py +++ b/tests/test_TFNetworkRecLayer.py @@ -3460,6 +3460,98 @@ def test_reclayer_optimize_out_dot(): rtol=1e-3) +def test_reclayer_optimize_out_dot_consistent_axes(): + # https://github.com/rwth-i6/returnn/issues/569 + # Used for multi-head dot-attention. + n_heads = 4 + n_key = 5 + n_value = 7 + n_key_total = n_heads * n_key + n_value_total = n_heads * n_value + check_reclayer_optimize_out( + {"class": "linear", "activation": None, "from": "att"}, + other_subnet_layers={ + "s": {"class": "linear", "activation": None, "with_bias": False, "from": "data:source", + "n_out": n_key_total}, # (B, D) -- Q (query). D should be same as enc_ctx + "att_query": {"class": "split_dims", "axis": "F", "dims": (n_heads, n_key), "from": "s"}, # (B, H, D/H) + # att_query is (T, B, H, D/H) outside the loop. + # Here is the main test, the dot-layer: + "energy": { + "class": "dot", + "red1": "static:-1", "red2": "static:-1", + "var1": "T", "var2": None, "add_var2_if_empty": False, + "from": ["base:enc_ctx", "att_query"]}, + # energy inside the loop will be (B, H, T). + # energy outside the loop should be (B, H, T, T). I.e. T is still the first time axis. + # The logic should be that the dot layer would add any extra axes (due to this optimization, moving layer out) + # to either common shared axes (which is implicit) or if it is only added to one input, + # then to the respective var axes. + # Note that in this test, there is no different encoder or decoder time dim. + # It still works due to time-dim-axis being set to the first source. + "att_weights": {"class": "softmax_over_spatial", "from": "energy"}, # (B, T, H) + "att0": {"class": "generic_attention", "weights": "att_weights", "base": "base:enc_value"}, # (B, H, V) + "att": {"class": "merge_dims", "axes": "static", "from": "att0"}, # (B, H*V); Use "static" here. + }, + shared_base_net={ + "encoder": {"class": "copy", "from": "data"}, + "enc_ctx0": {"class": "linear", "activation": None, "with_bias": False, "from": "encoder", + "n_out": n_key_total}, # (B, enc-T, D) + "enc_ctx": {"class": "split_dims", "axis": "F", "dims": (n_heads, n_key), + "from": "enc_ctx0", "is_output_layer": True}, # (B, enc-T, H, D/H) + "enc_value0": {"class": "linear", "activation": None, "with_bias": False, "from": "encoder", + "n_out": n_value_total}, + "enc_value": {"class": "split_dims", "axis": "F", "dims": (n_heads, n_value), + "from": "enc_value0", "is_output_layer": True}, # (B, enc-T, H, D/H) + }, + rtol=1e-3) + + +def test_reclayer_optimize_out_dot_consistent_axes_enc_dec(): + # https://github.com/rwth-i6/returnn/issues/569 + # Used for multi-head dot-attention. + n_heads = 4 + n_key = 5 + n_value = 7 + n_key_total = n_heads * n_key + n_value_total = n_heads * n_value + check_reclayer_optimize_out( + {"class": "linear", "activation": None, "from": "att"}, + other_subnet_layers={ + "s": {"class": "linear", "activation": None, "with_bias": False, "from": "data:source", + "n_out": n_key_total}, # (B, D) -- Q (query). D should be same as enc_ctx + "att_query": {"class": "split_dims", "axis": "F", "dims": (n_heads, n_key), "from": "s"}, # (B, H, D/H) + # att_query is (dec-T, B, H, D/H) outside the loop. + # Here is the main test, the dot-layer: + "energy": { + "class": "dot", + "red1": "static:-1", "red2": "static:-1", + "var1": "T", "var2": None, "add_var2_if_empty": False, + "from": ["base:enc_ctx", "att_query"]}, + # energy inside the loop will be (B, H, enc-T). + # energy outside the loop should be (B, H, enc-T, dec-T). I.e. enc-T is still the first time axis. + # The logic should be that the dot layer would add any extra axes (due to this optimization, moving layer out) + # to either common shared axes (which is implicit) or if it is only added to one input, + # then to the respective var axes. + "att_weights": {"class": "softmax_over_spatial", "from": "energy"}, # (B, enc-T, H) + "att0": {"class": "generic_attention", "weights": "att_weights", "base": "base:enc_value"}, # (B, H, V) + "att": {"class": "merge_dims", "axes": "static", "from": "att0"}, # (B, H*V); Use "static" here. + }, + shared_base_net={ + # Use conv with padding valid to make sure we get another time dim, + # such that the rec part above will not confuse this time dim with the rec time dim. + "encoder": {"class": "conv", "from": "data", "filter_size": [3], "padding": "valid", "n_out": 5}, + "enc_ctx0": {"class": "linear", "activation": None, "with_bias": False, "from": "encoder", + "n_out": n_key_total}, # (B, enc-T, D) + "enc_ctx": {"class": "split_dims", "axis": "F", "dims": (n_heads, n_key), + "from": "enc_ctx0", "is_output_layer": True}, # (B, enc-T, H, D/H) + "enc_value0": {"class": "linear", "activation": None, "with_bias": False, "from": "encoder", + "n_out": n_value_total}, + "enc_value": {"class": "split_dims", "axis": "F", "dims": (n_heads, n_value), + "from": "enc_value0", "is_output_layer": True}, # (B, enc-T, H, D/H) + }, + rtol=1e-3) + + def test_reclayer_optimize_out_dot_kv_in_rec(): # Same as test_reclayer_optimize_out_dot, but with the att key/value layers declared INSIDE the rec layer. AttNumHeads = 4 From 89ee2dab2ae4ec6a0381637eb488303252325c70 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 3 Sep 2021 23:48:33 +0200 Subject: [PATCH 2/4] Data.get_axes_from_description small fix for dim tag --- returnn/tf/util/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 636062eaa..7e3d933f8 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -2902,7 +2902,7 @@ def get_axes_from_description(self, axes, allow_int=True): if isinstance(i, int): flat_axes += [i] else: - assert isinstance(i, (str, tuple, list)) + assert isinstance(i, (str, tuple, list, DimensionTag)) flat_axes += self.get_axes_from_description(i) flat_axes = [i % self.batch_ndim for i in flat_axes] res = [] From 8e569779bef2d115fd234e8a8a182e3eaef33c54 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Sat, 4 Sep 2021 00:40:58 +0200 Subject: [PATCH 3/4] network get_all_rec_time_dims --- returnn/tf/layers/rec.py | 4 ++++ returnn/tf/network.py | 29 +++++++++++++++++++++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/returnn/tf/layers/rec.py b/returnn/tf/layers/rec.py index 1f0b29b69..a6cc5980a 100644 --- a/returnn/tf/layers/rec.py +++ b/returnn/tf/layers/rec.py @@ -3094,6 +3094,8 @@ def _construct_input_layers_moved_out(self): extern_data=ExternData(), train_flag=self.parent_net.train_flag, search_flag=self.parent_net.search_flag, + over_rec_time_dim=self.time_dim_tag, + over_rec_time_dim_subs=self._time_dim_tags, parent_layer=self.parent_rec_layer, parent_net=self.parent_net) self.input_layers_net.is_root_in_ctx = True @@ -3182,6 +3184,8 @@ def _construct_output_layers_moved_out(self, loop_accumulated, seq_len, extra_ou extern_data=ExternData(), train_flag=self.parent_net.train_flag, search_flag=self.parent_net.search_flag, + over_rec_time_dim=self.time_dim_tag, + over_rec_time_dim_subs=self._time_dim_tags, parent_layer=self.parent_rec_layer, parent_net=self.parent_net) self.output_layers_net.is_root_in_ctx = True diff --git a/returnn/tf/network.py b/returnn/tf/network.py index acb1e3881..f8c2b9f2f 100644 --- a/returnn/tf/network.py +++ b/returnn/tf/network.py @@ -353,7 +353,7 @@ class TFNetwork(object): def __init__(self, config=None, extern_data=None, rnd_seed=None, train_flag=None, eval_flag=None, search_flag=None, parent_layer=None, parent_net=None, extra_parent_net=None, extra_name_prefix=None, - inside_rec_time_dim=None, + inside_rec_time_dim=None, over_rec_time_dim=None, over_rec_time_dim_subs=None, absolute_name_prefix=None, name=None): """ :param returnn.config.Config config: only needed to init extern_data if not specified explicitly @@ -367,7 +367,9 @@ def __init__(self, config=None, extern_data=None, rnd_seed=None, :param TFNetwork|None extra_parent_net: we are on the same level (not really a child), but an "extra" net of extra_parent_net :param str|None extra_name_prefix: - :param DimensionTag|None inside_rec_time_dim: + :param DimensionTag|None inside_rec_time_dim: dim tag of outer rec layer, when run inside the loop (not optimized) + :param DimensionTag|None over_rec_time_dim: dim tag of outer rec layer, when optimized out of the loop + :param set[DimensionTag]|None over_rec_time_dim_subs: outer rec layer, out of loop, potential shorter :param str|None absolute_name_prefix: :param str name: only for debugging """ @@ -428,6 +430,8 @@ def __init__(self, config=None, extern_data=None, rnd_seed=None, self.parent_layer = parent_layer self.parent_net = parent_net self._inside_rec_time_dim = inside_rec_time_dim + self._over_rec_time_dim = over_rec_time_dim + self._over_rec_time_dim_subs = over_rec_time_dim_subs self.extra_parent_net = extra_parent_net self.extra_name_prefix = extra_name_prefix self.extra_deps_in_extra = False @@ -2173,6 +2177,10 @@ def get_inside_rec_time_dim(self, inside_loop=True): """ if self._inside_rec_time_dim: return self._inside_rec_time_dim + if self._over_rec_time_dim: + if inside_loop: + return None + return self._over_rec_time_dim from returnn.tf.layers.rec import RecLayer if isinstance(self.parent_layer, RecLayer): # When we get here (and not in the if-branch above on _inside_rec_time_dim), @@ -2186,6 +2194,23 @@ def get_inside_rec_time_dim(self, inside_loop=True): return self.parent_net.get_inside_rec_time_dim(inside_loop=inside_loop) return None + def get_all_rec_time_dims(self): + """ + :return: all rec time dims, moved out or not, including all parents + :rtype: set[DimensionTag] + """ + coll = set() + net = self + while net: + if net._inside_rec_time_dim: + coll.add(net._inside_rec_time_dim) + if net._over_rec_time_dim: + coll.add(net._over_rec_time_dim) + if net._over_rec_time_dim_subs: + coll.update(net._over_rec_time_dim_subs) + net = net.parent_net + return coll + def _is_rec_layer_inside_net(self): """ :rtype: bool From 6136a9e400e4b175152bebe6f56134d16467d5b0 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Sat, 4 Sep 2021 00:42:19 +0200 Subject: [PATCH 4/4] DotLayer transform maybe var1/var2 when optimized out of loop Fix #569. --- returnn/tf/layers/basic.py | 44 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index d3c2e2191..184f56531 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -5478,6 +5478,50 @@ def _axis2_to_output(axis, b_rem_axes, a_var_axes, b_var_axes): return None return out_axes.index(axis) + @classmethod + def transform_config_dict(cls, d, network, get_layer): + """ + :param dict[str] d: will modify inplace + :param returnn.tf.network.TFNetwork network: + :param ((str) -> LayerBase) get_layer: function to get or construct another layer + """ + super(DotLayer, cls).transform_config_dict(d, network=network, get_layer=get_layer) + rec_time_dims = network.get_all_rec_time_dims() + if rec_time_dims: + assert len(d["sources"]) == 2, "dot-layer %r: needs exactly two sources" % (d["name"],) + src1, src2 = d["sources"] + assert isinstance(src1, LayerBase) and isinstance(src2, LayerBase) + # Maybe we want to add some of the outer rec layer dims to the var1/var2 list, + # or use those rec layer dims as further common dims (implicitly). + dims1 = set(tag for tag in rec_time_dims if tag in src1.output.dim_tags) + dims2 = set(tag for tag in rec_time_dims if tag in src2.output.dim_tags) + # If the rec layer dim is the same as some other dim, + # and was already explicitly specified in var1/var2 before, + # skip it. + var1 = d.get("var1", -2) # the default should really not be used... + var2 = d.get("var2", -1) # the default should really not be used... + var1_ = set([src1.output.dim_tags[i] for i in src1.output.get_axes_from_description(var1)]) + var2_ = set([src2.output.dim_tags[i] for i in src2.output.get_axes_from_description(var2)]) + dims1.difference_update(var1_) + dims2.difference_update(var2_) + # The common dims should be shared. The shared common dims are implicit, so nothing to do about them. + dims_common = dims1.intersection(dims2) + # Those are dims which should be added to var1/var2. + dims1.difference_update(dims_common) + dims2.difference_update(dims_common) + + def _add(dims, val, d_key): + if not dims: + return + if val is None or val == "": + val = [] + elif not isinstance(val, (tuple, list)): + val = [val] + d[d_key] = val + type(val)(dims) + + _add(dims1, var1, "var1") + _add(dims2, var2, "var2") + @classmethod def get_out_data_from_opts(cls, name, sources, red1=-1, red2=-2, var1=-2, var2=-1, add_var2_if_empty=True, **kwargs): """