diff --git a/src/ThreadPools.jl b/src/ThreadPools.jl index 7ea8f38..eac9cd6 100644 --- a/src/ThreadPools.jl +++ b/src/ThreadPools.jl @@ -8,6 +8,7 @@ export tmap, bmap, qmap, qbmap export logtmap, logbmap, logqmap, logqbmap export tforeach, bforeach, qforeach, qbforeach export logtforeach, logbforeach, logqforeach, logqbforeach +export spawn_background, checked_fetch include("interface.jl") include("macros.jl") @@ -17,6 +18,7 @@ include("qpool.jl") include("logstaticpool.jl") include("logqpool.jl") include("simplefuncs.jl") +include("spawn_background.jl") export @pthreads, pwith diff --git a/src/spawn_background.jl b/src/spawn_background.jl new file mode 100644 index 0000000..3e9fec7 --- /dev/null +++ b/src/spawn_background.jl @@ -0,0 +1,61 @@ +const AVAILABLE_THREADS = Base.RefValue{Channel{Int}}() + +# Somehow, fetch doesn't do a very good job at preserving +# stacktraces. So, we catch any error in spawn_background +# And return it as a CapturedException, and then use checked_fetch to +# rethrow any exception in that case +function checked_fetch(future) + value = fetch(future) + value isa Exception && throw(value) + return value +end + +""" + spawnbg(f) + +Spawn work on any available background thread. +Captures any exception thrown in the thread, to give better stacktraces. + +You can use `checked_fetch(spawnbg(f))` to rethrow any exception. + + ** Warning ** this doesn't compose with other ways of scheduling threads + So, one should use `spawn_background` exclusively in each Julia process. +""" +function spawnbg(f) + # -1, because we don't spawn on foreground thread 1 + nbackground = Threads.nthreads() - 1 + if nbackground == 0 + # we don't run in threaded mode, so we just run things async + # to not block forever + @warn("No threads available, running in foreground thread") + return @async try + return f() + catch e + # If we don't do this, we get pretty bad stack traces... not sure why!? + return CapturedException(e, Base.catch_backtrace()) + end + end + # Initialize dynamically, could also do this in __init__ but it's nice to keep things in one place + if !isassigned(AVAILABLE_THREADS) + # Allocate a Channel with n background threads + c = Channel{Int}(nbackground) + AVAILABLE_THREADS[] = c + # fill queue with available threads + foreach(i -> put!(c, i + 1), 1:nbackground) + end + # take the next free thread... Will block/wait until a thread becomes free + thread_id = take!(AVAILABLE_THREADS[]) + + return ThreadPools.@tspawnat thread_id begin + try + return f() + catch e + # If we don't do this, we get pretty bad stack traces... + # not sure why something so basic just doesn't work well \_(ツ)_/¯ + return CapturedException(e, Base.catch_backtrace()) + finally + # Make thread available again after work is done! + put!(AVAILABLE_THREADS[], thread_id) + end + end +end diff --git a/test/Project.toml b/test/Project.toml index 617f165..6924127 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ThreadPools = "b189fb0b-2eb5-4ed4-bc0c-d34c51242431" diff --git a/test/runtests_exec.jl b/test/runtests_exec.jl index 69ff10c..575837d 100644 --- a/test/runtests_exec.jl +++ b/test/runtests_exec.jl @@ -1,4 +1,4 @@ - +using Test, ThreadPools include("teststatic.jl") include("testlogstatic.jl") include("testq.jl") @@ -14,3 +14,4 @@ include("testmultiarg.jl") include("testmisc.jl") include("testplots.jl") include("errorhandling.jl") +include("spawn_background.jl") diff --git a/test/spawn_background.jl b/test/spawn_background.jl new file mode 100644 index 0000000..b63a16a --- /dev/null +++ b/test/spawn_background.jl @@ -0,0 +1,97 @@ +module TestSpawnBackground + +using Test +using ThreadPools +using Statistics + +# Put in the function to test this in compiled form +# to make sure there is no yield etc introduced from running interpreted +function uses_all_threads() + bg_nthreads = Threads.nthreads() - 1 + bg_threads = zeros(bg_nthreads) + futures = map(1:bg_nthreads) do i + return spawnbg() do + id = Threads.threadid() + bg_threads[id - 1] = id + return id + end + end + foreach(wait, futures) + return sum(bg_threads) == sum(2:(bg_nthreads + 1)) +end + +function busy_wait(time_s) + t = time() + while time() - t < time_s + end + return +end + +function count_occurence(list) + occurences = Dict{Int,Int}() + for elem in list + i = get!(occurences, elem, 0) + occurences[elem] = i + 1 + end + return occurences +end + +function spam_threads(f, spam_factor) + bg_nthreads = Threads.nthreads() - 1 + n_executions = bg_nthreads * spam_factor + thread_ids = [] + time_spent = @elapsed begin + futures = map(1:n_executions) do i + return spawnbg() do + f() + return Threads.threadid() + end + end + thread_ids = map(fetch, futures) + end + return time_spent, thread_ids +end + +function spam_threads_busy(time_waiting, spam_factor) + return spam_threads(spam_factor) do + return busy_wait(time_waiting) + end +end + +@testset "threading" begin + nthreads = Threads.nthreads() + bg_nthreads = nthreads - 1 + if bg_nthreads == 0 + @test fetch(spawnbg(()-> Threads.threadid())) == 1 + else + @testset "scheduling" begin + # When we quickly schedule nthreads work items, the implementation should use all threads + @test uses_all_threads() + + spam_factor = 5 + time_spent, thread_ids = spam_threads(() -> nothing, spam_factor) + occurences = count_occurence(thread_ids) + # We should spread out work to all threads when spamming lots of tasks + @test all(x -> x in keys(occurences), 2:bg_nthreads) + # a few threads may get more work items, but the mean should be equal to the spamfactor + @test spam_factor == mean(values(occurences)) + + time_spent, thread_ids = spam_threads_busy(0.5, spam_factor) + occurences = count_occurence(thread_ids) + @test all(x -> x in keys(occurences), 2:bg_nthreads) + @test spam_factor == mean(values(occurences)) + # I'm not sure how stable this will be on the CI, we may need to tweak the atol + @test time_spent ≈ 0.5 * spam_factor atol = 0.1 + end + @testset "Queue contains all threads, after work is done" begin + @test length(unique(ThreadPools.AVAILABLE_THREADS[].data)) == bg_nthreads + end + end + + @testset "error handling" begin + @test_throws CapturedException checked_fetch(spawnbg(() -> error("hey"))) + end + +end + +end