Skip to content

Commit

Permalink
dot nicer API
Browse files Browse the repository at this point in the history
Fix #67
  • Loading branch information
albertz committed Dec 15, 2021
1 parent 310e6ed commit 0cba8d2
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 12 deletions.
48 changes: 43 additions & 5 deletions nn/_generate_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,18 @@ def setup():

if layer_class.layer_class:
print("", file=f)
if sig.has_source_param() or sig.has_recurrent_state() or sig.has_module_call_args():
if any([
sig.has_source_param(),
sig.explicit_source_list(),
sig.has_recurrent_state(),
sig.has_module_call_args()]):
print(" # noinspection PyShadowingBuiltins,PyShadowingNames", file=f)
print(" def make_layer_dict(self,", file=f)
if sig.has_source_param():
print(f" {sig.get_module_call_source_param_code_str()},", file=f)
elif sig.explicit_source_list():
for i in range(sig.explicit_source_list()):
print(f" {sig.get_module_call_source_param_code_str(explicit_idx=i)},", file=f)
if sig.has_module_call_args() or sig.has_recurrent_state():
print(" *,", file=f)
if sig.has_recurrent_state():
Expand All @@ -214,6 +221,9 @@ def setup():
file=f)
else:
print(" assert isinstance(source, LayerRef)", file=f)
elif sig.explicit_source_list():
for i in range(sig.explicit_source_list()):
print(f" assert isinstance(source{i + 1}, LayerRef)", file=f)
if sig.has_module_call_args() or sig.has_recurrent_state():
print(" args = {", file=f)
if sig.has_recurrent_state():
Expand All @@ -227,6 +237,8 @@ def setup():
print(f" 'class': {layer_class.layer_class!r},", file=f)
if sig.has_source_param():
print(" 'from': source,", file=f)
elif sig.explicit_source_list():
print(f" 'from': [{', '.join('source' + str(i + 1) for i in range(sig.explicit_source_list()))}],", file=f)
if sig.has_module_call_args() or sig.has_recurrent_state():
print(" **args,", file=f)
print(" **self.get_opts()}", file=f)
Expand All @@ -253,6 +265,10 @@ def setup():
if sig.has_source_param():
print(f"{prefix}{sig.get_module_call_source_param_code_str()},", file=f)
args.append("source")
elif sig.explicit_source_list():
for i in range(sig.explicit_source_list()):
print(f"{prefix}{sig.get_module_call_source_param_code_str(explicit_idx=i)},", file=f)
args.append(f"source{i + 1}")
print(f"{prefix}*,", file=f)
if sig.has_recurrent_state():
print(f"{prefix}{sig.get_module_call_state_param_code_str('state')},", file=f)
Expand Down Expand Up @@ -280,6 +296,9 @@ def setup():
print("", file=f)
if sig.has_source_param():
print(f" {sig.get_module_call_source_docstring()}", file=f)
elif sig.explicit_source_list():
for i in range(sig.explicit_source_list()):
print(f" {sig.get_module_call_source_docstring(explicit_idx=i)}", file=f)
if sig.has_recurrent_state():
print(f" {sig.get_module_call_state_docstring('state')}", file=f)
print(f" {sig.get_module_call_state_docstring('initial_state')}", file=f)
Expand Down Expand Up @@ -308,6 +327,9 @@ def setup():
print(f" return mod(", file=f)
if sig.has_source_param():
print(" source,", file=f)
elif sig.explicit_source_list():
for i in range(sig.explicit_source_list()):
print(f" source{i + 1},", file=f)
if sig.has_recurrent_state():
print(" state=state,", file=f)
print(" initial_state=initial_state,", file=f)
Expand Down Expand Up @@ -344,6 +366,8 @@ def has_source_param(self) -> bool:
self.layer_class,
(SourceLayer, ConstantLayer, VariableLayer, CondLayer, SwitchLayer, GenericAttentionLayer)):
return False
if self.explicit_source_list():
return False
return True

