Skip to content

Commit

Permalink
Make trace dispatch purely a function of context rather than a functi…
Browse files Browse the repository at this point in the history
…on of both context and data. This lets us delete a lot of machinery for managing data-dependent tracing: levels, sublevels, post_process_call, new_base_main, custom_bind and so on.

PiperOrigin-RevId: 681582933
  • Loading branch information
dougalm authored and copybara-github committed Oct 29, 2024
1 parent e153d50 commit a536c89
Showing 1 changed file with 34 additions and 29 deletions.
63 changes: 34 additions & 29 deletions haiku/_src/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ def method_hook(mod: module.Module, method_name: str):
graph_stack.peek().subgraphs.append(subg.evolve(title=title))

with graph_stack(graph), \
module.hook_methods(method_hook), \
jax.core.new_main(DotTrace) as main:
out_flat = _interpret_subtrace(flat_fun, main).call_wrapped(*args_flat)
module.hook_methods(method_hook):
tag = jax.core.TraceTag()
out_flat = _interpret_subtrace(flat_fun, tag).call_wrapped(*args_flat)
out = jax.tree.unflatten(out_tree(), out_flat)

return graph, args, out
Expand All @@ -163,20 +163,20 @@ def method_hook(mod: module.Module, method_name: str):


@lu.transformation
def _interpret_subtrace(main, *in_vals):
trace = DotTrace(main, jax.core.cur_sublevel())
in_tracers = [DotTracer(trace, val) for val in in_vals]
outs = yield in_tracers, {}
out_tracers = map(trace.full_raise, outs)
out_vals = [t.val for t in out_tracers]
yield out_vals
def _interpret_subtrace(tag, *in_vals):
with jax.core.take_current_trace() as parent_trace:
trace = DotTrace(parent_trace, tag)
with jax.core.set_current_trace(trace):
in_tracers = [DotTracer(trace, val) for val in in_vals]
outs = yield in_tracers, {}
yield [trace.to_val(t) for t in outs]


class DotTracer(jax.core.Tracer):
"""JAX tracer used in DotTrace."""

def __init__(self, trace, val):
super().__init__(trace)
self._trace = trace
self.val = val

@property
Expand All @@ -190,61 +190,66 @@ def full_lower(self):
class DotTrace(jax.core.Trace):
"""Traces a JAX function to dot."""

def pure(self, val):
return DotTracer(self, val)
def __init__(self, parent_trace, tag):
self.parent_trace = parent_trace
self.tag = tag

def lift(self, val):
return DotTracer(self, val)

def sublift(self, val):
return DotTracer(self, val.val)
def to_val(self, val):
if isinstance(val, DotTracer) and val._trace.tag is self.tag: # pylint:disable=protected-access
return val.val
else:
return val

def process_primitive(self, primitive, tracers, params):
val_out = primitive.bind(*[t.val for t in tracers], **params)
vals = [self.to_val(t) for t in tracers]
val_out = primitive.bind_with_trace(self.parent_trace, vals, params)
if primitive is pjit.pjit_p:
f = jax.core.jaxpr_as_fun(params['jaxpr'])
f.__name__ = params['name']
fun = lu.wrap_init(f)
return self.process_call(primitive, fun, tracers, params)

inputs = [t.val for t in tracers]
outputs = list(jax.tree.leaves(val_out))

graph = graph_stack.peek()
node = Node(id=outputs[0], title=str(primitive), outputs=outputs)
graph.nodes.append(node)
graph.edges.extend([(i, outputs[0]) for i in inputs])
graph.edges.extend([(i, outputs[0]) for i in vals])

return jax.tree.map(lambda v: DotTracer(self, v), val_out)

def process_call(self, call_primitive, f, tracers, params):
assert call_primitive.multiple_results
if (call_primitive in (pjit.pjit_p,) and
params.get('inline', False)):
f = _interpret_subtrace(f, self.main)
vals_out = f.call_wrapped(*[t.val for t in tracers])
return [DotTracer(self, v) for v in vals_out]
f = _interpret_subtrace(f, self.tag)
with jax.core.set_current_trace(self.parent_trace):
vals_out = f.call_wrapped(*[self.to_val(t) for t in tracers])
return [DotTracer(self, v) for v in vals_out]

graph = Graph.create(title=f'{call_primitive} ({name_or_str(f.f)})')
graph_stack.peek().subgraphs.append(graph)
with graph_stack(graph):
f = _interpret_subtrace(f, self.main)
vals_out = f.call_wrapped(*[t.val for t in tracers])
return [DotTracer(self, v) for v in vals_out]
f = _interpret_subtrace(f, self.tag)
with jax.core.set_current_trace(self.parent_trace):
vals_out = f.call_wrapped(*[self.to_val(t) for t in tracers])
return [DotTracer(self, v) for v in vals_out]

process_map = process_call

def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *,
symbolic_zeros):
# Drop the custom differentiation rule.
del primitive, jvp, symbolic_zeros # Unused.
return fun.call_wrapped(*tracers)
with jax.core.set_current_trace(self.parent_trace):
return fun.call_wrapped(*tracers)

def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers,
out_trees, symbolic_zeros):
# Drop the custom differentiation rule.
del primitive, fwd, bwd, out_trees, symbolic_zeros # Unused.
return fun.call_wrapped(*tracers)
with jax.core.set_current_trace(self.parent_trace):
return fun.call_wrapped(*tracers)


def _format_val(val):
Expand Down

0 comments on commit a536c89

Please sign in to comment.