diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp index 780c7caee767c1..a2ff9c99469179 100644 --- a/mlir/lib/Analysis/CallGraph.cpp +++ b/mlir/lib/Analysis/CallGraph.cpp @@ -112,7 +112,8 @@ CallGraph::CallGraph(Operation *op) /// Get or add a call graph node for the given region. CallGraphNode *CallGraph::getOrAddNode(Region *region, CallGraphNode *parentNode) { - assert(region && isa(region->getParentOp()) && + Operation *parentOp = region->getParentOp(); + assert(region && isa(parentOp) && "expected parent operation to be callable"); std::unique_ptr &node = nodes[region]; if (!node) { @@ -122,13 +123,12 @@ CallGraphNode *CallGraph::getOrAddNode(Region *region, if (parentNode) { parentNode->addChildEdge(node.get()); } else { - // Otherwise, connect all callable nodes to the external node, this allows - // for conservatively including all callable nodes within the graph. - // FIXME This isn't correct, this is only necessary for callable nodes - // that *could* be called from external sources. This requires extending - // the interface for callables to check if they may be referenced - // externally. - externalCallerNode.addAbstractEdge(node.get()); + // Otherwise, connect all symbol nodes with public visibility + // to the external node, which is a set including callable nodes + // may be referenced externally. + if (isa(parentOp) && + cast(parentOp).isPublic()) + externalCallerNode.addAbstractEdge(node.get()); } } return node.get(); @@ -199,9 +199,8 @@ void CallGraph::print(raw_ostream &os) const { os << " : " << attrs; }; - for (auto &nodeIt : nodes) { - const CallGraphNode *node = nodeIt.second.get(); - + // Functor used to emit the given node and edges. + auto emitNodeAndEdge = [&](const CallGraphNode *node) { // Dump the header for this node. os << "// - Node : "; emitNodeName(node); @@ -214,13 +213,21 @@ void CallGraph::print(raw_ostream &os) const { os << "Call"; else if (edge.isChild()) os << "Child"; + else if (edge.isAbstract()) + os << "Abstract"; os << "-Edge : "; emitNodeName(edge.getTarget()); os << "\n"; } os << "//\n"; - } + }; + + // Emit all graph nodes including ExternalCallerNode and UnknownCalleeNode. + for (auto &nodeIt : nodes) + emitNodeAndEdge(nodeIt.second.get()); + emitNodeAndEdge(getExternalCallerNode()); + emitNodeAndEdge(getUnknownCalleeNode()); os << "// -- SCCs --\n"; diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp index 8acfc96d2b611b..978bf7f5c0b70f 100644 --- a/mlir/lib/Transforms/Utils/Inliner.cpp +++ b/mlir/lib/Transforms/Utils/Inliner.cpp @@ -434,6 +434,9 @@ class Inliner::Impl { CGUseList &useList, CallGraphSCC ¤tSCC, MLIRContext *context); + void collectDeadNodeAfterInline(CallGraph &cg, CGUseList &useList, + InlinerInterfaceImpl &inlinerIface); + private: /// Optimize the nodes within the given SCC with one of the held optimization /// pass pipelines. Returns failure if an error occurred during the @@ -748,6 +751,27 @@ bool Inliner::Impl::shouldInline(ResolvedCall &resolvedCall) { return true; } +/// Iteratively clean up dead nodes until no change happened. +void Inliner::Impl::collectDeadNodeAfterInline( + CallGraph &cg, CGUseList &useList, InlinerInterfaceImpl &inlinerIface) { + auto eraseDeadNode = [&](void) { + bool changed = false; + for (CallGraphNode *node : cg) { + if (useList.isDead(node)) { + useList.eraseNode(node); + inlinerIface.markForDeletion(node); + changed = true; + } + } + return changed; + }; + + while (1) { + if (!eraseDeadNode()) + break; + } +} + LogicalResult Inliner::doInlining() { Impl impl(*this); auto *context = op->getContext(); @@ -765,6 +789,7 @@ LogicalResult Inliner::doInlining() { return result; // After inlining, make sure to erase any callables proven to be dead. + impl.collectDeadNodeAfterInline(cg, useList, inlinerIface); inlinerIface.eraseDeadCallables(); return success(); } diff --git a/mlir/test/Analysis/test-callgraph.mlir b/mlir/test/Analysis/test-callgraph.mlir index f6c9ff5006e053..8a00966bea61dd 100644 --- a/mlir/test/Analysis/test-callgraph.mlir +++ b/mlir/test/Analysis/test-callgraph.mlir @@ -8,24 +8,25 @@ module attributes {test.name = "simple"} { return } + // CHECK-NOT: Node{{.*}}func_b func.func private @func_b() - // CHECK: Node{{.*}}func_c + // CHECK: Node{{.*}}func_c{{.*}}private // CHECK-NEXT: Call-Edge{{.*}}Unknown-Callee-Node - func.func @func_c() { + func.func private @func_c() { call @func_b() : () -> () return } // CHECK: Node{{.*}}func_d - // CHECK-NEXT: Call-Edge{{.*}}func_c + // CHECK-NEXT: Call-Edge{{.*}}func_c{{.*}}private func.func @func_d() { call @func_c() : () -> () return } // CHECK: Node{{.*}}func_e - // CHECK-DAG: Call-Edge{{.*}}func_c + // CHECK-DAG: Call-Edge{{.*}}func_c{{.*}}private // CHECK-DAG: Call-Edge{{.*}}func_d // CHECK-DAG: Call-Edge{{.*}}func_e func.func @func_e() { @@ -49,6 +50,16 @@ module attributes {test.name = "simple"} { call_indirect %fn() : () -> () return } + + // CHECK: Node{{.*}}External-Caller-Node + // CHECK: Edge{{.*}}func_a + // CHECK-NOT: Edge{{.*}}func_b + // CHECK-NOT: Edge{{.*}}func_c + // CHECK: Edge{{.*}}func_d + // CHECK: Edge{{.*}}func_e + // CHECK: Edge{{.*}}func_f + + // CHECK: Node{{.*}}Unknown-Callee-Node } // ----- @@ -57,17 +68,23 @@ module attributes {test.name = "simple"} { module attributes {test.name = "nested"} { module @nested_module { // CHECK: Node{{.*}}func_a - func.func @func_a() { + func.func nested @func_a() { return } } // CHECK: Node{{.*}}func_b - // CHECK: Call-Edge{{.*}}func_a + // CHECK: Call-Edge{{.*}}func_a{{.*}}nested func.func @func_b() { "test.conversion_call_op"() { callee = @nested_module::@func_a } : () -> () return } + + // CHECK: Node{{.*}}External-Caller-Node + // CHECK: Edge{{.*}}func_b + // CHECK-NOT: Edge{{.*}}func_a + + // CHECK: Node{{.*}}Unknown-Callee-Node } // ----- diff --git a/mlir/test/Transforms/inlining-dce.mlir b/mlir/test/Transforms/inlining-dce.mlir index d167c1b4baae98..45b3ebc1e01772 100644 --- a/mlir/test/Transforms/inlining-dce.mlir +++ b/mlir/test/Transforms/inlining-dce.mlir @@ -10,7 +10,7 @@ func.func private @dead_function() { // Function becomes dead after inlining. // CHECK-NOT: func private @dead_function_b -func.func @dead_function_b() { +func.func private @dead_function_b() { return } @@ -44,6 +44,19 @@ func.func @live_function_c() { return } +// A transitive example, but no one be called by live-function. + +// CHECK-NOT: func private @dead_function_e +func.func private @dead_function_e() { + call @live_function_b() : () -> () + return +} +// CHECK-NOT: func private @dead_function_f +func.func private @dead_function_f() { + call @dead_function_e() : () -> () + return +} + // Function is referenced by non-callable top-level user. // CHECK: func private @live_function_d func.func private @live_function_d() {