123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- // Copyright Oliver Kowalke 2017.
- // Distributed under the Boost Software License, Version 1.0.
- // (See accompanying file LICENSE_1_0.txt or copy at
- // http://www.boost.org/LICENSE_1_0.txt)
- #ifndef BOOST_FIBERS_CUDA_WAITFOR_H
- #define BOOST_FIBERS_CUDA_WAITFOR_H
- #include <initializer_list>
- #include <mutex>
- #include <iostream>
- #include <set>
- #include <tuple>
- #include <vector>
- #include <boost/assert.hpp>
- #include <boost/config.hpp>
- #include <hip/hip_runtime.h>
- #include <boost/fiber/detail/config.hpp>
- #include <boost/fiber/detail/is_all_same.hpp>
- #include <boost/fiber/condition_variable.hpp>
- #include <boost/fiber/mutex.hpp>
- #ifdef BOOST_HAS_ABI_HEADERS
- # include BOOST_ABI_PREFIX
- #endif
- namespace boost {
- namespace fibers {
- namespace cuda {
- namespace detail {
- template< typename Rendezvous >
- static void trampoline( hipStream_t st, hipError_t status, void * vp) {
- Rendezvous * data = static_cast< Rendezvous * >( vp);
- data->notify( st, status);
- }
- class single_stream_rendezvous {
- public:
- single_stream_rendezvous( hipStream_t st) {
- unsigned int flags = 0;
- hipError_t status = ::hipStreamAddCallback( st, trampoline< single_stream_rendezvous >, this, flags);
- if ( hipSuccess != status) {
- st_ = st;
- status_ = status;
- done_ = true;
- }
- }
- void notify( hipStream_t st, hipError_t status) noexcept {
- std::unique_lock< mutex > lk{ mtx_ };
- st_ = st;
- status_ = status;
- done_ = true;
- lk.unlock();
- cv_.notify_one();
- }
- std::tuple< hipStream_t, hipError_t > wait() {
- std::unique_lock< mutex > lk{ mtx_ };
- cv_.wait( lk, [this]{ return done_; });
- return std::make_tuple( st_, status_);
- }
- private:
- mutex mtx_{};
- condition_variable cv_{};
- hipStream_t st_{};
- hipError_t status_{ hipErrorUnknown };
- bool done_{ false };
- };
- class many_streams_rendezvous {
- public:
- many_streams_rendezvous( std::initializer_list< hipStream_t > l) :
- stx_{ l } {
- results_.reserve( stx_.size() );
- for ( hipStream_t st : stx_) {
- unsigned int flags = 0;
- hipError_t status = ::hipStreamAddCallback( st, trampoline< many_streams_rendezvous >, this, flags);
- if ( hipSuccess != status) {
- std::unique_lock< mutex > lk{ mtx_ };
- stx_.erase( st);
- results_.push_back( std::make_tuple( st, status) );
- }
- }
- }
- void notify( hipStream_t st, hipError_t status) noexcept {
- std::unique_lock< mutex > lk{ mtx_ };
- stx_.erase( st);
- results_.push_back( std::make_tuple( st, status) );
- if ( stx_.empty() ) {
- lk.unlock();
- cv_.notify_one();
- }
- }
- std::vector< std::tuple< hipStream_t, hipError_t > > wait() {
- std::unique_lock< mutex > lk{ mtx_ };
- cv_.wait( lk, [this]{ return stx_.empty(); });
- return results_;
- }
- private:
- mutex mtx_{};
- condition_variable cv_{};
- std::set< hipStream_t > stx_;
- std::vector< std::tuple< hipStream_t, hipError_t > > results_;
- };
- }
- void waitfor_all();
- inline
- std::tuple< hipStream_t, hipError_t > waitfor_all( hipStream_t st) {
- detail::single_stream_rendezvous rendezvous( st);
- return rendezvous.wait();
- }
- template< typename ... STP >
- std::vector< std::tuple< hipStream_t, hipError_t > > waitfor_all( hipStream_t st0, STP ... stx) {
- static_assert( boost::fibers::detail::is_all_same< hipStream_t, STP ...>::value, "all arguments must be of type `CUstream*`.");
- detail::many_streams_rendezvous rendezvous{ st0, stx ... };
- return rendezvous.wait();
- }
- }}}
- #ifdef BOOST_HAS_ABI_HEADERS
- # include BOOST_ABI_SUFFIX
- #endif
- #endif // BOOST_FIBERS_CUDA_WAITFOR_H
|