diff --git a/mpi/examples/native_fortran/nf_context.f90 b/mpi/examples/native_fortran/nf_context.f90 index f496f686b0..a2a49ce086 100644 --- a/mpi/examples/native_fortran/nf_context.f90 +++ b/mpi/examples/native_fortran/nf_context.f90 @@ -76,6 +76,7 @@ end subroutine strtoptr C_CHAR_"qrm_ctx"//C_NULL_CHAR, & (/ FSTARPU_SCHED_CTX_POLICY_NAME, ptr, & c_null_ptr /)) + write(*, '("Created context: ",i1,"")')ctx allocate(a(mpi_size)) @@ -117,16 +118,10 @@ end subroutine strtoptr end if end do - ! if only wait for tasks in ctx I have a segfault - !call fstarpu_task_wait_for_all_in_ctx(ctx) - - ! if wait for all tasks (regardless of ctx) it works - call fstarpu_task_wait_for_all() - + ret = fstarpu_mpi_wait_for_all_in_ctx(mpi_comm, ctx) ret = fstarpu_mpi_barrier(mpi_comm) if(mpi_rank.eq.0) write(*,'("Yuppi, all the tasks in ctx",i1," ave finished!")')ctx - call fstarpu_codelet_free(task_cl ); task_cl = c_null_ptr call fstarpu_shutdown() ret = fstarpu_mpi_shutdown() diff --git a/mpi/include/fstarpu_mpi_mod.f90 b/mpi/include/fstarpu_mpi_mod.f90 index 1ea3393294..888338f78e 100644 --- a/mpi/include/fstarpu_mpi_mod.f90 +++ b/mpi/include/fstarpu_mpi_mod.f90 @@ -653,6 +653,15 @@ function fstarpu_mpi_wait_for_all (mpi_comm) bind(C) integer(c_int), value, intent(in) :: mpi_comm end function fstarpu_mpi_wait_for_all + ! int starpu_mpi_wait_for_all_in_ctx(MPI_Comm comm, unsigned sched_ctx); + function fstarpu_mpi_wait_for_all_in_ctx (mpi_comm, sched_ctx) bind(C) + use iso_c_binding + implicit none + integer(c_int) :: fstarpu_mpi_wait_for_all_in_ctx + integer(c_int), value, intent(in) :: mpi_comm + integer(c_int), value, intent(in) :: sched_ctx + end function fstarpu_mpi_wait_for_all_in_ctx + ! int starpu_mpi_datatype_register(starpu_data_handle_t handle, starpu_mpi_datatype_allocate_func_t allocate_datatype_func, starpu_mpi_datatype_free_func_t free_datatype_func); function fstarpu_mpi_datatype_register(dh, alloc_func, free_func) bind(C,name="starpu_mpi_datatype_register") use iso_c_binding diff --git a/mpi/include/starpu_mpi.h b/mpi/include/starpu_mpi.h index 10f74cb68a..b7f44aaca3 100644 --- a/mpi/include/starpu_mpi.h +++ b/mpi/include/starpu_mpi.h @@ -327,6 +327,12 @@ int starpu_mpi_barrier(MPI_Comm comm); */ int starpu_mpi_wait_for_all(MPI_Comm comm); +/** + Wait until all StarPU tasks in the given context and communications + for the given communicator are completed +*/ +int starpu_mpi_wait_for_all_in_ctx(MPI_Comm comm, unsigned sched_ctx); + /** Post a standard-mode, non blocking send of \p data_handle to the node \p dest using the message tag \p data_tag within the diff --git a/mpi/src/mpi/starpu_mpi_mpi.c b/mpi/src/mpi/starpu_mpi_mpi.c index a1bab8eecd..fa5e69e4fa 100644 --- a/mpi/src/mpi/starpu_mpi_mpi.c +++ b/mpi/src/mpi/starpu_mpi_mpi.c @@ -819,7 +819,7 @@ int _starpu_mpi_barrier(MPI_Comm comm) return 0; } -int _starpu_mpi_wait_for_all(MPI_Comm comm) +int _starpu_mpi_wait_for_all__(MPI_Comm comm, unsigned sched_ctx) { (void) comm; _STARPU_MPI_LOG_IN(); @@ -839,7 +839,10 @@ int _starpu_mpi_wait_for_all(MPI_Comm comm) newer_requests = 0; STARPU_PTHREAD_MUTEX_UNLOCK(&progress_mutex); /* Now wait for all tasks */ - starpu_task_wait_for_all(); + if (sched_ctx == STARPU_NMAX_SCHED_CTXS+1) + starpu_task_wait_for_all(); + else + starpu_task_wait_for_all_in_ctx(sched_ctx); STARPU_PTHREAD_MUTEX_LOCK(&progress_mutex); /* Check newer_requests again, in case some MPI requests * triggered by tasks completed and triggered tasks between @@ -850,6 +853,17 @@ int _starpu_mpi_wait_for_all(MPI_Comm comm) return 0; } +int _starpu_mpi_wait_for_all(MPI_Comm comm) +{ + return _starpu_mpi_wait_for_all__(comm, STARPU_NMAX_SCHED_CTXS+1); +} + +int _starpu_mpi_wait_for_all_in_ctx(MPI_Comm comm, unsigned sched_ctx) +{ + return _starpu_mpi_wait_for_all__(comm, sched_ctx); +} + + /********************************************************/ /* */ /* Progression */ diff --git a/mpi/src/mpi/starpu_mpi_mpi.h b/mpi/src/mpi/starpu_mpi_mpi.h index faaf1b6b47..3c311f2a4f 100644 --- a/mpi/src/mpi/starpu_mpi_mpi.h +++ b/mpi/src/mpi/starpu_mpi_mpi.h @@ -41,6 +41,7 @@ void _starpu_mpi_wait_for_initialization(); int _starpu_mpi_barrier(MPI_Comm comm); int _starpu_mpi_wait_for_all(MPI_Comm comm); +int _starpu_mpi_wait_for_all_in_ctx(MPI_Comm comm, unsigned sched_ctx); int _starpu_mpi_wait(starpu_mpi_req *public_req, MPI_Status *status); int _starpu_mpi_test(starpu_mpi_req *public_req, int *flag, MPI_Status *status); diff --git a/mpi/src/mpi/starpu_mpi_mpi_backend.c b/mpi/src/mpi/starpu_mpi_mpi_backend.c index 5f3291105d..89f6239185 100644 --- a/mpi/src/mpi/starpu_mpi_mpi_backend.c +++ b/mpi/src/mpi/starpu_mpi_mpi_backend.c @@ -130,6 +130,7 @@ struct _starpu_mpi_backend _mpi_backend = ._starpu_mpi_backend_barrier = _starpu_mpi_barrier, ._starpu_mpi_backend_wait_for_all = _starpu_mpi_wait_for_all, + ._starpu_mpi_backend_wait_for_all_in_ctx = _starpu_mpi_wait_for_all_in_ctx, ._starpu_mpi_backend_wait = _starpu_mpi_wait, ._starpu_mpi_backend_test = _starpu_mpi_test, diff --git a/mpi/src/nmad/starpu_mpi_nmad.c b/mpi/src/nmad/starpu_mpi_nmad.c index 9d23f4ba5f..5c10dd7e3c 100644 --- a/mpi/src/nmad/starpu_mpi_nmad.c +++ b/mpi/src/nmad/starpu_mpi_nmad.c @@ -294,7 +294,7 @@ int _starpu_mpi_barrier(MPI_Comm comm) return ret; } -int _starpu_mpi_wait_for_all(MPI_Comm comm) +int _starpu_mpi_wait_for_all__(MPI_Comm comm, unsigned schd_ctx) { (void) comm; _STARPU_MPI_LOG_IN(); @@ -308,7 +308,10 @@ int _starpu_mpi_wait_for_all(MPI_Comm comm) STARPU_PTHREAD_COND_WAIT(&mpi_wait_for_all_running_cond, &mpi_wait_for_all_running_mutex); STARPU_PTHREAD_MUTEX_UNLOCK(&mpi_wait_for_all_running_mutex); - starpu_task_wait_for_all(); + if (sched_ctx == STARPU_NMAX_SCHED_CTXS+1) + starpu_task_wait_for_all(); + else + starpu_task_wait_for_all_in_ctx(sched_ctx); STARPU_PTHREAD_MUTEX_LOCK(&mpi_wait_for_all_running_mutex); } while (nb_pending_requests); @@ -319,6 +322,16 @@ int _starpu_mpi_wait_for_all(MPI_Comm comm) return 0; } +int _starpu_mpi_wait_for_all(MPI_Comm comm) +{ + return _starpu_mpi_wait_for_all__(comm, STARPU_NMAX_SCHED_CTXS+1); +} + +int _starpu_mpi_wait_for_all_in_ctx(MPI_Comm comm, unsigned sched_ctx) +{ + return _starpu_mpi_wait_for_all__(comm, sched_ctx); +} + /********************************************************/ /* */ /* Progression */ diff --git a/mpi/src/nmad/starpu_mpi_nmad.h b/mpi/src/nmad/starpu_mpi_nmad.h index 8bfb5046ca..e2967c9cff 100644 --- a/mpi/src/nmad/starpu_mpi_nmad.h +++ b/mpi/src/nmad/starpu_mpi_nmad.h @@ -41,6 +41,7 @@ void _starpu_mpi_progress_shutdown(void **value); int _starpu_mpi_barrier(MPI_Comm comm); int _starpu_mpi_wait_for_all(MPI_Comm comm); +int _starpu_mpi_wait_for_all_in_ctx(MPI_Comm comm, unsigned sched_ctx); int _starpu_mpi_wait(starpu_mpi_req *public_req, MPI_Status *status); int _starpu_mpi_test(starpu_mpi_req *public_req, int *flag, MPI_Status *status); diff --git a/mpi/src/nmad/starpu_mpi_nmad_backend.c b/mpi/src/nmad/starpu_mpi_nmad_backend.c index 93b785f015..0181c66801 100644 --- a/mpi/src/nmad/starpu_mpi_nmad_backend.c +++ b/mpi/src/nmad/starpu_mpi_nmad_backend.c @@ -109,6 +109,7 @@ struct _starpu_mpi_backend _mpi_backend = ._starpu_mpi_backend_barrier = _starpu_mpi_barrier, ._starpu_mpi_backend_wait_for_all = _starpu_mpi_wait_for_all, + ._starpu_mpi_backend_wait_for_all = _starpu_mpi_wait_for_all_in_ctx, ._starpu_mpi_backend_wait = _starpu_mpi_wait, ._starpu_mpi_backend_test = _starpu_mpi_test, diff --git a/mpi/src/starpu_mpi.c b/mpi/src/starpu_mpi.c index 36fff8b9be..4fc1e5cf2a 100644 --- a/mpi/src/starpu_mpi.c +++ b/mpi/src/starpu_mpi.c @@ -685,6 +685,14 @@ int starpu_mpi_wait_for_all(MPI_Comm comm) return _mpi_backend._starpu_mpi_backend_wait_for_all(comm); } +int starpu_mpi_wait_for_all_in_ctx(MPI_Comm comm, unsigned sched_ctx) +{ + /* If the user forgets to call mpi_redux_data or insert R tasks on the reduced handles */ + /* then, we wrap reduction patterns for them. This is typical of benchmarks */ + _starpu_mpi_redux_wrapup_data_all(); + return _mpi_backend._starpu_mpi_backend_wait_for_all_in_ctx(comm, sched_ctx); +} + void starpu_mpi_comm_stats_disable() { _starpu_mpi_comm_stats_disable(); diff --git a/mpi/src/starpu_mpi_fortran.c b/mpi/src/starpu_mpi_fortran.c index 9a316df55f..5e3227a7e9 100644 --- a/mpi/src/starpu_mpi_fortran.c +++ b/mpi/src/starpu_mpi_fortran.c @@ -323,4 +323,9 @@ int fstarpu_mpi_wait_for_all(MPI_Fint comm) { return starpu_mpi_wait_for_all(MPI_Comm_f2c(comm)); } + +int fstarpu_mpi_wait_for_all_in_ctx(MPI_Fint comm, int sched_ctx) +{ + return starpu_mpi_wait_for_all_in_ctx(MPI_Comm_f2c(comm), (unsigned)sched_ctx); +} #endif diff --git a/mpi/src/starpu_mpi_private.h b/mpi/src/starpu_mpi_private.h index 462be17c60..f959057fb7 100644 --- a/mpi/src/starpu_mpi_private.h +++ b/mpi/src/starpu_mpi_private.h @@ -1,6 +1,6 @@ /* StarPU --- Runtime system for heterogeneous multicore architectures. * - * Copyright (C) 2010-2023 University of Bordeaux, CNRS (LaBRI UMR 5800), Inria + * Copyright (C) 2010-2024 University of Bordeaux, CNRS (LaBRI UMR 5800), Inria * * StarPU is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as published by @@ -389,6 +389,7 @@ struct _starpu_mpi_backend int (*_starpu_mpi_backend_barrier)(MPI_Comm comm); int (*_starpu_mpi_backend_wait_for_all)(MPI_Comm comm); + int (*_starpu_mpi_backend_wait_for_all_in_ctx)(MPI_Comm comm, unsigned sched_ctx); int (*_starpu_mpi_backend_wait)(starpu_mpi_req *public_req, MPI_Status *status); int (*_starpu_mpi_backend_test)(starpu_mpi_req *public_req, int *flag, MPI_Status *status);