Skip to content

Commit

Permalink
Cleanup of APLP code (#271)
Browse files Browse the repository at this point in the history
This PR address comments from the last commit
regarding reducing the overhead of instantiating a pool of works at
every iteration.

Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod authored Nov 14, 2024
1 parent 327b84b commit d8ce8d2
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
11 changes: 3 additions & 8 deletions iree/turbine/kernel/wave/scheduling/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import sympy
import math
from functools import partial
from ..utils import safe_subs
import multiprocessing as mp

T = index_symbol("$INITIATION_INTERVAL")
Expand Down Expand Up @@ -202,9 +203,7 @@ def all_pairs_longest_paths_symbolic(


def all_pairs_longest_paths(
graph: fx.Graph,
edges: list[Edge],
T: int,
graph: fx.Graph, edges: list[Edge], T: int, pool: mp.Pool
) -> dict[tuple[fx.Node, fx.Node], IndexExpr]:
"""
For each node in the graph, compute the longest path to all other nodes.
Expand All @@ -228,14 +227,11 @@ def all_pairs_longest_paths(
D[i, j] = edge.weight.delay - edge.weight.iteration_difference * T

# Parallel implementation
pool = mp.get_context("fork").Pool(processes=mp.cpu_count())
for k in range(N):
func = partial(all_pairs_longest_path_parallel, N, D, k)
results = pool.map(func, range(N))
for result in results:
D[result[0]] = result[1]
pool.close()
pool.join()

# Convert from index to node based representation.
G: dict[tuple[fx.Node, fx.Node], int] = {}
Expand All @@ -257,8 +253,7 @@ def evaluate_all_pairs_longest_paths(
"""
D_static = dict(D)
for key in D_static:
if isinstance(D_static[key], sympy.Expr):
D_static[key] = D_static[key].subs(T, initiation_interval)
D_static[key] = safe_subs(D_static[key], [(T, initiation_interval)])
# Remove the negative infinity values and edges to self.
for k in list(D_static.keys()):
if math.isinf(D_static[k]) or k[0] == k[1]:
Expand Down
6 changes: 5 additions & 1 deletion iree/turbine/kernel/wave/scheduling/modulo_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Callable
import numpy as np
import math
import multiprocessing as mp

logger = get_logger("turbine.wave.modulo_scheduling")

Expand Down Expand Up @@ -110,10 +111,11 @@ def schedule_graph(self) -> tuple[dict[fx.Node, int], bool]:
# TODO: Come up with a better heuristic on an upper bound for the initiation interval.
T_max_range = 3 * T0
success = False
pool = mp.get_context("fork").Pool(processes=mp.cpu_count())
for T in range(T0, T0 + T_max_range):
logger.debug(f"Trying initiation interval: {T}.")
self.RT = np.zeros((T, len(self.resources)))
self.e_star = all_pairs_longest_paths(self.graph, self.edges, T)
self.e_star = all_pairs_longest_paths(self.graph, self.edges, T, pool)
logger.debug(f"All Pairs Longest Paths: {self.e_star}.")
self.schedule: dict[fx.Node, int] = {}
for _, scc in topological_sort(sccs).items():
Expand Down Expand Up @@ -148,6 +150,8 @@ def schedule_graph(self) -> tuple[dict[fx.Node, int], bool]:
break
else:
raise Exception("Failed to schedule the graph.")
pool.close()
pool.join()

self._initiation_interval = T
return self.schedule, success
Expand Down
4 changes: 3 additions & 1 deletion tests/kernel/wave/scheduling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
import torch.fx as fx
import numpy as np
import multiprocessing as mp
from iree.turbine.kernel.wave.visualization import visualize_graph
from iree.turbine.kernel.wave.scheduling.graph_utils import (
find_strongly_connected_components,
Expand Down Expand Up @@ -180,7 +181,8 @@ def testGraphUtils(self):
def testAPLP(self):
graph, weighted_edges, nodes = self.create_weighted_graph()
T = 4
D3 = all_pairs_longest_paths(graph, weighted_edges, T)
pool = mp.get_context("fork").Pool(processes=mp.cpu_count())
D3 = all_pairs_longest_paths(graph, weighted_edges, T, pool)
assert D3[(nodes["a"], nodes["b"])] == 2
assert D3[(nodes["a"], nodes["c"])] == 3
assert D3[(nodes["a"], nodes["d"])] == 4
Expand Down

0 comments on commit d8ce8d2

Please sign in to comment.