set_union.hpp 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. //---------------------------------------------------------------------------//
  2. // Copyright (c) 2014 Roshan <[email protected]>
  3. //
  4. // Distributed under the Boost Software License, Version 1.0
  5. // See accompanying file LICENSE_1_0.txt or copy at
  6. // http://www.boost.org/LICENSE_1_0.txt
  7. //
  8. // See http://boostorg.github.com/compute for more information.
  9. //---------------------------------------------------------------------------//
  10. #ifndef BOOST_COMPUTE_ALGORITHM_SET_UNION_HPP
  11. #define BOOST_COMPUTE_ALGORITHM_SET_UNION_HPP
  12. #include <iterator>
  13. #include <boost/static_assert.hpp>
  14. #include <boost/compute/algorithm/detail/balanced_path.hpp>
  15. #include <boost/compute/algorithm/detail/compact.hpp>
  16. #include <boost/compute/algorithm/exclusive_scan.hpp>
  17. #include <boost/compute/algorithm/fill_n.hpp>
  18. #include <boost/compute/container/vector.hpp>
  19. #include <boost/compute/detail/iterator_range_size.hpp>
  20. #include <boost/compute/detail/meta_kernel.hpp>
  21. #include <boost/compute/system.hpp>
  22. #include <boost/compute/type_traits/is_device_iterator.hpp>
  23. namespace boost {
  24. namespace compute {
  25. namespace detail {
  26. ///
  27. /// \brief Serial set union kernel class
  28. ///
  29. /// Subclass of meta_kernel to perform serial set union after tiling
  30. ///
  31. class serial_set_union_kernel : meta_kernel
  32. {
  33. public:
  34. unsigned int tile_size;
  35. serial_set_union_kernel() : meta_kernel("set_union")
  36. {
  37. tile_size = 4;
  38. }
  39. template<class InputIterator1, class InputIterator2,
  40. class InputIterator3, class InputIterator4,
  41. class OutputIterator1, class OutputIterator2>
  42. void set_range(InputIterator1 first1,
  43. InputIterator2 first2,
  44. InputIterator3 tile_first1,
  45. InputIterator3 tile_last1,
  46. InputIterator4 tile_first2,
  47. OutputIterator1 result,
  48. OutputIterator2 counts)
  49. {
  50. m_count = iterator_range_size(tile_first1, tile_last1) - 1;
  51. *this <<
  52. "uint i = get_global_id(0);\n" <<
  53. "uint start1 = " << tile_first1[expr<uint_>("i")] << ";\n" <<
  54. "uint end1 = " << tile_first1[expr<uint_>("i+1")] << ";\n" <<
  55. "uint start2 = " << tile_first2[expr<uint_>("i")] << ";\n" <<
  56. "uint end2 = " << tile_first2[expr<uint_>("i+1")] << ";\n" <<
  57. "uint index = i*" << tile_size << ";\n" <<
  58. "uint count = 0;\n" <<
  59. "while(start1<end1 && start2<end2)\n" <<
  60. "{\n" <<
  61. " if(" << first1[expr<uint_>("start1")] << " == " <<
  62. first2[expr<uint_>("start2")] << ")\n" <<
  63. " {\n" <<
  64. result[expr<uint_>("index")] <<
  65. " = " << first1[expr<uint_>("start1")] << ";\n" <<
  66. " index++; count++;\n" <<
  67. " start1++; start2++;\n" <<
  68. " }\n" <<
  69. " else if(" << first1[expr<uint_>("start1")] << " < " <<
  70. first2[expr<uint_>("start2")] << ")\n" <<
  71. " {\n" <<
  72. result[expr<uint_>("index")] <<
  73. " = " << first1[expr<uint_>("start1")] << ";\n" <<
  74. " index++; count++;\n" <<
  75. " start1++;\n" <<
  76. " }\n" <<
  77. " else\n" <<
  78. " {\n" <<
  79. result[expr<uint_>("index")] <<
  80. " = " << first2[expr<uint_>("start2")] << ";\n" <<
  81. " index++; count++;\n" <<
  82. " start2++;\n" <<
  83. " }\n" <<
  84. "}\n" <<
  85. "while(start1<end1)\n" <<
  86. "{\n" <<
  87. result[expr<uint_>("index")] <<
  88. " = " << first1[expr<uint_>("start1")] << ";\n" <<
  89. " index++; count++;\n" <<
  90. " start1++;\n" <<
  91. "}\n" <<
  92. "while(start2<end2)\n" <<
  93. "{\n" <<
  94. result[expr<uint_>("index")] <<
  95. " = " << first2[expr<uint_>("start2")] << ";\n" <<
  96. " index++; count++;\n" <<
  97. " start2++;\n" <<
  98. "}\n" <<
  99. counts[expr<uint_>("i")] << " = count;\n";
  100. }
  101. event exec(command_queue &queue)
  102. {
  103. if(m_count == 0) {
  104. return event();
  105. }
  106. return exec_1d(queue, 0, m_count);
  107. }
  108. private:
  109. size_t m_count;
  110. };
  111. } //end detail namespace
  112. ///
  113. /// \brief Set union algorithm
  114. ///
  115. /// Finds the union of the sorted range [first1, last1) with the sorted
  116. /// range [first2, last2) and stores it in range starting at result
  117. /// \return Iterator pointing to end of union
  118. ///
  119. /// \param first1 Iterator pointing to start of first set
  120. /// \param last1 Iterator pointing to end of first set
  121. /// \param first2 Iterator pointing to start of second set
  122. /// \param last2 Iterator pointing to end of second set
  123. /// \param result Iterator pointing to start of range in which the union
  124. /// will be stored
  125. /// \param queue Queue on which to execute
  126. ///
  127. /// Space complexity:
  128. /// \Omega(2(distance(\p first1, \p last1) + distance(\p first2, \p last2)))
  129. template<class InputIterator1, class InputIterator2, class OutputIterator>
  130. inline OutputIterator set_union(InputIterator1 first1,
  131. InputIterator1 last1,
  132. InputIterator2 first2,
  133. InputIterator2 last2,
  134. OutputIterator result,
  135. command_queue &queue = system::default_queue())
  136. {
  137. BOOST_STATIC_ASSERT(is_device_iterator<InputIterator1>::value);
  138. BOOST_STATIC_ASSERT(is_device_iterator<InputIterator2>::value);
  139. BOOST_STATIC_ASSERT(is_device_iterator<OutputIterator>::value);
  140. typedef typename std::iterator_traits<InputIterator1>::value_type value_type;
  141. int tile_size = 1024;
  142. int count1 = detail::iterator_range_size(first1, last1);
  143. int count2 = detail::iterator_range_size(first2, last2);
  144. vector<uint_> tile_a((count1+count2+tile_size-1)/tile_size+1, queue.get_context());
  145. vector<uint_> tile_b((count1+count2+tile_size-1)/tile_size+1, queue.get_context());
  146. // Tile the sets
  147. detail::balanced_path_kernel tiling_kernel;
  148. tiling_kernel.tile_size = tile_size;
  149. tiling_kernel.set_range(first1, last1, first2, last2,
  150. tile_a.begin()+1, tile_b.begin()+1);
  151. fill_n(tile_a.begin(), 1, 0, queue);
  152. fill_n(tile_b.begin(), 1, 0, queue);
  153. tiling_kernel.exec(queue);
  154. fill_n(tile_a.end()-1, 1, count1, queue);
  155. fill_n(tile_b.end()-1, 1, count2, queue);
  156. vector<value_type> temp_result(count1+count2, queue.get_context());
  157. vector<uint_> counts((count1+count2+tile_size-1)/tile_size + 1, queue.get_context());
  158. fill_n(counts.end()-1, 1, 0, queue);
  159. // Find individual unions
  160. detail::serial_set_union_kernel union_kernel;
  161. union_kernel.tile_size = tile_size;
  162. union_kernel.set_range(first1, first2, tile_a.begin(), tile_a.end(),
  163. tile_b.begin(), temp_result.begin(), counts.begin());
  164. union_kernel.exec(queue);
  165. exclusive_scan(counts.begin(), counts.end(), counts.begin(), queue);
  166. // Compact the results
  167. detail::compact_kernel compact_kernel;
  168. compact_kernel.tile_size = tile_size;
  169. compact_kernel.set_range(temp_result.begin(), counts.begin(), counts.end(), result);
  170. compact_kernel.exec(queue);
  171. return result + (counts.end() - 1).read(queue);
  172. }
  173. } //end compute namespace
  174. } //end boost namespace
  175. #endif // BOOST_COMPUTE_ALGORITHM_SET_UNION_HPP