Skip to content

Commit

Permalink
Add new function starpu_mpi_wait_for_all_in_ctx()
Browse files Browse the repository at this point in the history
  • Loading branch information
nfurmento committed Jul 16, 2024
1 parent bff91d4 commit c5330b8
Show file tree
Hide file tree
Showing 12 changed files with 67 additions and 12 deletions.
9 changes: 2 additions & 7 deletions mpi/examples/native_fortran/nf_context.f90
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ end subroutine strtoptr
C_CHAR_"qrm_ctx"//C_NULL_CHAR, &
(/ FSTARPU_SCHED_CTX_POLICY_NAME, ptr, &
c_null_ptr /))
write(*, '("Created context: ",i1,"")')ctx

allocate(a(mpi_size))

Expand Down Expand Up @@ -117,16 +118,10 @@ end subroutine strtoptr
end if
end do

! if only wait for tasks in ctx I have a segfault
!call fstarpu_task_wait_for_all_in_ctx(ctx)

! if wait for all tasks (regardless of ctx) it works
call fstarpu_task_wait_for_all()

ret = fstarpu_mpi_wait_for_all_in_ctx(mpi_comm, ctx)
ret = fstarpu_mpi_barrier(mpi_comm)
if(mpi_rank.eq.0) write(*,'("Yuppi, all the tasks in ctx",i1," ave finished!")')ctx


call fstarpu_codelet_free(task_cl ); task_cl = c_null_ptr
call fstarpu_shutdown()
ret = fstarpu_mpi_shutdown()
Expand Down
9 changes: 9 additions & 0 deletions mpi/include/fstarpu_mpi_mod.f90
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,15 @@ function fstarpu_mpi_wait_for_all (mpi_comm) bind(C)
integer(c_int), value, intent(in) :: mpi_comm
end function fstarpu_mpi_wait_for_all

! int starpu_mpi_wait_for_all_in_ctx(MPI_Comm comm, unsigned sched_ctx);
function fstarpu_mpi_wait_for_all_in_ctx (mpi_comm, sched_ctx) bind(C)
use iso_c_binding
implicit none
integer(c_int) :: fstarpu_mpi_wait_for_all_in_ctx
integer(c_int), value, intent(in) :: mpi_comm
integer(c_int), value, intent(in) :: sched_ctx
end function fstarpu_mpi_wait_for_all_in_ctx

! int starpu_mpi_datatype_register(starpu_data_handle_t handle, starpu_mpi_datatype_allocate_func_t allocate_datatype_func, starpu_mpi_datatype_free_func_t free_datatype_func);
function fstarpu_mpi_datatype_register(dh, alloc_func, free_func) bind(C,name="starpu_mpi_datatype_register")
use iso_c_binding
Expand Down
6 changes: 6 additions & 0 deletions mpi/include/starpu_mpi.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,12 @@ int starpu_mpi_barrier(MPI_Comm comm);
*/
int starpu_mpi_wait_for_all(MPI_Comm comm);

/**
Wait until all StarPU tasks in the given context and communications
for the given communicator are completed
*/
int starpu_mpi_wait_for_all_in_ctx(MPI_Comm comm, unsigned sched_ctx);

