Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DotLayer transform maybe var1/var2 when optimized out of loop #628

Merged
merged 4 commits into from
Sep 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5478,6 +5478,50 @@ def _axis2_to_output(axis, b_rem_axes, a_var_axes, b_var_axes):
return None
return out_axes.index(axis)

@classmethod
def transform_config_dict(cls, d, network, get_layer):
"""
:param dict[str] d: will modify inplace
:param returnn.tf.network.TFNetwork network:
:param ((str) -> LayerBase) get_layer: function to get or construct another layer
"""
super(DotLayer, cls).transform_config_dict(d, network=network, get_layer=get_layer)
rec_time_dims = network.get_all_rec_time_dims()
if rec_time_dims:
assert len(d["sources"]) == 2, "dot-layer %r: needs exactly two sources" % (d["name"],)
src1, src2 = d["sources"]
assert isinstance(src1, LayerBase) and isinstance(src2, LayerBase)
# Maybe we want to add some of the outer rec layer dims to the var1/var2 list,
# or use those rec layer dims as further common dims (implicitly).
dims1 = set(tag for tag in rec_time_dims if tag in src1.output.dim_tags)
dims2 = set(tag for tag in rec_time_dims if tag in src2.output.dim_tags)
# If the rec layer dim is the same as some other dim,
# and was already explicitly specified in var1/var2 before,
# skip it.
var1 = d.get("var1", -2) # the default should really not be used...
var2 = d.get("var2", -1) # the default should really not be used...
var1_ = set([src1.output.dim_tags[i] for i in src1.output.get_axes_from_description(var1)])
var2_ = set([src2.output.dim_tags[i] for i in src2.output.get_axes_from_description(var2)])
dims1.difference_update(var1_)
dims2.difference_update(var2_)
# The common dims should be shared. The shared common dims are implicit, so nothing to do about them.
dims_common = dims1.intersection(dims2)
# Those are dims which should be added to var1/var2.
dims1.difference_update(dims_common)
dims2.difference_update(dims_common)

def _add(dims, val, d_key):
if not dims:
return
if val is None or val == "":
val = []
elif not isinstance(val, (tuple, list)):
val = [val]
d[d_key] = val + type(val)(dims)

_add(dims1, var1, "var1")
_add(dims2, var2, "var2")

