Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add @no_side_effects. #647

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions axlearn/common/inference_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ def prune_non_tf_str_tensor_leaves(_: str, subtree: NestedTensor):

return subtree.dtype != tf.string

batch_without_str_tensors = utils.prune_tree(batch, prune_tf_str_tensor_leaves)
str_tensors = utils.prune_tree(batch, prune_non_tf_str_tensor_leaves)
batch_without_str_tensors = utils.prune_tree(
batch, prune_tf_str_tensor_leaves, prune_root=False
)
str_tensors = utils.prune_tree(batch, prune_non_tf_str_tensor_leaves, prune_root=False)
return batch_without_str_tensors, str_tensors


Expand Down
20 changes: 2 additions & 18 deletions axlearn/common/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
Tensor,
flatten_items,
match_regex_rules,
prune_tree,
prune_empty,
register_per_param_settings,
tree_paths,
)
Expand Down Expand Up @@ -90,22 +90,6 @@ def should_apply_state_updates(update_type: UpdateType) -> bool:
return update_type in (UpdateType.STATE_UPDATES, UpdateType.ALL_UPDATES)


def _prune_empty(in_tree: Nested[Tensor]) -> Nested[Tensor]:
"""Returns a shallow copy of the input tree with empty subtrees pruned.

If a tree would be made empty by removal of its subtrees, it will also be pruned.
This is a shallow copy because leaf nodes (non-dict values) are not deep-copied.

Args:
in_tree: the input tree to be pruned.

Returns:
The pruned copy of the input tree.
"""
# Note that falsey values or empty Tensors are not considered empty.
return prune_tree(in_tree, lambda _, v: isinstance(v, dict) and not v)


class BaseLearner(LearnerModule):
"""The base class of a learner."""

Expand Down Expand Up @@ -310,7 +294,7 @@ def _compute_updated_params(
updated_model_params = optax.apply_updates(
jax.tree_util.tree_map(lambda op: op.value, opt_params), parameter_updates
)
state_updates = _prune_empty(state_updates)
state_updates = prune_empty(state_updates, fill_value=optax.MaskedNode())
apply_state_updates = jax.tree_util.tree_map(
should_apply_state_updates,
self._update_types(state_updates),
Expand Down
4 changes: 2 additions & 2 deletions axlearn/common/learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
Learner,
UpdateType,
_apply_updates,
_prune_empty,
_split_gradients,
_value_and_grad,
should_update_with_optimizers,
Expand Down Expand Up @@ -55,6 +54,7 @@
VDict,
flatten_items,
match_regex_rules,
prune_empty,
tree_paths,
)

Expand Down Expand Up @@ -86,7 +86,7 @@ def test_prune_empty_state(self):
},
},
}
actual = _prune_empty(state)
actual = prune_empty(state)
self.assertNestedAllClose(expected, actual)

@parameterized.product(ema_decay=(None, 0.9), method=("update", "forward_and_backward"))
Expand Down
151 changes: 143 additions & 8 deletions axlearn/common/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,19 @@ def do_foo(self, ...):
import re
import threading
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, TypeVar, Union
from typing import (
Any,
Callable,
Dict,
List,
Literal,
NamedTuple,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)

import jax
import numpy as np
Expand All @@ -59,9 +71,11 @@ def do_foo(self, ...):
Nested,
NestedTensor,
Tensor,
get_recursively,
partial_with_fn_metadata,
prune_tree,
raise_for_cycles,
same_pruned_structure,
)


Expand Down Expand Up @@ -540,6 +554,8 @@ def _should_wrap_method(method: str) -> bool:
fn_sig = inspect.signature(fn)
if "self" not in fn_sig.parameters:
return False
if getattr(fn, _NO_SIDE_EFFECTS_ATTRIBUTE_NAME, False):
return False
return True