def support_multiple_sources(self) -> bool:
Expand All @@ -364,10 +388,17 @@ def need_multiple_sources(self) -> bool:
return False
if issubclass(self.layer_class, (CombineLayer, CompareLayer, StackLayer)):
return True
if self.layer_class.layer_class in {"dot"}:
return True
return False

def explicit_source_list(self) -> Optional[int]:
"""
If returned value is given, it means that instead of source: list[Layers],
we have source1, source2 etc, the number returned here.
"""
if self.layer_class.layer_class == "dot":
return 2
return None

# noinspection PyMethodMayBeStatic
def default_source(self) -> Optional[str]:
"""
Expand All @@ -377,10 +408,12 @@ def default_source(self) -> Optional[str]:
return "()"
return None

def get_module_call_source_param_code_str(self):
def get_module_call_source_param_code_str(self, explicit_idx: Optional[int] = None):
"""
Code for `source` param
"""
if explicit_idx is not None:
return f"source{explicit_idx + 1}: LayerRef"
assert self.has_source_param()
s = "source: "
if self.need_multiple_sources():
Expand All @@ -394,10 +427,12 @@ def get_module_call_source_param_code_str(self):
s += " = " + default
return s

def get_module_call_source_docstring(self):
def get_module_call_source_docstring(self, explicit_idx: Optional[int] = None):
"""
Code for docstring of `source` param
"""
if explicit_idx is not None:
return f":param LayerRef source{explicit_idx + 1}:"
s = ":param "
if self.need_multiple_sources():
s += "list[LayerRef]|tuple[LayerRef]"
Expand Down Expand Up @@ -722,6 +757,9 @@ def __repr__(self):
args = []
if self.has_source_param():
args.append("source")
elif self.explicit_source_list():
for i in range(self.explicit_source_list()):
args.append(f"source{i + 1}")
if self.has_recurrent_state():
args.append("state")
args += [arg.get_module_param_name() for arg in self.get_all_derived_args()]
Expand Down
21 changes: 14 additions & 7 deletions nn/_generated_layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
This file is auto-generated by _generate_layers.py.
RETURNN: 1.20211209.173823+git.b6d0c47
RETURNN: 1.20211209.162912+git.5697dbb
These are the RETURNN layers directly wrapped.
Note that we intentionally exclude some layers or options for more consistency.
Expand Down Expand Up @@ -4399,21 +4399,24 @@ def get_opts(self):

# noinspection PyShadowingBuiltins,PyShadowingNames
def make_layer_dict(self,
source: Union[List[LayerRef], Tuple[LayerRef]],
source1: LayerRef,
source2: LayerRef,
) -> LayerDictRaw:
"""
Make layer dict
"""
assert isinstance(source, (tuple, list)) and all(isinstance(s, LayerRef) for s in source)
assert isinstance(source1, LayerRef)
assert isinstance(source2, LayerRef)
return {
'class': 'dot',
'from': source,
'from': [source1, source2],
**self.get_opts()}


# noinspection PyShadowingBuiltins,PyShadowingNames
def dot(
source: Union[List[LayerRef], Tuple[LayerRef]],
source1: LayerRef,
source2: LayerRef,
*,
reduce: Union[Dim, Tuple[Dim, ...], List[Dim]] = NotSpecified,
red1: Union[Dim, Tuple[Dim, ...], List[Dim]] = NotSpecified,
Expand All @@ -4436,7 +4439,8 @@ def dot(
However, these are bad, for multiple reasons, like using integers, but also in general.
See https://github.com/rwth-i6/returnn/issues/627 for details.
:param list[LayerRef]|tuple[LayerRef] source:
:param LayerRef source1:
:param LayerRef source2:
:param Dim|tuple[Dim]|list[Dim] reduce: reduce axes of both sources
:param Dim|tuple[Dim]|list[Dim] red1: reduce axes of first source
:param Dim|tuple[Dim]|list[Dim] red2: reduce axes of second source
Expand All @@ -4453,7 +4457,10 @@ def dot(
var2=var2,
debug=debug,
)
return mod(source, name=name)
return mod(
source1,
source2,
name=name)


class _ShiftAxis(_Base):
Expand Down

0 comments on commit 0cba8d2

Please sign in to comment.