Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove presere ops #6360

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/pytorch.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
d1b87e26e5c4343f5b56bb1e6f89b479b389bfac
export-D64151426
16 changes: 10 additions & 6 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,9 +925,12 @@ def _gen_edge_manager_for_partitioners(
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
all_ops_no_decomp |= set(curr_ops_no_decomp)

program = program.run_decompositions(
_default_decomposition_table(), _preserve_ops=tuple(all_ops_no_decomp)
)
table = _default_decomposition_table()

for op in all_ops_no_decomp:
table.pop(op, None)

program = program.run_decompositions(table)
# Among all the preserved aten ops, use the check_op_fn to do an additional
# check on which ops need to be preserved and which ops need to be decomposed
# Those which are truly preserved will be replaced with transformed ops
Expand Down Expand Up @@ -1097,9 +1100,10 @@ def to_edge_with_preserved_ops(

for name, program in aten_programs.items():
# Decompose to Core ATen
program = program.run_decompositions(
_default_decomposition_table(), _preserve_ops=preserve_ops
)
table = _default_decomposition_table()
for op in preserve_ops:
table.pop(op, None)
program = program.run_decompositions(table)
edge_programs[name] = _generate_edge_program(
name, config, program, list(preserve_ops)
)
Expand Down
8 changes: 4 additions & 4 deletions exir/program/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,10 +573,10 @@ def get_num_nondecomposed_ops(self, ep, partitioner):
# which pass the filter_ops fn given by the partitioner
reference_ep = copy.deepcopy(ep)
aten_ops_not_decomposed, filter_ops = partitioner.ops_to_not_decompose(ep)
reference_decomp_ep = reference_ep.run_decompositions(
decomp_table=_default_decomposition_table(),
_preserve_ops=tuple(aten_ops_not_decomposed),
)
table = _default_decomposition_table()
for op in aten_ops_not_decomposed:
table.pop(op, None)
reference_decomp_ep = reference_ep.run_decompositions(decomp_table=table)
num_non_decomposed_aten_ops = 0
for node in reference_decomp_ep.graph.nodes:
if (
Expand Down
5 changes: 3 additions & 2 deletions exir/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@
from executorch.exir.types import ValueSpec

from torch._C import _EnableTorchFunction, DisableTorchFunctionSubclass # @manual
from torch._decomp import core_aten_decompositions, get_decompositions
from torch._decomp import get_decompositions
from torch._dynamo.guards import Guard
from torch._functorch.eager_transforms import _maybe_unwrap_functional_tensor
from torch.export import default_decompositions
from torch.func import functionalize
from torch.fx.operator_schemas import normalize_function
from torch.utils._pytree import TreeSpec
Expand Down Expand Up @@ -631,7 +632,7 @@ def _default_decomposition_table(
# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e...
return get_decompositions(decomp_opset)
# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir....
return core_aten_decompositions()
return default_decompositions()


def dynamo_trace(
Expand Down
Loading