logsumexp.hpp 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. // (C) Copyright Matt Borland 2022.
  2. // Use, modification and distribution are subject to the
  3. // Boost Software License, Version 1.0. (See accompanying file
  4. // LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
  5. #include <cmath>
  6. #include <iterator>
  7. #include <utility>
  8. #include <algorithm>
  9. #include <type_traits>
  10. #include <initializer_list>
  11. #include <boost/math/special_functions/logaddexp.hpp>
  12. namespace boost { namespace math {
  13. // https://nhigham.com/2021/01/05/what-is-the-log-sum-exp-function/
  14. // See equation (#)
  15. template <typename ForwardIterator, typename Real = typename std::iterator_traits<ForwardIterator>::value_type>
  16. Real logsumexp(ForwardIterator first, ForwardIterator last)
  17. {
  18. using std::exp;
  19. using std::log1p;
  20. const auto elem = std::max_element(first, last);
  21. const Real max_val = *elem;
  22. Real arg = 0;
  23. while (first != last)
  24. {
  25. if (first != elem)
  26. {
  27. arg += exp(*first - max_val);
  28. }
  29. ++first;
  30. }
  31. return max_val + log1p(arg);
  32. }
  33. template <typename Container, typename Real = typename Container::value_type>
  34. inline Real logsumexp(const Container& c)
  35. {
  36. return logsumexp(std::begin(c), std::end(c));
  37. }
  38. template <typename... Args, typename Real = typename std::common_type<Args...>::type,
  39. typename std::enable_if<std::is_floating_point<Real>::value, bool>::type = true>
  40. inline Real logsumexp(Args&& ...args)
  41. {
  42. std::initializer_list<Real> list {std::forward<Args>(args)...};
  43. if(list.size() == 2)
  44. {
  45. return logaddexp(*list.begin(), *std::next(list.begin()));
  46. }
  47. return logsumexp(list.begin(), list.end());
  48. }
  49. }} // Namespace boost::math