return {
Expand Down Expand Up @@ -990,14 +1006,12 @@ def scan_fn(carry_i: NestedTensor, scan_i: NestedTensor):

# Filter output collection.
if drop_output is not None:
pruned_collection_i = new_output_collection()._asdict()
pruned_collection_i.update(
prune_tree(
output_collection_i._asdict(),
lambda path, _: drop_output(path),
)
pruned_output_collection_i = prune_tree(
output_collection_i._asdict(), lambda path, _: drop_output(path), prune_root=False
)
output_collection_i = OutputCollection(**pruned_collection_i)
output_collection_i = new_output_collection()._asdict()
output_collection_i.update(pruned_output_collection_i)
output_collection_i = OutputCollection(**output_collection_i)

return carry_i, dict(y_i=y_i, output_collection=output_collection_i)

Expand All @@ -1009,3 +1023,124 @@ def scan_fn(carry_i: NestedTensor, scan_i: NestedTensor):
)

return carry, scan_ys["y_i"]


def get_state_or_update(module: Module, key: str) -> Optional[Nested[Tensor]]:
"""Return `module.state[key]`, inclusive of any state update.

Unlike `module.state[key]`, this works even if in a parent context.

Args:
module: The module to get the state of.
key: The key to lookup in the state dict.
state: The state tree to traverse, rooted at the `current_context().module`.
If None, use `current_context().module`.

Returns
The requested state.

Raises:
KeyError: If it is not found.
"""
ctx = current_context()
not_found = object()
state = get_state(module, state=ctx.get_state_updates(), default=not_found)
if state is not not_found and key in state:
return state[key]
return get_state(module)[key]


# A sentinel that can be passed as the `default` parameter of `get_module_state()`.
RAISE = object()
_Default = TypeVar("_Default")


def get_state(
module: Module, *, state: Optional[Nested[Tensor]] = None, default: _Default = RAISE
) -> Union[Nested[Tensor], _Default]:
"""Return `module.state`.

Unlike `module.state`, this works even if in a parent context.

Args:
module: The module to get the state of.
state: The state tree to traverse, rooted at the `current_context().module`.
If None, use `current_context().module`.
default: The default value to return if it is not found. If set to the sentinel RAISE,
a KeyError is raised instead of returning.

Returns
The requested state.

Raises:
KeyError: If it is not found.
"""
ctx = current_context()
if state is None:
state = ctx.state
path = ctx.module.path_to_descendant_module(module)
try:
state = get_recursively(state, path)
except KeyError:
if default is RAISE:
raise
state = default
return state


_NO_SIDE_EFFECTS_ATTRIBUTE_NAME = "_no_side_effects"


def no_side_effects(
fn: Callable, *, raise_for_side_effects: Literal["never", "structural"] = "structural"
) -> Callable:
"""A decorator for a function that does not alter the InvocationContext.

Module methods wrapped with this decorator will not wrapped in an auto-child context.

This allows calling them even if other methods in the same module have already been called,
without the caller needing to explicitly create a new InvocationContext.

Any side effects of `fn` on the invocation context will be reverted after `fn()` returns because
`fn()` is called on a copy of the real context. This is true regardless of the value of
`raise_for_side_effects`. However, one may control whether an error is generated
by setting `raise_for_side_effects`.

Args:
fn: The function to wrap.
raise_for_side_effects: What to do if `fn` is detected as having side effects.
* "structural" (Default). This raises if the side
effects change the pruned structure of the OutputCollection.
* "never" Does not raise. Any side effects will therefore be silently
dropped.

Returns:
The wrapped function.

Raises:
ValueError: If `fn` has side effects on the InvocationContext.
"""
if raise_for_side_effects not in ["never", "structural"]:
raise ValueError(f"Invalid value of `raise_for_side_effects`: {raise_for_side_effects}.")

@no_stack_summary
@functools.wraps(fn)
def wrapper(*args, **kwargs):
ctx = current_context()
if ctx is None:
return fn(*args, **kwargs)

result, output_collection = ctx.functional(fn)(*args, **kwargs)
if raise_for_side_effects == "structural" and not same_pruned_structure(
output_collection, current_context().output_collection
):
raise ValueError(
f"Function '{fn.__name__}' wrapped with `no_side_effects` has side effects.\n"
"Expected no change to output collection other than the possible creation of\n"
"empty subtrees.\n"
"Did you accidentally call `add_summary()` or `add_module_output()`?"
)
return result

setattr(wrapper, _NO_SIDE_EFFECTS_ATTRIBUTE_NAME, True)
return wrapper
Loading