/**
Post a standard-mode, non blocking send of \p data_handle to the
node \p dest using the message tag \p data_tag within the
Expand Down
18 changes: 16 additions & 2 deletions mpi/src/mpi/starpu_mpi_mpi.c
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ int _starpu_mpi_barrier(MPI_Comm comm)
return 0;
}

int _starpu_mpi_wait_for_all(MPI_Comm comm)
int _starpu_mpi_wait_for_all__(MPI_Comm comm, unsigned sched_ctx)
{
(void) comm;
_STARPU_MPI_LOG_IN();
Expand All @@ -839,7 +839,10 @@ int _starpu_mpi_wait_for_all(MPI_Comm comm)
newer_requests = 0;
STARPU_PTHREAD_MUTEX_UNLOCK(&progress_mutex);
/* Now wait for all tasks */
starpu_task_wait_for_all();
if (sched_ctx == STARPU_NMAX_SCHED_CTXS+1)
starpu_task_wait_for_all();
else
starpu_task_wait_for_all_in_ctx(sched_ctx);
STARPU_PTHREAD_MUTEX_LOCK(&progress_mutex);
/* Check newer_requests again, in case some MPI requests
* triggered by tasks completed and triggered tasks between
Expand All @@ -850,6 +853,17 @@ int _starpu_mpi_wait_for_all(MPI_Comm comm)
return 0;
}

int _starpu_mpi_wait_for_all(MPI_Comm comm)
{
return _starpu_mpi_wait_for_all__(comm, STARPU_NMAX_SCHED_CTXS+1);
}

int _starpu_mpi_wait_for_all_in_ctx(MPI_Comm comm, unsigned sched_ctx)
{
return _starpu_mpi_wait_for_all__(comm, sched_ctx);
}


/********************************************************/
/* */
/* Progression */
Expand Down
1 change: 1 addition & 0 deletions mpi/src/mpi/starpu_mpi_mpi.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ void _starpu_mpi_wait_for_initialization();

int _starpu_mpi_barrier(MPI_Comm comm);
int _starpu_mpi_wait_for_all(MPI_Comm comm);
int _starpu_mpi_wait_for_all_in_ctx(MPI_Comm comm, unsigned sched_ctx);
int _starpu_mpi_wait(starpu_mpi_req *public_req, MPI_Status *status);
int _starpu_mpi_test(starpu_mpi_req *public_req, int *flag, MPI_Status *status);

Expand Down
1 change: 1 addition & 0 deletions mpi/src/mpi/starpu_mpi_mpi_backend.c
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ struct _starpu_mpi_backend _mpi_backend =

._starpu_mpi_backend_barrier = _starpu_mpi_barrier,
._starpu_mpi_backend_wait_for_all = _starpu_mpi_wait_for_all,
._starpu_mpi_backend_wait_for_all_in_ctx = _starpu_mpi_wait_for_all_in_ctx,
._starpu_mpi_backend_wait = _starpu_mpi_wait,
._starpu_mpi_backend_test = _starpu_mpi_test,

Expand Down
17 changes: 15 additions & 2 deletions mpi/src/nmad/starpu_mpi_nmad.c
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ int _starpu_mpi_barrier(MPI_Comm comm)
return ret;
}

int _starpu_mpi_wait_for_all(MPI_Comm comm)
int _starpu_mpi_wait_for_all__(MPI_Comm comm, unsigned schd_ctx)
{
(void) comm;
_STARPU_MPI_LOG_IN();
Expand All @@ -308,7 +308,10 @@ int _starpu_mpi_wait_for_all(MPI_Comm comm)
STARPU_PTHREAD_COND_WAIT(&mpi_wait_for_all_running_cond, &mpi_wait_for_all_running_mutex);
STARPU_PTHREAD_MUTEX_UNLOCK(&mpi_wait_for_all_running_mutex);

starpu_task_wait_for_all();
if (sched_ctx == STARPU_NMAX_SCHED_CTXS+1)
starpu_task_wait_for_all();
else
starpu_task_wait_for_all_in_ctx(sched_ctx);

STARPU_PTHREAD_MUTEX_LOCK(&mpi_wait_for_all_running_mutex);
} while (nb_pending_requests);
Expand All @@ -319,6 +322,16 @@ int _starpu_mpi_wait_for_all(MPI_Comm comm)
return 0;
}

int _starpu_mpi_wait_for_all(MPI_Comm comm)
{
return _starpu_mpi_wait_for_all__(comm, STARPU_NMAX_SCHED_CTXS+1);
}