@classmethod
def get_out_data_from_opts(cls, name, sources, red1=-1, red2=-2, var1=-2, var2=-1, add_var2_if_empty=True, **kwargs):
"""
Expand Down
4 changes: 4 additions & 0 deletions returnn/tf/layers/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3094,6 +3094,8 @@ def _construct_input_layers_moved_out(self):
extern_data=ExternData(),
train_flag=self.parent_net.train_flag,
search_flag=self.parent_net.search_flag,
over_rec_time_dim=self.time_dim_tag,
over_rec_time_dim_subs=self._time_dim_tags,
parent_layer=self.parent_rec_layer,
parent_net=self.parent_net)
self.input_layers_net.is_root_in_ctx = True
Expand Down Expand Up @@ -3182,6 +3184,8 @@ def _construct_output_layers_moved_out(self, loop_accumulated, seq_len, extra_ou
extern_data=ExternData(),
train_flag=self.parent_net.train_flag,
search_flag=self.parent_net.search_flag,
over_rec_time_dim=self.time_dim_tag,
over_rec_time_dim_subs=self._time_dim_tags,
parent_layer=self.parent_rec_layer,
parent_net=self.parent_net)
self.output_layers_net.is_root_in_ctx = True
Expand Down
29 changes: 27 additions & 2 deletions returnn/tf/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ class TFNetwork(object):
def __init__(self, config=None, extern_data=None, rnd_seed=None,
train_flag=None, eval_flag=None, search_flag=None,
parent_layer=None, parent_net=None, extra_parent_net=None, extra_name_prefix=None,
inside_rec_time_dim=None,
inside_rec_time_dim=None, over_rec_time_dim=None, over_rec_time_dim_subs=None,
absolute_name_prefix=None, name=None):
"""
:param returnn.config.Config config: only needed to init extern_data if not specified explicitly
Expand All @@ -367,7 +367,9 @@ def __init__(self, config=None, extern_data=None, rnd_seed=None,
:param TFNetwork|None extra_parent_net: we are on the same level (not really a child),
but an "extra" net of extra_parent_net
:param str|None extra_name_prefix:
:param DimensionTag|None inside_rec_time_dim:
:param DimensionTag|None inside_rec_time_dim: dim tag of outer rec layer, when run inside the loop (not optimized)
:param DimensionTag|None over_rec_time_dim: dim tag of outer rec layer, when optimized out of the loop
:param set[DimensionTag]|None over_rec_time_dim_subs: outer rec layer, out of loop, potential shorter
:param str|None absolute_name_prefix:
:param str name: only for debugging
"""
Expand Down Expand Up @@ -428,6 +430,8 @@ def __init__(self, config=None, extern_data=None, rnd_seed=None,
self.parent_layer = parent_layer
self.parent_net = parent_net
self._inside_rec_time_dim = inside_rec_time_dim
self._over_rec_time_dim = over_rec_time_dim
self._over_rec_time_dim_subs = over_rec_time_dim_subs
self.extra_parent_net = extra_parent_net
self.extra_name_prefix = extra_name_prefix
self.extra_deps_in_extra = False
Expand Down Expand Up @@ -2173,6 +2177,10 @@ def get_inside_rec_time_dim(self, inside_loop=True):
"""
if self._inside_rec_time_dim:
return self._inside_rec_time_dim
if self._over_rec_time_dim:
if inside_loop:
return None
return self._over_rec_time_dim
from returnn.tf.layers.rec import RecLayer
if isinstance(self.parent_layer, RecLayer):
# When we get here (and not in the if-branch above on _inside_rec_time_dim),
Expand All @@ -2186,6 +2194,23 @@ def get_inside_rec_time_dim(self, inside_loop=True):
return self.parent_net.get_inside_rec_time_dim(inside_loop=inside_loop)
return None

def get_all_rec_time_dims(self):
"""
:return: all rec time dims, moved out or not, including all parents
:rtype: set[DimensionTag]
"""
coll = set()
net = self
while net:
if net._inside_rec_time_dim:
coll.add(net._inside_rec_time_dim)
if net._over_rec_time_dim:
coll.add(net._over_rec_time_dim)
if net._over_rec_time_dim_subs:
coll.update(net._over_rec_time_dim_subs)
net = net.parent_net
return coll

def _is_rec_layer_inside_net(self):
"""
:rtype: bool
Expand Down
2 changes: 1 addition & 1 deletion returnn/tf/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2902,7 +2902,7 @@ def get_axes_from_description(self, axes, allow_int=True):
if isinstance(i, int):
flat_axes += [i]
else:
assert isinstance(i, (str, tuple, list))
assert isinstance(i, (str, tuple, list, DimensionTag))
flat_axes += self.get_axes_from_description(i)
flat_axes = [i % self.batch_ndim for i in flat_axes]
res = []
Expand Down
92 changes: 92 additions & 0 deletions tests/test_TFNetworkRecLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3460,6 +3460,98 @@ def test_reclayer_optimize_out_dot():
rtol=1e-3)


def test_reclayer_optimize_out_dot_consistent_axes():
# https://github.com/rwth-i6/returnn/issues/569
# Used for multi-head dot-attention.
n_heads = 4
n_key = 5
n_value = 7
n_key_total = n_heads * n_key
n_value_total = n_heads * n_value
check_reclayer_optimize_out(
{"class": "linear", "activation": None, "from": "att"},
other_subnet_layers={
"s": {"class": "linear", "activation": None, "with_bias": False, "from": "data:source",
"n_out": n_key_total}, # (B, D) -- Q (query). D should be same as enc_ctx
"att_query": {"class": "split_dims", "axis": "F", "dims": (n_heads, n_key), "from": "s"}, # (B, H, D/H)
# att_query is (T, B, H, D/H) outside the loop.
# Here is the main test, the dot-layer:
"energy": {
"class": "dot",
"red1": "static:-1", "red2": "static:-1",
"var1": "T", "var2": None, "add_var2_if_empty": False,
"from": ["base:enc_ctx", "att_query"]},
# energy inside the loop will be (B, H, T).
# energy outside the loop should be (B, H, T, T). I.e. T is still the first time axis.
# The logic should be that the dot layer would add any extra axes (due to this optimization, moving layer out)
# to either common shared axes (which is implicit) or if it is only added to one input,
# then to the respective var axes.
# Note that in this test, there is no different encoder or decoder time dim.
# It still works due to time-dim-axis being set to the first source.
"att_weights": {"class": "softmax_over_spatial", "from": "energy"}, # (B, T, H)
"att0": {"class": "generic_attention", "weights": "att_weights", "base": "base:enc_value"}, # (B, H, V)
"att": {"class": "merge_dims", "axes": "static", "from": "att0"}, # (B, H*V); Use "static" here.
},
shared_base_net={
"encoder": {"class": "copy", "from": "data"},
"enc_ctx0": {"class": "linear", "activation": None, "with_bias": False, "from": "encoder",
"n_out": n_key_total}, # (B, enc-T, D)
"enc_ctx": {"class": "split_dims", "axis": "F", "dims": (n_heads, n_key),
"from": "enc_ctx0", "is_output_layer": True}, # (B, enc-T, H, D/H)
"enc_value0": {"class": "linear", "activation": None, "with_bias": False, "from": "encoder",
"n_out": n_value_total},
"enc_value": {"class": "split_dims", "axis": "F", "dims": (n_heads, n_value),
"from": "enc_value0", "is_output_layer": True}, # (B, enc-T, H, D/H)
},
rtol=1e-3)


def test_reclayer_optimize_out_dot_consistent_axes_enc_dec():
# https://github.com/rwth-i6/returnn/issues/569
# Used for multi-head dot-attention.
n_heads = 4
n_key = 5
n_value = 7
n_key_total = n_heads * n_key
n_value_total = n_heads * n_value
check_reclayer_optimize_out(
{"class": "linear", "activation": None, "from": "att"},
other_subnet_layers={
"s": {"class": "linear", "activation": None, "with_bias": False, "from": "data:source",
"n_out": n_key_total}, # (B, D) -- Q (query). D should be same as enc_ctx
"att_query": {"class": "split_dims", "axis": "F", "dims": (n_heads, n_key), "from": "s"}, # (B, H, D/H)
# att_query is (dec-T, B, H, D/H) outside the loop.
# Here is the main test, the dot-layer:
"energy": {
"class": "dot",
"red1": "static:-1", "red2": "static:-1",
"var1": "T", "var2": None, "add_var2_if_empty": False,
"from": ["base:enc_ctx", "att_query"]},
# energy inside the loop will be (B, H, enc-T).
# energy outside the loop should be (B, H, enc-T, dec-T). I.e. enc-T is still the first time axis.
# The logic should be that the dot layer would add any extra axes (due to this optimization, moving layer out)
# to either common shared axes (which is implicit) or if it is only added to one input,
# then to the respective var axes.
"att_weights": {"class": "softmax_over_spatial", "from": "energy"}, # (B, enc-T, H)
"att0": {"class": "generic_attention", "weights": "att_weights", "base": "base:enc_value"}, # (B, H, V)
"att": {"class": "merge_dims", "axes": "static", "from": "att0"}, # (B, H*V); Use "static" here.
},
shared_base_net={
# Use conv with padding valid to make sure we get another time dim,
# such that the rec part above will not confuse this time dim with the rec time dim.
"encoder": {"class": "conv", "from": "data", "filter_size": [3], "padding": "valid", "n_out": 5},
"enc_ctx0": {"class": "linear", "activation": None, "with_bias": False, "from": "encoder",
"n_out": n_key_total}, # (B, enc-T, D)
"enc_ctx": {"class": "split_dims", "axis": "F", "dims": (n_heads, n_key),
"from": "enc_ctx0", "is_output_layer": True}, # (B, enc-T, H, D/H)
"enc_value0": {"class": "linear", "activation": None, "with_bias": False, "from": "encoder",
"n_out": n_value_total},
"enc_value": {"class": "split_dims", "axis": "F", "dims": (n_heads, n_value),
"from": "enc_value0", "is_output_layer": True}, # (B, enc-T, H, D/H)
},
rtol=1e-3)


def test_reclayer_optimize_out_dot_kv_in_rec():
# Same as test_reclayer_optimize_out_dot, but with the att key/value layers declared INSIDE the rec layer.
AttNumHeads = 4
Expand Down