Skip to content

Commit

Permalink
LayerBase base out data, fixes for Data.sparse_dim
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Nov 22, 2021
1 parent 208720b commit aba4232
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions returnn/tf/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,10 @@ def _base_get_out_data_from_opts(cls, network, name, out_type=None, n_out=NotSpe
if sources_data:
out_type.setdefault("batch_dim_axis", sources_data.batch_dim_axis)
out_type.setdefault("time_dim_axis", sources_data.time_dim_axis)
if not out_type.get("sparse", False) and sources_data.feature_dim_axis_or_unspecified is not NotSpecified:
if (
not out_type.get("sparse", False) and
not out_type.get("sparse_dim", None) and
sources_data.feature_dim_axis_or_unspecified is not NotSpecified):
if sources_data.feature_dim_axis_or_unspecified is not None:
out_type.setdefault("feature_dim_axis", sources_data.feature_dim_axis_or_unspecified)
else: # None
Expand All @@ -334,7 +337,7 @@ def _base_get_out_data_from_opts(cls, network, name, out_type=None, n_out=NotSpe
out_type.setdefault("time_dim_axis", None)
if "shape" not in out_type and "dim_tags" not in out_type:
if sources_data:
if out_type.get("sparse", False):
if out_type.get("sparse", False) or out_type.get("sparse_dim", None):
out_type["dim_tags"] = sources_data.dim_tags_sparse
else: # not sparse
feature_dim_axis = out_type.get("feature_dim_axis", NotSpecified)
Expand Down

0 comments on commit aba4232

Please sign in to comment.