diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 5dec156e1..1ef76aa24 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -792,6 +792,59 @@ def vocab(self, vocab): self.get_same_base()._vocab = vocab +# Global dim tag placeholders. +BatchDim = DimensionTag(kind=DimensionTag.Types.Batch, description="global batch") + + +class _ImplicitDim: + """ + Represents an implicit dim (dim tag) in :class:`Data`. + https://github.com/rwth-i6/returnn/issues/706 + """ + def __init__(self, tag): + """ + :param DimensionTag tag: + """ + self.tag = tag + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self.tag) + + def _eq_tuple(self): + return self.__class__, self.tag + + def __hash__(self): + return hash(self._eq_tuple()) + + def __eq__(self, other): + if isinstance(other, _ImplicitDim): + return self._eq_tuple() == other._eq_tuple() + return False + + def __ne__(self, other): + return not (self == other) + + +class ImplicitSparseDim(_ImplicitDim): + """ + Represents an implicit dim via Data.sparse_dim. + """ + + +class ImplicitDynSizeDim(_ImplicitDim): + """ + Represents an implicit dim via dynamic dim sizes. + https://github.com/rwth-i6/returnn/issues/706 + (For example via :class:`CumConcatLayer`.) + """ + + +class VerifyOutShapeException(Exception): + """ + Exception via :func:`Data.verify_out_shape`. + """ + + class BatchInfo: """ A batched tensor is a tensor with batch dimension, @@ -1793,6 +1846,51 @@ def get_runtime_sanity_check_op(self): checks += [dyn_size_ext.get_runtime_sanity_check_op()] return tf.group(*checks) + def verify_out_shape(self, out_shape): + """ + Verifies that ``out_shape`` matches our shape, i.e. specifically the dim tags. + https://github.com/rwth-i6/returnn/issues/706 + Throws an exception if this is not the case. + + :param set[DimensionTag|_ImplicitDim]|tuple|list out_shape: + It must be a set, with the only exception when it is empty (then it doesn't matter). + See :func:`dim_tags_set`. + """ + self_dim_tags = self.dim_tags_set_implicit + self_dim_tags_implicit_only = self.dim_tags_set_implicit_only_wrapped + if not out_shape: + if self_dim_tags: + raise VerifyOutShapeException( + "%s verify_out_shape, with dims %s, does not match empty out_shape %r" % (self, self_dim_tags, out_shape)) + return + if not isinstance(out_shape, set): + raise TypeError("%s verify_out_shape: expects a set but got %s" % (self, type(out_shape))) + remaining = set(self_dim_tags) + for dim in out_shape: + if isinstance(dim, DimensionTag): + dim_tag = dim + elif isinstance(dim, _ImplicitDim): + dim_tag = dim.tag + if dim not in self_dim_tags_implicit_only: + raise VerifyOutShapeException( + "%s verify_out_shape, with dims %s, with out_shape %s, %s is not an implicit dim in self" % ( + self, self_dim_tags, out_shape, dim)) + else: + raise TypeError("%s verify_out_shape with out_shape %s: expect dim tags but got %s" % ( + self, out_shape, type(dim))) + if dim_tag not in remaining: + if dim_tag in self_dim_tags: # can happen e.g. if specified once as implicit dim and then also as explicit + raise VerifyOutShapeException( + "%s verify_out_shape, with dims %s, does not match out_shape %r, dim %s multiple times in out_shape" % ( + self, self_dim_tags, out_shape, dim)) + raise VerifyOutShapeException( + "%s verify_out_shape, with dims %s, does not match out_shape %r, %s not in self" % ( + self, self_dim_tags, out_shape, dim)) + remaining.discard(dim_tag) + if remaining: + raise VerifyOutShapeException( + "%s verify_out_shape, dims %s are not specified in out_shape %s" % (self, remaining, out_shape)) + def get_placeholder_kwargs(self, with_batch=True): """ :param bool with_batch: @@ -2860,6 +2958,53 @@ def dim_tags_sparse(self): return self.dim_tags return self.dim_tags[:self.feature_dim_axis] + self.dim_tags[self.feature_dim_axis + 1:] + @property + def dim_tags_set_implicit_only_wrapped(self): + """ + :return: Dim tags implicit by sparse dim, or dynamic sizes, and not present as explicit dims. + Also see :func:`dim_tags_set`. + :rtype: set[_ImplicitDim] + """ + self_dim_tags = set(self.dim_tags) + dims = set() + if self.sparse_dim and self.sparse_dim not in self_dim_tags: + dims.add(ImplicitSparseDim(self.sparse_dim)) + for dim in self.dim_tags: + if dim.dyn_size_ext: + for dim_ in dim.dyn_size_ext.dim_tags: + if dim_ not in self_dim_tags: + dims.add(ImplicitDynSizeDim(dim_)) + return dims + + @property + def dim_tags_set_implicit_only(self): + """ + :return: Dim tags implicit by sparse dim, or dynamic sizes, and not present as explicit dims. + Also see :func:`dim_tags_set`. + :rtype: set[DimensionTag] + """ + return set(dim.tag for dim in self.dim_tags_set_implicit_only_wrapped) + + @property + def dim_tags_set_implicit(self): + """ + This is mostly intended to be used for verification, such as ``out_shape`` in a layer. + https://github.com/rwth-i6/returnn/issues/706 + + We return a set because when dim tags (dimensions, and the shape) are checked, + we never want that the order plays any role. + https://github.com/rwth-i6/returnn/wiki/RETURNN-principles + Further, dimension tags should ideally be unique. + https://github.com/rwth-i6/returnn/issues/632 + (This is not enforced currently, but we should not treat this specially now.) + + :return: set of dim tags + :rtype: set[DimensionTag] + """ + dims = set(self.dim_tags) + dims.update(self.dim_tags_set_implicit_only) + return dims + @property def ndim(self): """