Skip to content

Commit

Permalink
Data.dim_tags_set, Data.verify_out_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Nov 22, 2021
1 parent 6228964 commit 81525df
Showing 1 changed file with 145 additions and 0 deletions.
145 changes: 145 additions & 0 deletions returnn/tf/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,59 @@ def vocab(self, vocab):
self.get_same_base()._vocab = vocab


# Global dim tag placeholders.
BatchDim = DimensionTag(kind=DimensionTag.Types.Batch, description="global batch")


class _ImplicitDim:
"""
Represents an implicit dim (dim tag) in :class:`Data`.
https://github.com/rwth-i6/returnn/issues/706
"""
def __init__(self, tag):
"""
:param DimensionTag tag:
"""
self.tag = tag

def __repr__(self):
return "%s(%r)" % (self.__class__.__name__, self.tag)

def _eq_tuple(self):
return self.__class__, self.tag

def __hash__(self):
return hash(self._eq_tuple())

def __eq__(self, other):
if isinstance(other, _ImplicitDim):
return self._eq_tuple() == other._eq_tuple()
return False

def __ne__(self, other):
return not (self == other)


class ImplicitSparseDim(_ImplicitDim):
"""
Represents an implicit dim via Data.sparse_dim.
"""


class ImplicitDynSizeDim(_ImplicitDim):
"""
Represents an implicit dim via dynamic dim sizes.
https://github.com/rwth-i6/returnn/issues/706
(For example via :class:`CumConcatLayer`.)
"""


class VerifyOutShapeException(Exception):
"""
Exception via :func:`Data.verify_out_shape`.
"""


class BatchInfo:
"""
A batched tensor is a tensor with batch dimension,
Expand Down Expand Up @@ -1793,6 +1846,51 @@ def get_runtime_sanity_check_op(self):
checks += [dyn_size_ext.get_runtime_sanity_check_op()]
return tf.group(*checks)

def verify_out_shape(self, out_shape):
"""
Verifies that ``out_shape`` matches our shape, i.e. specifically the dim tags.
https://github.com/rwth-i6/returnn/issues/706
Throws an exception if this is not the case.
:param set[DimensionTag|_ImplicitDim]|tuple|list out_shape:
It must be a set, with the only exception when it is empty (then it doesn't matter).
See :func:`dim_tags_set`.
"""
self_dim_tags = self.dim_tags_set_implicit
self_dim_tags_implicit_only = self.dim_tags_set_implicit_only_wrapped
if not out_shape:
if self_dim_tags:
raise VerifyOutShapeException(
"%s verify_out_shape, with dims %s, does not match empty out_shape %r" % (self, self_dim_tags, out_shape))
return
if not isinstance(out_shape, set):
raise TypeError("%s verify_out_shape: expects a set but got %s" % (self, type(out_shape)))
remaining = set(self_dim_tags)
for dim in out_shape:
if isinstance(dim, DimensionTag):
dim_tag = dim
elif isinstance(dim, _ImplicitDim):
dim_tag = dim.tag
if dim not in self_dim_tags_implicit_only:
raise VerifyOutShapeException(
"%s verify_out_shape, with dims %s, with out_shape %s, %s is not an implicit dim in self" % (
self, self_dim_tags, out_shape, dim))
else:
raise TypeError("%s verify_out_shape with out_shape %s: expect dim tags but got %s" % (
self, out_shape, type(dim)))
if dim_tag not in remaining:
if dim_tag in self_dim_tags: # can happen e.g. if specified once as implicit dim and then also as explicit
raise VerifyOutShapeException(
"%s verify_out_shape, with dims %s, does not match out_shape %r, dim %s multiple times in out_shape" % (
self, self_dim_tags, out_shape, dim))
raise VerifyOutShapeException(
"%s verify_out_shape, with dims %s, does not match out_shape %r, %s not in self" % (
self, self_dim_tags, out_shape, dim))
remaining.discard(dim_tag)
if remaining:
raise VerifyOutShapeException(
"%s verify_out_shape, dims %s are not specified in out_shape %s" % (self, remaining, out_shape))

def get_placeholder_kwargs(self, with_batch=True):
"""
:param bool with_batch:
Expand Down Expand Up @@ -2860,6 +2958,53 @@ def dim_tags_sparse(self):
return self.dim_tags
return self.dim_tags[:self.feature_dim_axis] + self.dim_tags[self.feature_dim_axis + 1:]

@property
def dim_tags_set_implicit_only_wrapped(self):
"""
:return: Dim tags implicit by sparse dim, or dynamic sizes, and not present as explicit dims.
Also see :func:`dim_tags_set`.
:rtype: set[_ImplicitDim]
"""
self_dim_tags = set(self.dim_tags)
dims = set()
if self.sparse_dim and self.sparse_dim not in self_dim_tags:
dims.add(ImplicitSparseDim(self.sparse_dim))
for dim in self.dim_tags:
if dim.dyn_size_ext:
for dim_ in dim.dyn_size_ext.dim_tags:
if dim_ not in self_dim_tags:
dims.add(ImplicitDynSizeDim(dim_))
return dims

@property
def dim_tags_set_implicit_only(self):
"""
:return: Dim tags implicit by sparse dim, or dynamic sizes, and not present as explicit dims.
Also see :func:`dim_tags_set`.
:rtype: set[DimensionTag]
"""
return set(dim.tag for dim in self.dim_tags_set_implicit_only_wrapped)

@property
def dim_tags_set_implicit(self):
"""
This is mostly intended to be used for verification, such as ``out_shape`` in a layer.
https://github.com/rwth-i6/returnn/issues/706
We return a set because when dim tags (dimensions, and the shape) are checked,
we never want that the order plays any role.
https://github.com/rwth-i6/returnn/wiki/RETURNN-principles
Further, dimension tags should ideally be unique.
https://github.com/rwth-i6/returnn/issues/632
(This is not enforced currently, but we should not treat this specially now.)
:return: set of dim tags
:rtype: set[DimensionTag]
"""
dims = set(self.dim_tags)
dims.update(self.dim_tags_set_implicit_only)
return dims

@property
def ndim(self):
"""
Expand Down

0 comments on commit 81525df

Please sign in to comment.