Skip to content

Commit

Permalink
test_reclayer_optimize_out_cum_concat_gen_self_att
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Sep 1, 2021
1 parent 09c4cb8 commit 2384c1b
Showing 1 changed file with 33 additions and 2 deletions.
35 changes: 33 additions & 2 deletions tests/test_TFNetworkRecLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3345,8 +3345,7 @@ def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, sha
rec_layer_dict["unit"].update(other_subnet_layers)
config = Config({
"debug_print_layer_output_template": True,
"num_inputs": n_in,
"num_outputs": n_out
"extern_data": {"data": {"dim": n_in}},
})
from returnn.tf.layers.rec import _SubnetworkRecCell
with make_scope() as session:
Expand Down Expand Up @@ -3423,6 +3422,38 @@ def test_reclayer_optimize_out_selfatt_left():
"class": "self_attention", "attention_left_only": True, "num_heads": 2, "total_key_dim": 6, "n_out": 18})


def test_reclayer_optimize_out_cum_concat_gen_self_att():
new_dim = DimensionTag(kind=DimensionTag.Types.Spatial, description="cum_concat_new_dim")
n_key = 5
n_value = 7
check_reclayer_optimize_out(
{"class": "linear", "from": "att", "activation": None},
{
# This is very much the vanilla self attention,
# implemented via the new generic way.
# See https://github.com/rwth-i6/returnn/issues/391 for a long discussion.
# Commented shapes are always for the layers inside the loop (not optimized).
"qkv": {"class": "linear", "from": "data:source", "activation": None, "n_out": n_key * 2 + n_value}, # [B,2*K+V]
"qkv_split": {"class": "split", "from": "qkv", "size_splits": [n_key, n_key, n_value]},
"q": {"class": "copy", "from": "qkv_split/0"}, # [B,K]
"k": {"class": "copy", "from": "qkv_split/1"}, # [B,K]
"v": {"class": "copy", "from": "qkv_split/2"}, # [B,V]
# cum_concat here. Note that the optimized-out shape is not as you might expect [T,max(t),B,K],
# but instead using the optimized format, with extended dyn size on the special dim tag.
"k_accum": {"class": "cum_concat", "new_dim": new_dim, "from": "k"}, # [t,B,K]
"v_accum": {"class": "cum_concat", "new_dim": new_dim, "from": "v"}, # [t,B,V]
"energy": {
"class": "dot", "from": ["q", "k_accum"],
"red1": "static:-1", "red2": "static:-1",
"var1": None, "var2": new_dim}, # [B,t]
"att_weights": {"class": "softmax_over_spatial", "from": "energy", "axis": new_dim}, # [B,t]
"att": {
"class": "dot", "from": ["att_weights", "v_accum"],
"red1": new_dim, "red2": new_dim,
"var1": None, "var2": "static:-1"}, # [B,V]
})


def test_reclayer_optimize_out_dot():
# Used for multi-head dot-attention.
AttNumHeads = 4
Expand Down

0 comments on commit 2384c1b

Please sign in to comment.