Skip to content

Commit

Permalink
CumConcatLayer
Browse files Browse the repository at this point in the history
This is for generalized self attention (#391).

Co-authored-by: Frithjof <[email protected]>
  • Loading branch information
albertz and Zettelkasten committed Sep 1, 2021
1 parent bd2771d commit 09c4cb8
Showing 1 changed file with 175 additions and 0 deletions.
175 changes: 175 additions & 0 deletions returnn/tf/layers/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8498,3 +8498,178 @@ def get_out_data_from_opts(cls, name, sources, n_out, **kwargs):
kind=DimensionTag.Types.Spatial, description="%s_rel_pos_enc_time" % name, dimension=None)
data = data.copy_template_new_dim_tags((dummy_dim_tag, time_dim_tag, feature_dim_tag))
return data


class CumConcatLayer(_ConcatInputLayer):
"""
Concatenates all previous frames of a time-axis.
Like :class:`CumsumLayer` uses `sum`, this layer uses `concat`.
This layer can be used as a base for auto-regressive self-attention.
This layer expects to be inside a :class:`RecLayer`.
Inside a rec loop (not optimized out),
this will concatenate the current input
to the previous accumulated inputs.
For an input of shape `input_shape`,
it will output a tensor of shape `[new_dim] + input_shape`.
`new_dim` is a special dimension, usually of length `i`,
where `i` is the current loop frame,
i.e. the length increases in every loop frame.
`new_dim` is specified by a separate own dim tag.
For example, in the first frame,
this will be of shape `[1] + input_shape`,
in the second frame shape `[2] + input_shape`,
and so on,
and in the last frame shape `[T] + input_shape`.
Outside the rec loop (optimized out),
this layer expects an input with the time dim of the rec layer,
and returns the input as-is,
but replacing the time dim tag with the dim tag `new_dim`
converted as outside the loop.
Normally the optimization should not matter for the user,
i.e. for the user, the logical behavior is always as being inside the rec loop.
Outside the loop,
the output represents a tensor of shape `[T, new_dim] + input_shape`,
although we actually have another `new_dim` outside the loop,
and `T` is not actually there,
but we still have all the information,
because the last frame has all information.
This `new_dim` outside the loop stores all the dynamic seq lengths
per frame of the loop, i.e. the dyn seq len are extended of shape [B,T] or [T]
(unlike usually just [B]).
This way following layers use different seq lengths of `new_dim` for different loop frames,
just like if the `T` dim would actually exist.
"""
layer_class = "cum_concat"
recurrent = True # order matters

def __init__(self, new_dim, **kwargs):
"""
:param DimensionTag new_dim:
"""
super(CumConcatLayer, self).__init__(**kwargs)
rec_layer = self.network.get_rec_parent_layer(inside_loop=False)
assert rec_layer, "%r must be used inside a RecLayer" % self
out_axis = self.output.get_axis_from_description(new_dim)
new_dim_ = self.output.dim_tags[out_axis]

if not self.input_data.has_axis(rec_layer.time_dim_tag): # inside loop
current_data = self.input_data.copy_compatible_to(self.output, unbroadcast=False)
current_frame = current_data.placeholder # [B, 1, ..., D]
last_frames = self._rec_previous_layer.rec_vars_outputs["state"] # [B, t, ..., D]
concat_frames = tf.concat([last_frames, current_frame], axis=out_axis) # [B, t+1, ..., D]
self.rec_vars_outputs["state"] = concat_frames
self.output.placeholder = concat_frames

if not new_dim_.dyn_size_ext:
# Unbroadcasting to [B] is not needed because any layers operating on this
# should be able to handle extended dyn sizes.
# Clipping it to the max length for sequences in the loop which are already ended
# (i.e. considering the end flag)
# is also not needed because any calculations after the end are irrelevant.
# Note: In case we have some initial state/output, this can be extended.
dyn_size = self.network.get_rec_step_index() + 1 # scalar
new_dim_.dyn_size_ext = Data(
name="%s:cum-concat:size-inside" % self.name,
dim_tags=[], # scalar
placeholder=dyn_size, dtype="int32")

