diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 79d30faf..127881c2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,7 +10,7 @@ jobs: fail-fast: false matrix: version: - - '1.5' + - '1.6' os: - ubuntu-latest arch: diff --git a/src/MPIBackend.jl b/src/MPIBackend.jl index 9d24c369..439f2551 100644 --- a/src/MPIBackend.jl +++ b/src/MPIBackend.jl @@ -22,12 +22,32 @@ function prun(driver::Function,b::MPIBackend,nparts) if !MPI.Initialized() MPI.Init() end - #try - part = get_part_ids(b,nparts) - driver(part) - #finally - # MPI.Finalize() - #end + try + part = get_part_ids(b,nparts) + driver(part) + catch e + @error "" exception=(e, catch_backtrace()) + if MPI.Initialized() && !MPI.Finalized() + MPI.Abort(MPI.COMM_WORLD,0) + end + end + # We are NOT invoking MPI.Finalize() here because we rely on + # MPI.jl, which registers MPI.Finalize() in atexit() +end + +# Useful to debug an MPI program when executed interactively +# on the REPL, i.e., with a single MPI task +function prun_debug(driver::Function,b::MPIBackend,nparts) + if !MPI.Initialized() + MPI.Init() + end + if (length(nparts) != 1) + MPI.Abort(MPI.COMM_WORLD,0) + end + part = get_part_ids(b,nparts) + driver(part) + # We are NOT invoking MPI.Finalize() here because we rely on + # MPI.jl, which registers MPI.Finalize() in atexit() end struct MPIData{T,N} <: AbstractPData{T,N} @@ -156,7 +176,7 @@ function scatter(snd::MPIData) MPI.Scatter!(MPI.UBuffer(snd.part,1),MPI.IN_PLACE,MAIN-1,snd.comm) else rcv = Vector{eltype(snd.part)}(undef,1) - MPI.Scatter!(nothing,rcv,MAIN-1,snd.comm) + MPI.Scatter!(nothing,rcv,MAIN-1,snd.comm) part = rcv[1] end MPIData(part,snd.comm,snd.size) @@ -171,7 +191,7 @@ function scatter(snd::MPIData{<:Table}) rcv = snd.part[MAIN] else rcv = eltype(snd.part)(undef,counts_scat.part) - MPI.Scatterv!(nothing,rcv,MAIN-1,snd.comm) + MPI.Scatterv!(nothing,rcv,MAIN-1,snd.comm) end MPIData(rcv,snd.comm,snd.size) end @@ -282,4 +302,3 @@ function async_exchange!( t_out = MPIData(t2,comm,s) t_out end - diff --git a/src/PartitionedArrays.jl b/src/PartitionedArrays.jl index fb920c8a..ec598dde 100644 --- a/src/PartitionedArrays.jl +++ b/src/PartitionedArrays.jl @@ -9,7 +9,7 @@ import IterativeSolvers import Distances export AbstractBackend -export prun +export prun, prun_debug export AbstractPData export SequentialData export MPIData diff --git a/src/SequentialBackend.jl b/src/SequentialBackend.jl index 795358f6..b0b320be 100644 --- a/src/SequentialBackend.jl +++ b/src/SequentialBackend.jl @@ -13,6 +13,10 @@ function get_part_ids(b::SequentialBackend,nparts::Tuple) SequentialData(parts) end +function prun_debug(driver::Function,b::SequentialBackend,nparts) + prun(driver,b,nparts) +end + struct SequentialData{T,N} <: AbstractPData{T,N} parts::Array{T,N} end diff --git a/test/mpi/ExceptionTests.jl b/test/mpi/ExceptionTests.jl new file mode 100644 index 00000000..a7e4080a --- /dev/null +++ b/test/mpi/ExceptionTests.jl @@ -0,0 +1,6 @@ +module ExceptionTests + +include("mpiexec.jl") +run_mpi_driver(procs=8,file="driver_exception.jl") + +end # module diff --git a/test/mpi/driver_exception.jl b/test/mpi/driver_exception.jl new file mode 100644 index 00000000..1226130f --- /dev/null +++ b/test/mpi/driver_exception.jl @@ -0,0 +1,3 @@ +include("../test_exception.jl") +nparts = (2,2,2) +prun(test_exception,mpi,nparts) diff --git a/test/mpi/runtests.jl b/test/mpi/runtests.jl index cddee559..27b9132e 100644 --- a/test/mpi/runtests.jl +++ b/test/mpi/runtests.jl @@ -14,4 +14,7 @@ using Test @testset "PTimers" begin include("PTimersTests.jl") end +@testset "ExceptionTests" begin include("ExceptionTests.jl") end + + end diff --git a/test/sequential/FDMTests.jl b/test/sequential/FDMTests.jl index 2485ecc5..54439fce 100644 --- a/test/sequential/FDMTests.jl +++ b/test/sequential/FDMTests.jl @@ -3,9 +3,9 @@ module FDMTests include("../test_fdm.jl") nparts = (2,2,2) -prun(test_fdm,sequential,nparts) +prun_debug(test_fdm,sequential,nparts) nparts = 4 -prun(test_fdm,sequential,nparts) +prun_debug(test_fdm,sequential,nparts) end # module diff --git a/test/test_exception.jl b/test/test_exception.jl new file mode 100644 index 00000000..f2356079 --- /dev/null +++ b/test/test_exception.jl @@ -0,0 +1,13 @@ + +using PartitionedArrays + +function throw_assert(parts) + nparts=length(parts) + map_parts(parts) do part + @assert rand(1:nparts) != part + end +end + +function test_exception(parts) + throw_assert(parts) +end