Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a universal function to replace subgraph in OptimizationRule #3353

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mars/core/entity/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def __init__(self, *args, **kwargs):
def op(self):
return self._op

@property
def outputs(self):
return self._op.outputs

@property
def inputs(self):
return self.op.inputs
Expand Down
78 changes: 77 additions & 1 deletion mars/optimization/logical/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import itertools
import weakref
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Type, Set

from ...core import OperandType, EntityType, enter_mode
from ...core import OperandType, EntityType, enter_mode, Entity
from ...core.graph import EntityGraph
from ...utils import implements

Expand Down Expand Up @@ -130,6 +131,77 @@ def _replace_node(self, original_node: EntityType, new_node: EntityType):
for succ in successors:
self._graph.add_edge(new_node, succ)

def _replace_subgraph(
self,
graph: Optional[EntityGraph],
removed_nodes: Optional[Set[EntityType]],
ericpai marked this conversation as resolved.
Show resolved Hide resolved
new_results: Optional[List[Entity]] = None,
):
"""
Replace the subgraph from the self._graph represented by a list of nodes with input graph.
It will delete the nodes in removed_nodes with all linked edges first, and then add (or update if it's still
existed in self._graph) the nodes and edges of the input graph.

Parameters
----------
graph : EntityGraph, optional
The input graph. If it's none, no new node and edge will be added.
removed_nodes : Set[EntityType], optional
The nodes to be removed. All the edges connected with them are removed as well.
new_results : List[EntityType], optional, default None
The updated results of the graph. If it's None, then the results will not be updated.

Raises
------
ReplaceSubgraphError
If the input key of the removed node's successor can't be found in the subgraph.
Or some of the nodes of the subgraph are in removed ones.
"""
infected_successors = set()

output_to_node = dict()
removed_nodes = removed_nodes or set()
if graph is not None:
# Add the output key -> node of the subgraph
for node in graph.iter_nodes():
if node in removed_nodes:
raise ReplaceSubgraphError(f"The node {node} is in the removed set")
ericpai marked this conversation as resolved.
Show resolved Hide resolved
for output in node.outputs:
output_to_node[output.key] = node

for node in removed_nodes:
for infected_successor in self._graph.iter_successors(node):
ericpai marked this conversation as resolved.
Show resolved Hide resolved
if infected_successor not in removed_nodes:
infected_successors.add(infected_successor)
# Check whether infected successors' inputs are in subgraph
for infected_successor in infected_successors:
for inp in infected_successor.inputs:
if inp.key not in output_to_node:
raise ReplaceSubgraphError(
f"The output {inp} of node {infected_successor} is missing in the subgraph"
)
for node in removed_nodes:
self._graph.remove_node(node)

if graph is None:
return

# Add the output key -> node of the original graph
for node in self._graph.iter_nodes():
for output in node.outputs:
output_to_node[output.key] = node

for node in graph.iter_nodes():
self._graph.add_node(node)

for node in itertools.chain(graph.iter_nodes(), infected_successors):
for inp in node.inputs:
pred_node = output_to_node[inp.key]
self._graph.add_edge(pred_node, node)

if new_results is not None:
self._graph.results = new_results.copy()
ericpai marked this conversation as resolved.
Show resolved Hide resolved

def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType):
pred_original = self._records.get_original_entity(predecessor, predecessor)
if predecessor not in self._preds_to_remove:
Expand Down Expand Up @@ -283,3 +355,7 @@ def optimize(cls, graph: EntityGraph) -> OptimizationRecords:
graph.results = new_results

return records


class ReplaceSubgraphError(Exception):
pass
13 changes: 13 additions & 0 deletions mars/optimization/logical/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
222 changes: 222 additions & 0 deletions mars/optimization/logical/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import pytest


from ..core import OptimizationRule, ReplaceSubgraphError
from .... import tensor as mt
from .... import dataframe as md


class _MockRule(OptimizationRule):
def apply(self) -> bool:
pass

def replace_subgraph(self, graph, removed_nodes, new_results=None):
self._replace_subgraph(graph, removed_nodes, new_results)


def test_replace_tileable_subgraph():
"""
Original Graph:
s1 ---> c1 ---> v1 ---> v4 ----> v6(output) <--- v5 <--- c5 <--- s5
| ^
| |
V |
v3 ------|
^
|
s2 ---> c2 ---> v2

Target Graph:
s1 ---> c1 ---> v1 ---> v7 ----> v8(output) <--- v5 <--- c5 <--- s5
^
|
s2 ---> c2 ---> v2

The nodes [v3, v4, v6] will be removed.
Subgraph only contains [v7, v8]
"""
s1 = mt.random.randint(0, 100, size=(5, 4))
v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5)
s2 = mt.random.randint(0, 100, size=(5, 4))
v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5)
v3 = v1.add(v2)
v4 = v3.add(v1)
s5 = mt.random.randint(0, 100, size=(5, 4))
v5 = md.DataFrame(s5, columns=list("ABCD"), chunk_size=4)
v6 = v5.sub(v4)
g1 = v6.build_graph()
v7 = v1.sub(v2)
v8 = v7.add(v5)
g2 = v8.build_graph()