else: # outside loop
# If not inside a rec loop, this layer is a no-op on the tensor.
self.output.placeholder = self.input_data.placeholder

# However, we used new dim tags, which were already prepared.
# We now must fill in the extended dynamic size information.
if not new_dim_.dyn_size_ext:
# This must match the logic above for inside the loop.
# Note: In case we have some initial state/output, this can be extended.
dyn_size = tf.range(tf.math.reduce_max(rec_layer.time_dim_tag.dyn_size)) + 1 # [T]
new_dim_.dyn_size_ext = Data(
name="%s:cum-concat:size-outside" % self.name,
dim_tags=[rec_layer.time_dim_tag],
placeholder=dyn_size, dtype="int32")

@classmethod
def get_out_data_from_opts(cls, name, network, sources, new_dim, **kwargs):
"""
:param str name:
:param returnn.tf.network.TFNetwork network:
:param list[LayerBase] sources:
:param DimensionTag new_dim:
:rtype: Data
"""
assert network.is_inside_rec_layer(inside_loop=False), "CumConcatLayer %r must be used inside a RecLayer" % name
rec_time_dim = network.get_inside_rec_time_dim(inside_loop=False)
assert rec_time_dim
new_dim_base = new_dim.get_same_base()
if new_dim_base.per_spatial_frame is None:
new_dim_base.per_spatial_frame = rec_time_dim
else:
assert new_dim_base.per_spatial_frame == rec_time_dim

input_data = get_concat_sources_data_template(sources, name="%s_output" % name)
if not input_data.has_axis(rec_time_dim): # inside loop
# Currently SelectSearchSourcesLayer assumes that all rec_vars_outputs are batch-major.
# Therefore we here copy the input as batch-major, and then add the time axis at axis 1.
# In the future, when SelectSearchSourcesLayer has support for this, we can change this to operate on axis 0,
# which should be more efficient
out = input_data.copy_as_batch_major()
out = out.copy_add_dim_by_tag(new_dim_base, unbroadcast=True, axis=1)
return out

else: # outside loop
if not new_dim_base.per_spatial_frame_accumulated:
new_dim_accum = DimensionTag(
kind=new_dim_base.kind, description="%s:accumulated" % name)
new_dim_accum.declare_same_as(new_dim_base)
new_dim_base.per_spatial_frame_accumulated = new_dim_accum
else:
new_dim_accum = new_dim_base.per_spatial_frame_accumulated
# Assume that the input has the time dim from the rec layer.
axis = input_data.get_axis_from_description(rec_time_dim)
return input_data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=new_dim_accum)

# noinspection PyMethodOverriding
@classmethod
def get_rec_initial_extra_outputs(cls, network, batch_dim, rec_layer, sources, output, new_dim, **kwargs):
"""
:param returnn.tf.network.TFNetwork network:
:param tf.Tensor batch_dim:
:param returnn.tf.layers.rec.RecLayer|LayerBase rec_layer:
:param list[LayerBase] sources:
:param Data output:
:param DimensionTag new_dim:
:rtype: dict[str,tf.Tensor]
"""
if network.is_inside_rec_layer():
shape = []
for tag in output.dim_tags:
if tag.is_batch_dim():
shape.append(batch_dim)
elif tag == new_dim:
shape.append(0)
elif tag.dimension is not None:
shape.append(tag.dimension)
else:
assert tag.dyn_size is not None
shape.append(tf.math.reduce_max(tag.dyn_size))
return {"state": tf.zeros(shape, dtype=output.dtype)}
else:
return {}

@classmethod
def get_rec_initial_extra_outputs_shape_invariants(cls, network, sources, output, **kwargs):
"""
:param returnn.tf.network.TFNetwork network:
:param list[LayerBase] sources:
:param Data output:
:rtype: dict[str, tf.TensorShape]
"""
if network.is_inside_rec_layer():
return {"state": tf.TensorShape(output.batch_shape)}
else:
return {}

0 comments on commit 09c4cb8

Please sign in to comment.