scatter.hpp 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. // Copyright (C) 2005, 2006 Douglas Gregor.
  2. // Use, modification and distribution is subject to the Boost Software
  3. // License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
  4. // http://www.boost.org/LICENSE_1_0.txt)
  5. // Message Passing Interface 1.1 -- Section 4.6. Scatter
  6. #ifndef BOOST_MPI_SCATTER_HPP
  7. #define BOOST_MPI_SCATTER_HPP
  8. #include <boost/mpi/exception.hpp>
  9. #include <boost/mpi/datatype.hpp>
  10. #include <vector>
  11. #include <boost/mpi/packed_oarchive.hpp>
  12. #include <boost/mpi/packed_iarchive.hpp>
  13. #include <boost/mpi/detail/point_to_point.hpp>
  14. #include <boost/mpi/communicator.hpp>
  15. #include <boost/mpi/environment.hpp>
  16. #include <boost/mpi/detail/offsets.hpp>
  17. #include <boost/mpi/detail/antiques.hpp>
  18. #include <boost/assert.hpp>
  19. namespace boost { namespace mpi {
  20. namespace detail {
  21. // We're scattering from the root for a type that has an associated MPI
  22. // datatype, so we'll use MPI_Scatter to do all of the work.
  23. template<typename T>
  24. void
  25. scatter_impl(const communicator& comm, const T* in_values, T* out_values,
  26. int n, int root, mpl::true_)
  27. {
  28. MPI_Datatype type = get_mpi_datatype<T>(*in_values);
  29. BOOST_MPI_CHECK_RESULT(MPI_Scatter,
  30. (const_cast<T*>(in_values), n, type,
  31. out_values, n, type, root, comm));
  32. }
  33. // We're scattering from a non-root for a type that has an associated MPI
  34. // datatype, so we'll use MPI_Scatter to do all of the work.
  35. template<typename T>
  36. void
  37. scatter_impl(const communicator& comm, T* out_values, int n, int root,
  38. mpl::true_)
  39. {
  40. MPI_Datatype type = get_mpi_datatype<T>(*out_values);
  41. BOOST_MPI_CHECK_RESULT(MPI_Scatter,
  42. (0, n, type,
  43. out_values, n, type,
  44. root, comm));
  45. }
  46. // Fill the sendbuf while keeping trac of the slot's footprints
  47. // Used in the first steps of both scatter and scatterv
  48. // Nslots contains the number of slots being sent
  49. // to each process (identical values for scatter).
  50. // skiped_slots, if present, is deduced from the
  51. // displacement array authorised be the MPI API,
  52. // for some yet to be determined reason.
  53. template<typename T>
  54. void
  55. fill_scatter_sendbuf(const communicator& comm, T const* values,
  56. int const* nslots, int const* skipped_slots,
  57. packed_oarchive::buffer_type& sendbuf, std::vector<int>& archsizes) {
  58. int nproc = comm.size();
  59. archsizes.resize(nproc);
  60. for (int dest = 0; dest < nproc; ++dest) {
  61. if (skipped_slots) { // wee need to keep this for backward compatibility
  62. for(int k= 0; k < skipped_slots[dest]; ++k) ++values;
  63. }
  64. packed_oarchive procarchive(comm);
  65. for (int i = 0; i < nslots[dest]; ++i) {
  66. procarchive << *values++;
  67. }
  68. int archsize = procarchive.size();
  69. sendbuf.resize(sendbuf.size() + archsize);
  70. archsizes[dest] = archsize;
  71. char const* aptr = static_cast<char const*>(procarchive.address());
  72. std::copy(aptr, aptr+archsize, sendbuf.end()-archsize);
  73. }
  74. }
  75. template<typename T, class A>
  76. T*
  77. non_const_data(std::vector<T,A> const& v) {
  78. using detail::c_data;
  79. return const_cast<T*>(c_data(v));
  80. }
  81. // Dispatch the sendbuf among proc.
  82. // Used in the second steps of both scatter and scatterv
  83. // in_value is only provide in the non variadic case.
  84. template<typename T>
  85. void
  86. dispatch_scatter_sendbuf(const communicator& comm,
  87. packed_oarchive::buffer_type const& sendbuf, std::vector<int> const& archsizes,
  88. T const* in_values,
  89. T* out_values, int n, int root) {
  90. // Distribute the sizes
  91. int myarchsize;
  92. BOOST_MPI_CHECK_RESULT(MPI_Scatter,
  93. (non_const_data(archsizes), 1, MPI_INT,
  94. &myarchsize, 1, MPI_INT, root, comm));
  95. std::vector<int> offsets;
  96. if (root == comm.rank()) {
  97. sizes2offsets(archsizes, offsets);
  98. }
  99. // Get my proc archive
  100. packed_iarchive::buffer_type recvbuf;
  101. recvbuf.resize(myarchsize);
  102. BOOST_MPI_CHECK_RESULT(MPI_Scatterv,
  103. (non_const_data(sendbuf), non_const_data(archsizes), c_data(offsets), MPI_BYTE,
  104. c_data(recvbuf), recvbuf.size(), MPI_BYTE,
  105. root, MPI_Comm(comm)));
  106. // Unserialize
  107. if ( in_values != 0 && root == comm.rank()) {
  108. // Our own local values are already here: just copy them.
  109. std::copy(in_values + root * n, in_values + (root + 1) * n, out_values);
  110. } else {
  111. // Otherwise deserialize:
  112. packed_iarchive iarchv(comm, recvbuf);
  113. for (int i = 0; i < n; ++i) {
  114. iarchv >> out_values[i];
  115. }
  116. }
  117. }
  118. // We're scattering from the root for a type that does not have an
  119. // associated MPI datatype, so we'll need to serialize it.
  120. template<typename T>
  121. void
  122. scatter_impl(const communicator& comm, const T* in_values, T* out_values,
  123. int n, int root, mpl::false_)
  124. {
  125. packed_oarchive::buffer_type sendbuf;
  126. std::vector<int> archsizes;
  127. if (root == comm.rank()) {
  128. std::vector<int> nslots(comm.size(), n);
  129. fill_scatter_sendbuf(comm, in_values, c_data(nslots), (int const*)0, sendbuf, archsizes);
  130. }
  131. dispatch_scatter_sendbuf(comm, sendbuf, archsizes, in_values, out_values, n, root);
  132. }
  133. template<typename T>
  134. void
  135. scatter_impl(const communicator& comm, T* out_values, int n, int root,
  136. mpl::false_ is_mpi_type)
  137. {
  138. scatter_impl(comm, (T const*)0, out_values, n, root, is_mpi_type);
  139. }
  140. } // end namespace detail
  141. template<typename T>
  142. void
  143. scatter(const communicator& comm, const T* in_values, T& out_value, int root)
  144. {
  145. detail::scatter_impl(comm, in_values, &out_value, 1, root, is_mpi_datatype<T>());
  146. }
  147. template<typename T>
  148. void
  149. scatter(const communicator& comm, const std::vector<T>& in_values, T& out_value,
  150. int root)
  151. {
  152. using detail::c_data;
  153. ::boost::mpi::scatter<T>(comm, c_data(in_values), out_value, root);
  154. }
  155. template<typename T>
  156. void scatter(const communicator& comm, T& out_value, int root)
  157. {
  158. BOOST_ASSERT(comm.rank() != root);
  159. detail::scatter_impl(comm, &out_value, 1, root, is_mpi_datatype<T>());
  160. }
  161. template<typename T>
  162. void
  163. scatter(const communicator& comm, const T* in_values, T* out_values, int n,
  164. int root)
  165. {
  166. detail::scatter_impl(comm, in_values, out_values, n, root, is_mpi_datatype<T>());
  167. }
  168. template<typename T>
  169. void
  170. scatter(const communicator& comm, const std::vector<T>& in_values,
  171. T* out_values, int n, int root)
  172. {
  173. ::boost::mpi::scatter(comm, detail::c_data(in_values), out_values, n, root);
  174. }
  175. template<typename T>
  176. void scatter(const communicator& comm, T* out_values, int n, int root)
  177. {
  178. BOOST_ASSERT(comm.rank() != root);
  179. detail::scatter_impl(comm, out_values, n, root, is_mpi_datatype<T>());
  180. }
  181. } } // end namespace boost::mpi
  182. #endif // BOOST_MPI_SCATTER_HPP