int _starpu_mpi_wait_for_all_in_ctx(MPI_Comm comm, unsigned sched_ctx)
{
return _starpu_mpi_wait_for_all__(comm, sched_ctx);
}

/********************************************************/
/* */
/* Progression */
Expand Down
1 change: 1 addition & 0 deletions mpi/src/nmad/starpu_mpi_nmad.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ void _starpu_mpi_progress_shutdown(void **value);

int _starpu_mpi_barrier(MPI_Comm comm);
int _starpu_mpi_wait_for_all(MPI_Comm comm);
int _starpu_mpi_wait_for_all_in_ctx(MPI_Comm comm, unsigned sched_ctx);
int _starpu_mpi_wait(starpu_mpi_req *public_req, MPI_Status *status);
int _starpu_mpi_test(starpu_mpi_req *public_req, int *flag, MPI_Status *status);

Expand Down
1 change: 1 addition & 0 deletions mpi/src/nmad/starpu_mpi_nmad_backend.c
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ struct _starpu_mpi_backend _mpi_backend =

._starpu_mpi_backend_barrier = _starpu_mpi_barrier,
._starpu_mpi_backend_wait_for_all = _starpu_mpi_wait_for_all,
._starpu_mpi_backend_wait_for_all = _starpu_mpi_wait_for_all_in_ctx,
._starpu_mpi_backend_wait = _starpu_mpi_wait,
._starpu_mpi_backend_test = _starpu_mpi_test,

Expand Down
8 changes: 8 additions & 0 deletions mpi/src/starpu_mpi.c
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,14 @@ int starpu_mpi_wait_for_all(MPI_Comm comm)
return _mpi_backend._starpu_mpi_backend_wait_for_all(comm);
}

int starpu_mpi_wait_for_all_in_ctx(MPI_Comm comm, unsigned sched_ctx)
{
/* If the user forgets to call mpi_redux_data or insert R tasks on the reduced handles */
/* then, we wrap reduction patterns for them. This is typical of benchmarks */
_starpu_mpi_redux_wrapup_data_all();
return _mpi_backend._starpu_mpi_backend_wait_for_all_in_ctx(comm, sched_ctx);
}

void starpu_mpi_comm_stats_disable()
{
_starpu_mpi_comm_stats_disable();
Expand Down
5 changes: 5 additions & 0 deletions mpi/src/starpu_mpi_fortran.c
Original file line number Diff line number Diff line change
Expand Up @@ -323,4 +323,9 @@ int fstarpu_mpi_wait_for_all(MPI_Fint comm)
{
return starpu_mpi_wait_for_all(MPI_Comm_f2c(comm));
}

int fstarpu_mpi_wait_for_all_in_ctx(MPI_Fint comm, int sched_ctx)
{
return starpu_mpi_wait_for_all_in_ctx(MPI_Comm_f2c(comm), (unsigned)sched_ctx);
}
#endif
3 changes: 2 additions & 1 deletion mpi/src/starpu_mpi_private.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* StarPU --- Runtime system for heterogeneous multicore architectures.
*
* Copyright (C) 2010-2023 University of Bordeaux, CNRS (LaBRI UMR 5800), Inria
* Copyright (C) 2010-2024 University of Bordeaux, CNRS (LaBRI UMR 5800), Inria
*
* StarPU is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
Expand Down Expand Up @@ -389,6 +389,7 @@ struct _starpu_mpi_backend

int (*_starpu_mpi_backend_barrier)(MPI_Comm comm);
int (*_starpu_mpi_backend_wait_for_all)(MPI_Comm comm);
int (*_starpu_mpi_backend_wait_for_all_in_ctx)(MPI_Comm comm, unsigned sched_ctx);
int (*_starpu_mpi_backend_wait)(starpu_mpi_req *public_req, MPI_Status *status);
int (*_starpu_mpi_backend_test)(starpu_mpi_req *public_req, int *flag, MPI_Status *status);

Expand Down

0 comments on commit c5330b8

Please sign in to comment.