123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288 |
- //
- // Copyright (c) 2018-2019, Cem Bassoy, [email protected]
- //
- // Distributed under the Boost Software License, Version 1.0. (See
- // accompanying file LICENSE_1_0.txt or copy at
- // http://www.boost.org/LICENSE_1_0.txt)
- //
- // The authors gratefully acknowledge the support of
- // Fraunhofer IOSB, Ettlingen, Germany
- //
- #ifndef _BOOST_UBLAS_TENSOR_EXPRESSIONS_EVALUATION_HPP_
- #define _BOOST_UBLAS_TENSOR_EXPRESSIONS_EVALUATION_HPP_
- #include <type_traits>
- #include <stdexcept>
- namespace boost::numeric::ublas {
- template<class element_type, class storage_format, class storage_type>
- class tensor;
- template<class size_type>
- class basic_extents;
- }
- namespace boost::numeric::ublas::detail {
- template<class T, class D>
- struct tensor_expression;
- template<class T, class EL, class ER, class OP>
- struct binary_tensor_expression;
- template<class T, class E, class OP>
- struct unary_tensor_expression;
- }
- namespace boost::numeric::ublas::detail {
- template<class T, class E>
- struct has_tensor_types
- { static constexpr bool value = false; };
- template<class T>
- struct has_tensor_types<T,T>
- { static constexpr bool value = true; };
- template<class T, class D>
- struct has_tensor_types<T, tensor_expression<T,D>>
- { static constexpr bool value = std::is_same<T,D>::value || has_tensor_types<T,D>::value; };
- template<class T, class EL, class ER, class OP>
- struct has_tensor_types<T, binary_tensor_expression<T,EL,ER,OP>>
- { static constexpr bool value = std::is_same<T,EL>::value || std::is_same<T,ER>::value || has_tensor_types<T,EL>::value || has_tensor_types<T,ER>::value; };
- template<class T, class E, class OP>
- struct has_tensor_types<T, unary_tensor_expression<T,E,OP>>
- { static constexpr bool value = std::is_same<T,E>::value || has_tensor_types<T,E>::value; };
- } // namespace boost::numeric::ublas::detail
- namespace boost::numeric::ublas::detail {
- /** @brief Retrieves extents of the tensor
- *
- */
- template<class T, class F, class A>
- auto retrieve_extents(tensor<T,F,A> const& t)
- {
- return t.extents();
- }
- /** @brief Retrieves extents of the tensor expression
- *
- * @note tensor expression must be a binary tree with at least one tensor type
- *
- * @returns extents of the child expression if it is a tensor or extents of one child of its child.
- */
- template<class T, class D>
- auto retrieve_extents(tensor_expression<T,D> const& expr)
- {
- static_assert(detail::has_tensor_types<T,tensor_expression<T,D>>::value,
- "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
- auto const& cast_expr = static_cast<D const&>(expr);
- if constexpr ( std::is_same<T,D>::value )
- return cast_expr.extents();
- else
- return retrieve_extents(cast_expr);
- }
- /** @brief Retrieves extents of the binary tensor expression
- *
- * @note tensor expression must be a binary tree with at least one tensor type
- *
- * @returns extents of the (left and if necessary then right) child expression if it is a tensor or extents of a child of its (left and if necessary then right) child.
- */
- template<class T, class EL, class ER, class OP>
- auto retrieve_extents(binary_tensor_expression<T,EL,ER,OP> const& expr)
- {
- static_assert(detail::has_tensor_types<T,binary_tensor_expression<T,EL,ER,OP>>::value,
- "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
- if constexpr ( std::is_same<T,EL>::value )
- return expr.el.extents();
- if constexpr ( std::is_same<T,ER>::value )
- return expr.er.extents();
- else if constexpr ( detail::has_tensor_types<T,EL>::value )
- return retrieve_extents(expr.el);
- else if constexpr ( detail::has_tensor_types<T,ER>::value )
- return retrieve_extents(expr.er);
- }
- /** @brief Retrieves extents of the binary tensor expression
- *
- * @note tensor expression must be a binary tree with at least one tensor type
- *
- * @returns extents of the child expression if it is a tensor or extents of a child of its child.
- */
- template<class T, class E, class OP>
- auto retrieve_extents(unary_tensor_expression<T,E,OP> const& expr)
- {
- static_assert(detail::has_tensor_types<T,unary_tensor_expression<T,E,OP>>::value,
- "Error in boost::numeric::ublas::detail::retrieve_extents: Expression to evaluate should contain tensors.");
- if constexpr ( std::is_same<T,E>::value )
- return expr.e.extents();
- else if constexpr ( detail::has_tensor_types<T,E>::value )
- return retrieve_extents(expr.e);
- }
- } // namespace boost::numeric::ublas::detail
- ///////////////
- namespace boost::numeric::ublas::detail {
- template<class T, class F, class A, class S>
- auto all_extents_equal(tensor<T,F,A> const& t, basic_extents<S> const& extents)
- {
- return extents == t.extents();
- }
- template<class T, class D, class S>
- auto all_extents_equal(tensor_expression<T,D> const& expr, basic_extents<S> const& extents)
- {
- static_assert(detail::has_tensor_types<T,tensor_expression<T,D>>::value,
- "Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
- auto const& cast_expr = static_cast<D const&>(expr);
- if constexpr ( std::is_same<T,D>::value )
- if( extents != cast_expr.extents() )
- return false;
- if constexpr ( detail::has_tensor_types<T,D>::value )
- if ( !all_extents_equal(cast_expr, extents))
- return false;
- return true;
- }
- template<class T, class EL, class ER, class OP, class S>
- auto all_extents_equal(binary_tensor_expression<T,EL,ER,OP> const& expr, basic_extents<S> const& extents)
- {
- static_assert(detail::has_tensor_types<T,binary_tensor_expression<T,EL,ER,OP>>::value,
- "Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
- if constexpr ( std::is_same<T,EL>::value )
- if(extents != expr.el.extents())
- return false;
- if constexpr ( std::is_same<T,ER>::value )
- if(extents != expr.er.extents())
- return false;
- if constexpr ( detail::has_tensor_types<T,EL>::value )
- if(!all_extents_equal(expr.el, extents))
- return false;
- if constexpr ( detail::has_tensor_types<T,ER>::value )
- if(!all_extents_equal(expr.er, extents))
- return false;
- return true;
- }
- template<class T, class E, class OP, class S>
- auto all_extents_equal(unary_tensor_expression<T,E,OP> const& expr, basic_extents<S> const& extents)
- {
- static_assert(detail::has_tensor_types<T,unary_tensor_expression<T,E,OP>>::value,
- "Error in boost::numeric::ublas::detail::all_extents_equal: Expression to evaluate should contain tensors.");
- if constexpr ( std::is_same<T,E>::value )
- if(extents != expr.e.extents())
- return false;
- if constexpr ( detail::has_tensor_types<T,E>::value )
- if(!all_extents_equal(expr.e, extents))
- return false;
- return true;
- }
- } // namespace boost::numeric::ublas::detail
- namespace boost::numeric::ublas::detail {
- /** @brief Evaluates expression for a tensor
- *
- * Assigns the results of the expression to the tensor.
- *
- * \note Checks if shape of the tensor matches those of all tensors within the expression.
- */
- template<class tensor_type, class derived_type>
- void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type> const& expr)
- {
- if constexpr (detail::has_tensor_types<tensor_type, tensor_expression<tensor_type,derived_type> >::value )
- if(!detail::all_extents_equal(expr, lhs.extents() ))
- throw std::runtime_error("Error in boost::numeric::ublas::tensor: expression contains tensors with different shapes.");
- #pragma omp parallel for
- for(auto i = 0u; i < lhs.size(); ++i)
- lhs(i) = expr()(i);
- }
- /** @brief Evaluates expression for a tensor
- *
- * Applies a unary function to the results of the expressions before the assignment.
- * Usually applied needed for unary operators such as A += C;
- *
- * \note Checks if shape of the tensor matches those of all tensors within the expression.
- */
- template<class tensor_type, class derived_type, class unary_fn>
- void eval(tensor_type& lhs, tensor_expression<tensor_type, derived_type> const& expr, unary_fn const fn)
- {
- if constexpr (detail::has_tensor_types< tensor_type, tensor_expression<tensor_type,derived_type> >::value )
- if(!detail::all_extents_equal( expr, lhs.extents() ))
- throw std::runtime_error("Error in boost::numeric::ublas::tensor: expression contains tensors with different shapes.");
- #pragma omp parallel for
- for(auto i = 0u; i < lhs.size(); ++i)
- fn(lhs(i), expr()(i));
- }
- /** @brief Evaluates expression for a tensor
- *
- * Applies a unary function to the results of the expressions before the assignment.
- * Usually applied needed for unary operators such as A += C;
- *
- * \note Checks if shape of the tensor matches those of all tensors within the expression.
- */
- template<class tensor_type, class unary_fn>
- void eval(tensor_type& lhs, unary_fn const fn)
- {
- #pragma omp parallel for
- for(auto i = 0u; i < lhs.size(); ++i)
- fn(lhs(i));
- }
- }
- #endif
|