# Here we use a trick way to construct the subgraph for test only
key_to_node = dict()
for node in g2.iter_nodes():
key_to_node[node.key] = node
for key, node in key_to_node.items():
if key != v7.key and key != v8.key:
g2.remove_node(node)
r = _MockRule(g1, None, None)
for node in g1.iter_nodes():
key_to_node[node.key] = node

c1 = g1.successors(key_to_node[s1.key])[0]
c2 = g1.successors(key_to_node[s2.key])[0]
c5 = g1.successors(key_to_node[s5.key])[0]

expected_results = [v8.outputs[0]]
r.replace_subgraph(
g2, {key_to_node[op.key] for op in [v3, v4, v6]}, expected_results
)
assert g1.results == expected_results

expected_nodes = {s1, c1, v1, s2, c2, v2, s5, c5, v5, v7, v8}
assert set(g1) == {key_to_node[n.key] for n in expected_nodes}

expected_edges = {
s1: [c1],
c1: [v1],
v1: [v7],
s2: [c2],
c2: [v2],
v2: [v7],
s5: [c5],
c5: [v5],
v5: [v8],
v7: [v8],
v8: [],
}
for pred, successors in expected_edges.items():
pred_node = key_to_node[pred.key]
assert g1.count_successors(pred_node) == len(successors)
for successor in successors:
assert g1.has_successor(pred_node, key_to_node[successor.key])


def test_replace_null_subgraph():
"""
Original Graph:
s1 ---> c1 ---> v1 ---> v3 <--- v2 <--- c2 <--- s2

Target Graph:
c1 ---> v1 ---> v3 <--- v2 <--- c2

The nodes [s1, s2] will be removed.
Subgraph is None
"""
s1 = mt.random.randint(0, 100, size=(10, 4))
v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5)
s2 = mt.random.randint(0, 100, size=(10, 4))
v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5)
v3 = v1.add(v2)
g1 = v3.build_graph()
key_to_node = {node.key: node for node in g1.iter_nodes()}
c1 = g1.successors(key_to_node[s1.key])[0]
c2 = g1.successors(key_to_node[s2.key])[0]
r = _MockRule(g1, None, None)
expected_results = [v3.outputs[0]]
# delete c5 s5 will fail
with pytest.raises(ReplaceSubgraphError) as e:
r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]})
assert g1.results == expected_results
assert set(g1) == {key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, v3}}
expected_edges = {
s1: [c1],
c1: [v1],
v1: [v3],
s2: [c2],
c2: [v2],
v2: [v3],
v3: [],
}
for pred, successors in expected_edges.items():
pred_node = key_to_node[pred.key]
assert g1.count_successors(pred_node) == len(successors)
for successor in successors:
assert g1.has_successor(pred_node, key_to_node[successor.key])

c1.inputs.clear()
c2.inputs.clear()
r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]})
assert g1.results == expected_results
assert set(g1) == {key_to_node[n.key] for n in {c1, v1, c2, v2, v3}}
expected_edges = {
c1: [v1],
v1: [v3],
c2: [v2],
v2: [v3],
v3: [],
}
for pred, successors in expected_edges.items():
pred_node = key_to_node[pred.key]
assert g1.count_successors(pred_node) == len(successors)
for successor in successors:
assert g1.has_successor(pred_node, key_to_node[successor.key])


def test_replace_subgraph_without_removing_nodes():
"""
Original Graph:
s1 ---> c1 ---> v1 ---> v4 <--- v2 <--- c2 <--- s2

Target Graph:
s1 ---> c1 ---> v1 ---> v4 <--- v2 <--- c2 <--- s2
s3 ---> c3 ---> v3

Nothing will be removed.
Subgraph only contains [s3, c3, v3]
"""
s1 = mt.random.randint(0, 100, size=(10, 4))
v1 = md.DataFrame(s1, columns=list("ABCD"), chunk_size=5)
s2 = mt.random.randint(0, 100, size=(10, 4))
v2 = md.DataFrame(s2, columns=list("ABCD"), chunk_size=5)
v4 = v1.add(v2)
g1 = v4.build_graph()

s3 = mt.random.randint(0, 100, size=(10, 4))
v3 = md.DataFrame(s3, columns=list("ABCD"), chunk_size=5)
g2 = v3.build_graph()
key_to_node = {
node.key: node for node in itertools.chain(g1.iter_nodes(), g2.iter_nodes())
}
expected_results = [v3.outputs[0], v4.outputs[0]]
c1 = g1.successors(key_to_node[s1.key])[0]
c2 = g1.successors(key_to_node[s2.key])[0]
c3 = g2.successors(key_to_node[s3.key])[0]
r = _MockRule(g1, None, None)
r.replace_subgraph(g2, None, expected_results)
assert g1.results == expected_results
assert set(g1) == {
key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, s3, c3, v3, v4}
}
expected_edges = {
s1: [c1],
c1: [v1],
v1: [v4],
s2: [c2],
c2: [v2],
v2: [v4],
s3: [c3],
c3: [v3],
v3: [],
v4: [],
}
for pred, successors in expected_edges.items():
pred_node = key_to_node[pred.key]
assert g1.count_successors(pred_node) == len(successors)
for successor in successors:
assert g1.has_successor(pred_node, key_to_node[successor.key])