diff --git a/flang/include/flang/Lower/StatementContext.h b/flang/include/flang/Lower/StatementContext.h index 7776edc93ed737..eef21d4bae5aab 100644 --- a/flang/include/flang/Lower/StatementContext.h +++ b/flang/include/flang/Lower/StatementContext.h @@ -92,10 +92,13 @@ class StatementContext { cufs.back().reset(); } + /// Pop the stack top list. + void pop() { cufs.pop_back(); } + /// Make cleanup calls. Pop the stack top list. void finalizeAndPop() { finalizeAndKeep(); - cufs.pop_back(); + pop(); } bool hasCode() const { diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index da53edf7e734b0..7f41742bf5e8b2 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -1621,13 +1621,19 @@ class FirConverter : public Fortran::lower::AbstractConverter { // Termination of symbolically referenced execution units //===--------------------------------------------------------------------===// - /// END of program + /// Exit of a routine /// - /// Generate the cleanup block before the program exits - void genExitRoutine() { - - if (blockIsUnterminated()) - builder->create(toLocation()); + /// Generate the cleanup block before the routine exits + void genExitRoutine(bool earlyReturn, mlir::ValueRange retval = {}) { + if (blockIsUnterminated()) { + bridge.openAccCtx().finalizeAndKeep(); + bridge.fctCtx().finalizeAndKeep(); + builder->create(toLocation(), retval); + } + if (!earlyReturn) { + bridge.openAccCtx().pop(); + bridge.fctCtx().pop(); + } } /// END of procedure-like constructs @@ -1684,9 +1690,7 @@ class FirConverter : public Fortran::lower::AbstractConverter { resultRef = builder->createConvert(loc, resultRefType, resultRef); return builder->create(loc, resultRef); }); - bridge.openAccCtx().finalizeAndPop(); - bridge.fctCtx().finalizeAndPop(); - builder->create(loc, resultVal); + genExitRoutine(false, resultVal); } /// Get the return value of a call to \p symbol, which is a subroutine entry @@ -1712,13 +1716,9 @@ class FirConverter : public Fortran::lower::AbstractConverter { } else if (Fortran::semantics::HasAlternateReturns(symbol)) { mlir::Value retval = builder->create( toLocation(), getAltReturnResult(symbol)); - bridge.openAccCtx().finalizeAndPop(); - bridge.fctCtx().finalizeAndPop(); - builder->create(toLocation(), retval); + genExitRoutine(false, retval); } else { - bridge.openAccCtx().finalizeAndPop(); - bridge.fctCtx().finalizeAndPop(); - genExitRoutine(); + genExitRoutine(false); } } @@ -5018,8 +5018,7 @@ class FirConverter : public Fortran::lower::AbstractConverter { it->stmtCtx.finalizeAndKeep(); } if (funit->isMainProgram()) { - bridge.fctCtx().finalizeAndKeep(); - genExitRoutine(); + genExitRoutine(true); return; } mlir::Location loc = toLocation(); @@ -5478,9 +5477,7 @@ class FirConverter : public Fortran::lower::AbstractConverter { void endNewFunction(Fortran::lower::pft::FunctionLikeUnit &funit) { setCurrentPosition(Fortran::lower::pft::stmtSourceLoc(funit.endStmt)); if (funit.isMainProgram()) { - bridge.openAccCtx().finalizeAndPop(); - bridge.fctCtx().finalizeAndPop(); - genExitRoutine(); + genExitRoutine(false); } else { genFIRProcedureExit(funit, funit.getSubprogramSymbol()); } diff --git a/flang/test/Lower/CUDA/cuda-return01.cuf b/flang/test/Lower/CUDA/cuda-return01.cuf new file mode 100644 index 00000000000000..c9f9a8b57ef041 --- /dev/null +++ b/flang/test/Lower/CUDA/cuda-return01.cuf @@ -0,0 +1,14 @@ +! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s + +! Check if finalization works with a return statement + +program main + integer, device :: a(10) + return +end + +! CHECK: func.func @_QQmain() attributes {fir.bindc_name = "main"} { +! CHECK: %[[DECL:.*]]:2 = hlfir.declare +! CHECK-NEXT: cuf.free %[[DECL]]#1 : !fir.ref> +! CHECK-NEXT: return +! CHECK-NEXT: } diff --git a/flang/test/Lower/CUDA/cuda-return02.cuf b/flang/test/Lower/CUDA/cuda-return02.cuf new file mode 100644 index 00000000000000..5d01f0a24b420b --- /dev/null +++ b/flang/test/Lower/CUDA/cuda-return02.cuf @@ -0,0 +1,48 @@ +! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s + +! Check if finalization works with multiple return statements + +program test + integer, device :: a(10) + logical :: l + + if (l) then + return + end if + + return +end + +! CHECK: func.func @_QQmain() attributes {fir.bindc_name = "test"} { +! CHECK: %[[DECL:.*]]:2 = hlfir.declare +! CHECK: cf.cond_br %{{.*}}, ^bb1, ^bb2 +! CHECK-NEXT: ^bb1: +! CHECK-NEXT: cuf.free %[[DECL]]#1 : !fir.ref> +! CHECK-NEXT: return +! CHECK-NEXT: ^bb2: +! CHECK-NEXT: cuf.free %[[DECL]]#1 : !fir.ref> +! CHECK-NEXT: return +! CHECK-NEXT: } + +subroutine sub(l) + integer, device :: a(10) + logical :: l + + if (l) then + l = .false. + return + end if + + return +end + +! CHECK: func.func @_QPsub(%arg0: !fir.ref> {fir.bindc_name = "l"}) { +! CHECK: %[[DECL:.*]]:2 = hlfir.declare +! CHECK: cf.cond_br %6, ^bb1, ^bb2 +! CHECK: ^bb1: +! CHECK: cf.br ^bb3 +! CHECK: ^bb2: +! CHECK: cf.br ^bb3 +! CHECK: ^bb3: +! CHECK: cuf.free %[[DECL]]#1 : !fir.ref> +! CHECK: }