diff --git a/axlearn/common/inference_pipeline.py b/axlearn/common/inference_pipeline.py index 36c5abe0..dff14c06 100644 --- a/axlearn/common/inference_pipeline.py +++ b/axlearn/common/inference_pipeline.py @@ -44,8 +44,10 @@ def prune_non_tf_str_tensor_leaves(_: str, subtree: NestedTensor): return subtree.dtype != tf.string - batch_without_str_tensors = utils.prune_tree(batch, prune_tf_str_tensor_leaves) - str_tensors = utils.prune_tree(batch, prune_non_tf_str_tensor_leaves) + batch_without_str_tensors = utils.prune_tree( + batch, prune_tf_str_tensor_leaves, prune_root=False + ) + str_tensors = utils.prune_tree(batch, prune_non_tf_str_tensor_leaves, prune_root=False) return batch_without_str_tensors, str_tensors diff --git a/axlearn/common/learner.py b/axlearn/common/learner.py index 1d38bfbd..4c38d00f 100644 --- a/axlearn/common/learner.py +++ b/axlearn/common/learner.py @@ -46,7 +46,7 @@ Tensor, flatten_items, match_regex_rules, - prune_tree, + prune_empty, register_per_param_settings, tree_paths, ) @@ -90,22 +90,6 @@ def should_apply_state_updates(update_type: UpdateType) -> bool: return update_type in (UpdateType.STATE_UPDATES, UpdateType.ALL_UPDATES) -def _prune_empty(in_tree: Nested[Tensor]) -> Nested[Tensor]: - """Returns a shallow copy of the input tree with empty subtrees pruned. - - If a tree would be made empty by removal of its subtrees, it will also be pruned. - This is a shallow copy because leaf nodes (non-dict values) are not deep-copied. - - Args: - in_tree: the input tree to be pruned. - - Returns: - The pruned copy of the input tree. - """ - # Note that falsey values or empty Tensors are not considered empty. - return prune_tree(in_tree, lambda _, v: isinstance(v, dict) and not v) - - class BaseLearner(LearnerModule): """The base class of a learner.""" @@ -310,7 +294,7 @@ def _compute_updated_params( updated_model_params = optax.apply_updates( jax.tree_util.tree_map(lambda op: op.value, opt_params), parameter_updates ) - state_updates = _prune_empty(state_updates) + state_updates = prune_empty(state_updates, fill_value=optax.MaskedNode()) apply_state_updates = jax.tree_util.tree_map( should_apply_state_updates, self._update_types(state_updates), diff --git a/axlearn/common/learner_test.py b/axlearn/common/learner_test.py index 4cbcb8d0..df7fbc33 100644 --- a/axlearn/common/learner_test.py +++ b/axlearn/common/learner_test.py @@ -22,7 +22,6 @@ Learner, UpdateType, _apply_updates, - _prune_empty, _split_gradients, _value_and_grad, should_update_with_optimizers, @@ -55,6 +54,7 @@ VDict, flatten_items, match_regex_rules, + prune_empty, tree_paths, ) @@ -86,7 +86,7 @@ def test_prune_empty_state(self): }, }, } - actual = _prune_empty(state) + actual = prune_empty(state) self.assertNestedAllClose(expected, actual) @parameterized.product(ema_decay=(None, 0.9), method=("update", "forward_and_backward")) diff --git a/axlearn/common/module.py b/axlearn/common/module.py index fb1ac4f0..c066a1d8 100644 --- a/axlearn/common/module.py +++ b/axlearn/common/module.py @@ -44,7 +44,19 @@ def do_foo(self, ...): import re import threading from dataclasses import dataclass -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + NamedTuple, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) import jax import numpy as np @@ -59,9 +71,11 @@ def do_foo(self, ...): Nested, NestedTensor, Tensor, + get_recursively, partial_with_fn_metadata, prune_tree, raise_for_cycles, + same_pruned_structure, ) @@ -540,6 +554,8 @@ def _should_wrap_method(method: str) -> bool: fn_sig = inspect.signature(fn) if "self" not in fn_sig.parameters: return False + if getattr(fn, _NO_SIDE_EFFECTS_ATTRIBUTE_NAME, False): + return False return True return { @@ -990,14 +1006,12 @@ def scan_fn(carry_i: NestedTensor, scan_i: NestedTensor): # Filter output collection. if drop_output is not None: - pruned_collection_i = new_output_collection()._asdict() - pruned_collection_i.update( - prune_tree( - output_collection_i._asdict(), - lambda path, _: drop_output(path), - ) + pruned_output_collection_i = prune_tree( + output_collection_i._asdict(), lambda path, _: drop_output(path), prune_root=False ) - output_collection_i = OutputCollection(**pruned_collection_i) + output_collection_i = new_output_collection()._asdict() + output_collection_i.update(pruned_output_collection_i) + output_collection_i = OutputCollection(**output_collection_i) return carry_i, dict(y_i=y_i, output_collection=output_collection_i) @@ -1009,3 +1023,124 @@ def scan_fn(carry_i: NestedTensor, scan_i: NestedTensor): ) return carry, scan_ys["y_i"] + + +def get_state_or_update(module: Module, key: str) -> Optional[Nested[Tensor]]: + """Return `module.state[key]`, inclusive of any state update. + + Unlike `module.state[key]`, this works even if in a parent context. + + Args: + module: The module to get the state of. + key: The key to lookup in the state dict. + state: The state tree to traverse, rooted at the `current_context().module`. + If None, use `current_context().module`. + + Returns + The requested state. + + Raises: + KeyError: If it is not found. + """ + ctx = current_context() + not_found = object() + state = get_state(module, state=ctx.get_state_updates(), default=not_found) + if state is not not_found and key in state: + return state[key] + return get_state(module)[key] + + +# A sentinel that can be passed as the `default` parameter of `get_module_state()`. +RAISE = object() +_Default = TypeVar("_Default") + + +def get_state( + module: Module, *, state: Optional[Nested[Tensor]] = None, default: _Default = RAISE +) -> Union[Nested[Tensor], _Default]: + """Return `module.state`. + + Unlike `module.state`, this works even if in a parent context. + + Args: + module: The module to get the state of. + state: The state tree to traverse, rooted at the `current_context().module`. + If None, use `current_context().module`. + default: The default value to return if it is not found. If set to the sentinel RAISE, + a KeyError is raised instead of returning. + + Returns + The requested state. + + Raises: + KeyError: If it is not found. + """ + ctx = current_context() + if state is None: + state = ctx.state + path = ctx.module.path_to_descendant_module(module) + try: + state = get_recursively(state, path) + except KeyError: + if default is RAISE: + raise + state = default + return state + + +_NO_SIDE_EFFECTS_ATTRIBUTE_NAME = "_no_side_effects" + + +def no_side_effects( + fn: Callable, *, raise_for_side_effects: Literal["never", "structural"] = "structural" +) -> Callable: + """A decorator for a function that does not alter the InvocationContext. + + Module methods wrapped with this decorator will not wrapped in an auto-child context. + + This allows calling them even if other methods in the same module have already been called, + without the caller needing to explicitly create a new InvocationContext. + + Any side effects of `fn` on the invocation context will be reverted after `fn()` returns because + `fn()` is called on a copy of the real context. This is true regardless of the value of + `raise_for_side_effects`. However, one may control whether an error is generated + by setting `raise_for_side_effects`. + + Args: + fn: The function to wrap. + raise_for_side_effects: What to do if `fn` is detected as having side effects. + * "structural" (Default). This raises if the side + effects change the pruned structure of the OutputCollection. + * "never" Does not raise. Any side effects will therefore be silently + dropped. + + Returns: + The wrapped function. + + Raises: + ValueError: If `fn` has side effects on the InvocationContext. + """ + if raise_for_side_effects not in ["never", "structural"]: + raise ValueError(f"Invalid value of `raise_for_side_effects`: {raise_for_side_effects}.") + + @no_stack_summary + @functools.wraps(fn) + def wrapper(*args, **kwargs): + ctx = current_context() + if ctx is None: + return fn(*args, **kwargs) + + result, output_collection = ctx.functional(fn)(*args, **kwargs) + if raise_for_side_effects == "structural" and not same_pruned_structure( + output_collection, current_context().output_collection + ): + raise ValueError( + f"Function '{fn.__name__}' wrapped with `no_side_effects` has side effects.\n" + "Expected no change to output collection other than the possible creation of\n" + "empty subtrees.\n" + "Did you accidentally call `add_summary()` or `add_module_output()`?" + ) + return result + + setattr(wrapper, _NO_SIDE_EFFECTS_ATTRIBUTE_NAME, True) + return wrapper diff --git a/axlearn/common/module_test.py b/axlearn/common/module_test.py index ffde4ab3..b7483549 100644 --- a/axlearn/common/module_test.py +++ b/axlearn/common/module_test.py @@ -4,6 +4,7 @@ # pylint: disable=protected-access # type: ignore[attribute-error] import contextlib +import functools import threading from typing import List, Optional, Union @@ -29,8 +30,11 @@ ) from axlearn.common.module import functional as F from axlearn.common.module import ( + get_state, + get_state_or_update, install_context_stack, new_output_collection, + no_side_effects, scan_in_context, set_current_context, ) @@ -729,6 +733,119 @@ def test_get_shared_module(self): inputs=dict(shared_module_name="outer_shared"), ) + def test_no_side_effects(self): + class MyModule(Module): + """Module for testing `no_side_effects` decorator.""" + + @no_side_effects + def has_outputs(self) -> bool: + return "asdf" in get_state( + self, state=current_context().get_module_outputs(), default={} + ) + + def _get_info(self, x: Tensor, *, y: Tensor): + z = 1 + if self.has_outputs(): + z = get_state(self, state=current_context().get_module_outputs())["asdf"] + return 3 * x + y * z + + @no_side_effects + def get_info(self, x: Tensor, *, y: Tensor): + return self._get_info(x, y=y) + + def get_info_without_decorator(self, x: Tensor, *, y: Tensor): + return self._get_info(x, y=y) + + @no_side_effects + def does_not_need_context(self): + return 5 + + @no_side_effects + def actually_has_side_effects(self): + self.add_output() + + def add_output(self): + self.add_module_output("asdf", 5) + + @no_side_effects + def set_state_update(self): + current_context().set_state_update(0) + + @functools.partial(no_side_effects, raise_for_side_effects="never") + def has_silently_reverted_side_effects(self): + self.add_output() + return 5 + + class Parent(Module): + def __init__(self, cfg: Module.Config, *, parent: Optional["Module"]): + super().__init__(cfg, parent=parent) + self._add_child("child", MyModule.default_config()) + + with test_utils.bind_module(Parent.default_config(), state={}) as parent: + mod = parent.child + self.assertEqual(mod.has_outputs(), False) + # Check that repeated calling works. + # This would cause an output conflict error if we didn't use functional(). + self.assertEqual(mod.get_info(2, y=5), 11) + self.assertEqual(mod.get_info(2, y=7), 13) + mod.get_info_without_decorator(2, y=5) + with self.assertRaises(OutputConflictError): + mod.get_info_without_decorator(2, y=7) + + with test_utils.bind_module(parent, state={}): + # Test that calling a side-effecting function fails. + with self.assertRaisesRegex(ValueError, "has side effects"): + mod.actually_has_side_effects() + # Check that side effects from the side-effecting function were prevented. + self.assertEqual(mod.get_info(2, y=5), 11) + self.assertEqual(current_context().output_collection, new_output_collection()) + # Check @no_side_effects methods work even if there were side effects prior to the + # no_side_effects() call. + mod.add_output() + self.assertEqual(mod.get_info(2, y=5), 31) + + with test_utils.bind_module(parent, state={}): + # Check that replacing a leaf in the OutputCollection without changing the + # tree structure results in the change to the leaf not affecting the original + # context. + current_context().set_state_update(1) + mod.set_state_update() # Tries to set it to 0, but silently fails. + self.assertEqual(current_context().get_state_updates(), 1) + + with test_utils.bind_module(parent, state={}): + # Check that raise_for_side_effects = "never" works. + self.assertEqual(mod.has_silently_reverted_side_effects(), 5) + self.assertFalse(mod.has_outputs()) + + # Test calling without a context works. + self.assertEqual(mod.does_not_need_context(), 5) + with self.assertRaises(AttributeError): + # Fails since we are not in an InvocationContext anymore. + mod.get_info_without_decorator(2, y=5) + + def test_get_state_and_or_update(self): + """Tests `get_state()` and `get_state_or_update()`.""" + module1 = new_test_module("root") + module1._add_child("child1", TestModule.default_config()) + module1._add_child("child2", TestModule.default_config()) + state = {"child1": {"x": 1}, "child2": {"x": 2}} + state_update = {"child1": {"x": 3}, "child2": {"x": 4}} + with test_utils.bind_module(module1, state=state): + self.assertEqual(get_state(module1.child1)["x"], 1) + self.assertEqual(get_state(module1.child2)["x"], 2) + self.assertEqual(get_state_or_update(module1.child1, "x"), 1) + self.assertEqual(get_state_or_update(module1.child2, "x"), 2) + current_context().set_state_update(state_update) + self.assertEqual(get_state(module1.child1)["x"], 1) + self.assertEqual(get_state(module1.child2)["x"], 2) + self.assertEqual(get_state_or_update(module1.child1, "x"), 3) + self.assertEqual(get_state_or_update(module1.child2, "x"), 4) + + with self.assertRaises(KeyError): + get_state(module1.child1)["asdf"] # pylint: disable=expression-not-assigned + with self.assertRaises(KeyError): + get_state_or_update(module1.child1, "asdf") + class ScanInContextTest(TestWithTemporaryCWD): """Tests scan_in_context.""" diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index d9c6a8c0..4cbb800b 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -66,6 +66,8 @@ # The set of supported floating point dtypes. _supported_float_dtypes = [jnp.bfloat16, jnp.float32] +T = TypeVar("T") + @dataclasses.dataclass class HybridMeshShape: @@ -158,6 +160,29 @@ def _concat(*, prefix: str, suffix: str, separator: str): return f"{prefix}{separator}{suffix}" if prefix else f"{suffix}" +def _key_entry_to_str(key_entry: KeyEntry) -> str: + """Convert a pytree child key to a string suitable for building AXLearn paths. + + See `tree_paths()` for an example. + """ + # Although (e.g.) DictKey does have its own __str__ implementation, calling + # str(DictKey('a')) produces "['a']" instead of just "a". + if isinstance(key_entry, jax.tree_util.DictKey): + key = key_entry.key + elif isinstance(key_entry, jax.tree_util.GetAttrKey): + key = key_entry.name + elif isinstance(key_entry, jax.tree_util.SequenceKey): + key = key_entry.idx + elif isinstance(key_entry, jax.tree_util.FlattenedIndexKey): + key = key_entry.key + else: + raise RuntimeError(f"Unknown key entry type {type(key_entry)}: {key_entry}.") + + # Use f-string instead of calling str() because it matches the behavior of the previous + # implementation and differs from str() for (e.g.) enums. + return f"{key}" + + def tree_paths( tree: NestedTree, separator: str = "/", is_leaf: Optional[Callable[[Any], bool]] = None ) -> NestedTree: @@ -179,26 +204,8 @@ def tree_paths( tree_paths. """ - def key_entry_to_str(key_entry: KeyEntry) -> str: - # Although (e.g.) DictKey does have its own __str__ implementation, calling - # str(DictKey('a')) produces "['a']" instead of just "a". - if isinstance(key_entry, jax.tree_util.DictKey): - key = key_entry.key - elif isinstance(key_entry, jax.tree_util.GetAttrKey): - key = key_entry.name - elif isinstance(key_entry, jax.tree_util.SequenceKey): - key = key_entry.idx - elif isinstance(key_entry, jax.tree_util.FlattenedIndexKey): - key = key_entry.key - else: - raise RuntimeError(f"Unknown key entry type {type(key_entry)}: {key_entry}.") - - # Use f-string instead of calling str() because it matches the behavior of the previous - # implementation and differs from str() for (e.g.) enums. - return f"{key}" - return jax.tree_util.tree_map_with_path( - lambda kp, _: separator.join(key_entry_to_str(k) for k in kp), tree, is_leaf=is_leaf + lambda kp, _: separator.join(_key_entry_to_str(k) for k in kp), tree, is_leaf=is_leaf ) @@ -971,37 +978,165 @@ def partial_with_fn_metadata(fn, *args, **kwargs): return functools.update_wrapper(partial_fn, fn) +F = TypeVar("F") + + def prune_tree( - in_tree: NestedTensor, - should_prune: Callable[[str, NestedTensor], bool], + tree: T, + should_prune: Callable[[str, Any], bool], *, - prefix: str = "", + fill_value: F = None, + prune_root: bool = True, separator: str = "/", -): - """Returns a shallow copy of the input tree with subtrees pruned based on `should_prune`. + prefix: str = "", +) -> Union[T, F]: + """Returns a copy of the input pytree with subtrees pruned based on `should_prune`. - This is a shallow copy because leaf nodes (non-dict values) are not deep-copied. + The tree structure is deep copied but the leaf values are shallow copied. + + Pruned nodes have their value replaced with `fill_value`. For nodes that are children of a + dict-like object (nodes that have a DictKey as their key), we also prune the key itself from + the parent. + + Example: + ``` + assert prune_tree({}, should_prune=lambda path, v: True) is None + assert prune_tree(dict(a=3, b=4), should_prune=lambda path, v: v == 3) == dict(b=4) + # Assuming MyNode uses (e.g.) a GetAttrKey for child 'a'. + assert prune_tree(MyNode(a=3), should_prune=lambda path, v: v == 3) == MyNode(a=None) + ``` Args: - in_tree: The input tree to be pruned. + tree: The pytree to be pruned. should_prune: A callable which takes (path, subtree) as input and returns a boolean. The - subtree provided will have already been pruned. If the callable returns True, the - subtree itself will be dropped. - prefix: Path prefix. + subtree provided will have already had its children pruned. + If the callable returns True, the subtree itself will be pruned. + fill_value: The value to replace pruned subtrees with. + prune_root: Whether the root may be pruned. separator: Separator used to join path parts. + prefix: Path prefix. + + Returns: + The pruned copy of the input tree. + """ + + def unflatten_one_level(treedef: jax.tree_util.PyTreeDef, children: Sequence) -> Any: + """Create a pytree from a non-leaf `treedef` whose immediate children are `children`.""" + placeholder_child_treedefs = len(children) * (jax.tree_util.tree_structure(object()),) + treedef = treedef.make_from_node_data_and_children( + registry=jax.tree_util.default_registry, + node_data=treedef.node_data(), + children=placeholder_child_treedefs, + ) + return treedef.unflatten(children) + + def maybe_remove_keys(tree: Any, keys: set[KeyEntry]): + """Remove top-level DictKey keys from tree. + + This mutates `tree`. + """ + for key in keys: + if isinstance(key, jax.tree_util.DictKey): + del tree[key.key] + + # Sentinel value to replace pruned trees with. + # This is replaced with fill_value later. + pruned = object() + + def _prune_tree( + tree: Any, *, treedef: jax.tree_util.PyTreeDef, prefix: str, is_root: bool + ) -> Any: + children = [] + pruned_keys = set() + for (k, v), v_treedef in jax.util.safe_zip(pytree_children(tree), treedef.children()): + v_path = _concat(prefix=prefix, suffix=_key_entry_to_str(k), separator=separator) + v = _prune_tree(v, treedef=v_treedef, prefix=v_path, is_root=False) + if v is pruned: + pruned_keys.add(k) + v = fill_value + children.append(v) + + # Don't try to unflatten (nonstrict) pytree leaves. + if children: + tree = unflatten_one_level(treedef, children) + maybe_remove_keys(tree, pruned_keys) + + if is_root and not prune_root: + return tree + if should_prune(prefix, tree): + return pruned + return tree + + tree = jax.tree_util.tree_map(lambda x: x, tree) + treedef = jax.tree_util.tree_structure(tree) + tree = _prune_tree(tree, treedef=treedef, prefix=prefix, is_root=True) + if tree is pruned: + return fill_value + return tree + + +def prune_empty(tree: T, *, fill_value: Any = None) -> Optional[T]: + """Returns a copy of the input pytree with empty subtrees pruned. + + A subtree is empty if it has no pytree leaves. + See `prune_tree()` for additional semantics. + + Args: + tree: The pytree to prune. + fill_value: The value to replace pruned subtrees with. Unlike prune_tree(), this + replacement is only done after pruning, so it cannot turn empty subtrees + into non-empty subtrees if `fill_value` is a non-empty subtree. Returns: The pruned copy of the input tree. """ - if isinstance(in_tree, dict): - out_tree = {} - for k, v in in_tree.items(): - path = _concat(prefix=prefix, suffix=k, separator=separator) - v = prune_tree(v, should_prune, prefix=path, separator=separator) - if not should_prune(path, v): - out_tree[k] = v - in_tree = out_tree - return in_tree + + @dataclasses.dataclass + class Reference: + """A hashable reference to an object that compares equal based on reference equality. + + This is safer than using `id(value)` as a dictionary key because Python can reuse + ids if the original object is garbage collected. + """ + + value: Any + + def __hash__(self): + return id(self.value) + + def __eq__(self, other): + return self.value is other.value + + # Maps Reference(tree) -> num_leaves. + leaf_counts: dict[Reference, int] = {} # pytype: disable=invalid-annotation + # Sentinel for tracking where pruning happens. + pruned = object() + + def num_leaves(tree: Any) -> int: + if Reference(tree) in leaf_counts: + return leaf_counts[Reference(tree)] + children = pytree_children(tree) + if tree is pruned: + count = 0 + elif not children: + # E.g., 0 for [], 1 for object() + count = jax.tree_util.tree_structure(tree).num_leaves + else: + # Key is guaranteed to exist because prune_tree guarantees that it + # calls `should_prune` on child nodes before their parent. + count = sum(leaf_counts[Reference(c)] for _, c in children) + leaf_counts[Reference(tree)] = count + return count + + # Pre-populate with number of leaves of `pruned` (0). + num_leaves(pruned) + + def should_prune_empty(path: str, tree: T) -> bool: + del path + return num_leaves(tree) == 0 + + tree = prune_tree(tree, should_prune=should_prune_empty, fill_value=pruned) + return jax.tree_util.tree_map(lambda x: x if x is not pruned else fill_value, tree) @dataclasses.dataclass @@ -1060,9 +1195,6 @@ def get_or_none(x: Optional[Dict], key: Any) -> Optional[Any]: return None if x is None else x.get(key) -T = TypeVar("T") - - def match_regex_rules( x: str, *, rules: Sequence[Tuple[str, T]], default_value: Optional[T] = None ) -> Optional[T]: @@ -1428,3 +1560,27 @@ def raise_for_cycles(tree: Any): f"Descendant KeyPath: {cycles['descendant']}.\n" f"Ancestor KeyPath: {cycles['ancestor']}." ) + + +def same_pruned_structure(tree_1: Any, tree_2: Any) -> bool: + """Returns whether the two input trees have the same structure after pruning empty subtrees. + + A subtree is considered empty if it has no pytree leaves. + + Example: + ``` + assert same_pruned_structure(dict(a=1), dict(a=2, dict(b=[]) ) + ``` + + Args: + tree_1: The first tree to compare. + tree_2: The second tree to compare. + + Returns: + Whether the trees are the same after pruning. + """ + + def pruned_structure(tree: Any): + return jax.tree_util.tree_structure(prune_empty(tree)) + + return pruned_structure(tree_1) == pruned_structure(tree_2) diff --git a/axlearn/common/utils_test.py b/axlearn/common/utils_test.py index 2fa0f1bb..b44f0152 100644 --- a/axlearn/common/utils_test.py +++ b/axlearn/common/utils_test.py @@ -8,6 +8,8 @@ from collections import OrderedDict from typing import Any, Iterable, NamedTuple, Optional, Sequence, Type, Union +import chex + # pylint: disable=no-self-use import jax import jaxlib @@ -65,9 +67,11 @@ infer_mesh_shape, input_partition_spec, match_regex_rules, + prune_empty, prune_tree, pytree_children, runtime_checks, + same_pruned_structure, set_data_dir, set_recursively, split_prng_key, @@ -306,10 +310,8 @@ class TestUnstructured: ) # No children - self.assertSequenceEqual(pytree_children([]), []) - - # No children - self.assertSequenceEqual(pytree_children(3), []) + for val in None, 3, [], {}, (), set(), "", "asdf", True, False, object(): + self.assertSequenceEqual(pytree_children(val), [], msg=val) def test_find_cycles(self): x = {} @@ -1331,6 +1333,94 @@ def test(self): prune_tree(in_tree, lambda _, v: isinstance(v, int)), ) + @parameterized.product(tree=[{}, [], None, 3, dict(a=object()), object()]) + def test_root(self, tree: Any): + """Tests that root can be pruned.""" + self.assertIs(prune_tree(tree, should_prune=lambda path, v: True), None, msg=tree) + + def test_prune_empty(self): + """Test empty tree pruning with custom pytree nodes.""" + + # Dict-like node. (Uses DictKey). + # Use chex assert to automatically check type is correct. + # E.g., self.assertEqual(VDict(), dict()) passes but we don't want it to. + chex.assert_trees_all_equal(prune_empty(VDict(a=[], b=3)), VDict(b=3)) + + # Non-dict-like nodes. (Does not use DictKey.) + @struct.dataclass + class MyNodeDataclass: + a: Optional[int] + b: Optional[int] = None + + class MyNodeTuple(NamedTuple): + a: int + b: Optional[int] = None + + for cls in MyNodeTuple, MyNodeDataclass: + msg = str(cls) + # pytype: disable=wrong-arg-types + self.assertEqual(prune_empty(cls(a=cls(a=None))), None, msg=msg) + self.assertEqual(prune_empty(cls(a=[])), None, msg=msg) + self.assertEqual(prune_empty(cls(a=None)), None, msg=msg) + self.assertEqual(prune_empty(cls(a=[], b=3)), cls(a=None, b=3), msg=msg) + # pytype: enable=wrong-arg-types + + def test_fill_value(self): + """Tests the `fill_value` argument of `prune_empty()` and `prune_tree()`.""" + fill_value = object() + self.assertEqual(prune_empty(dict(a=None, b=2), fill_value=fill_value), dict(b=2)) + self.assertEqual(prune_empty(dict(a=fill_value), fill_value=fill_value), dict(a=fill_value)) + self.assertEqual( + prune_empty(Combo(head=3, tail=None), fill_value=fill_value), + Combo(head=3, tail=fill_value), + ) + + # Test behavior difference between prune_empty and prune_tree when fill_value makes tree + # non-empty during pruning. + should_prune_empty = lambda _, x: len(jax.tree_util.tree_leaves(x)) == 0 + + # The prune_tree case is the same as prune_empty because pruned subtrees with a DictKey key + # are always removed from teh parent regardless of fill_value. + self.assertEqual(prune_empty(dict(a=[]), fill_value=fill_value), fill_value) + self.assertEqual( + prune_tree(dict(a=[]), fill_value=fill_value, should_prune=should_prune_empty), + fill_value, + ) + # The prune_tree case is different from prune_empty because we don't (can't) remove keys + # from named tuples, so the fill_value stays in the tree and makes the should_prune + # call on the parent return false. + self.assertEqual( + prune_empty(Combo(head=None, tail=None), fill_value=fill_value), fill_value + ) + self.assertEqual( + prune_tree( + Combo(head=None, tail=None), fill_value=fill_value, should_prune=should_prune_empty + ), + Combo(head=fill_value, tail=fill_value), + ) + + def test_prune_root(self): + """Tests the `prune_root` argument of `prune_tree()`.""" + self.assertEqual( + prune_tree(dict(a=3, b=4), should_prune=lambda _, v: v != 3, prune_root=True), None + ) + self.assertEqual( + prune_tree(dict(a=3, b=4), should_prune=lambda _, v: v != 3, prune_root=False), + dict(a=3), + ) + + @parameterized.parameters( + [(), VDict(), True], + [dict(a=3), dict(a=4, b=dict(c=None)), True], + [dict(a=3), dict(a=4, b=dict(c=[])), True], + [VDict(a=3), VDict(a=4, b=VDict(c=None)), True], + [dict(a=3), dict(b=3), False], + [VDict(a=3), VDict(b=3), False], + [VDict(a=3), dict(a=3), False], + ) + def test_same_pruned_structure(self, tree_1: Any, tree_2: Any, expected: bool): + self.assertEqual(same_pruned_structure(tree_1, tree_2), expected) + @dataclasses.dataclass(frozen=True) class DummyDevice: