Automatic Differentiation
 
Loading...
Searching...
No Matches
binary_log_loss.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_BINARY_LOG_LOSS_HPP
2#define STAN_MATH_REV_FUN_BINARY_LOG_LOSS_HPP
3
7#include <cmath>
8
9namespace stan {
10namespace math {
11
46inline var binary_log_loss(int y, const var& y_hat) {
47 if (y == 0) {
48 return make_callback_var(-log1p(-y_hat.val()), [y_hat](auto& vi) mutable {
49 y_hat.adj() += vi.adj() / (1.0 - y_hat.val());
50 });
51 } else {
52 return make_callback_var(-std::log(y_hat.val()), [y_hat](auto& vi) mutable {
53 y_hat.adj() -= vi.adj() / y_hat.val();
54 });
55 }
56}
57
61template <typename Mat, require_eigen_t<Mat>* = nullptr>
62inline auto binary_log_loss(int y, const var_value<Mat>& y_hat) {
63 if (y == 0) {
64 return make_callback_var(
65 -(-y_hat.val().array()).log1p(), [y_hat](auto& vi) mutable {
66 y_hat.adj().array() += vi.adj().array() / (1.0 - y_hat.val().array());
67 });
68 } else {
69 return make_callback_var(
70 -y_hat.val().array().log(), [y_hat](auto& vi) mutable {
71 y_hat.adj().array() -= vi.adj().array() / y_hat.val().array();
72 });
73 }
74}
75
80template <typename StdVec, typename Mat, require_eigen_t<Mat>* = nullptr,
81 require_st_integral<StdVec>* = nullptr>
82inline auto binary_log_loss(const StdVec& y, const var_value<Mat>& y_hat) {
83 auto arena_y = to_arena(as_array_or_scalar(y).template cast<bool>());
84 auto ret_val
85 = -(arena_y == 0)
86 .select((-y_hat.val().array()).log1p(), y_hat.val().array().log());
87 return make_callback_var(ret_val, [y_hat, arena_y](auto& vi) mutable {
88 y_hat.adj().array()
89 += vi.adj().array()
90 / (arena_y == 0)
91 .select((1.0 - y_hat.val().array()), -y_hat.val().array());
92 });
93}
94
95} // namespace math
96} // namespace stan
97#endif
select_< as_operation_cl_t< T_condition >, as_operation_cl_t< T_then >, as_operation_cl_t< T_else > > select(T_condition &&condition, T_then &&then, T_else &&els)
Selection operation on kernel generator expressions.
Definition select.hpp:148
T as_array_or_scalar(T &&v)
Returns specified input value.
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
Definition to_arena.hpp:25
fvar< T > log1p(const fvar< T > &x)
Definition log1p.hpp:12
fvar< T > binary_log_loss(int y, const fvar< T > &y_hat)
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...