Automatic Differentiation
 
Loading...
Searching...
No Matches
binary_log_loss.hpp
Go to the documentation of this file.
1#ifndef STAN_MATH_REV_FUN_BINARY_LOG_LOSS_HPP
2#define STAN_MATH_REV_FUN_BINARY_LOG_LOSS_HPP
3
8#include <cmath>
9
10namespace stan {
11namespace math {
12
47inline var binary_log_loss(int y, const var& y_hat) {
48 if (y == 0) {
49 return make_callback_var(-log1p(-y_hat.val()), [y_hat](auto& vi) mutable {
50 y_hat.adj() += vi.adj() / (1.0 - y_hat.val());
51 });
52 } else {
53 return make_callback_var(-std::log(y_hat.val()), [y_hat](auto& vi) mutable {
54 y_hat.adj() -= vi.adj() / y_hat.val();
55 });
56 }
57}
58
62template <typename Mat, require_eigen_t<Mat>* = nullptr>
63inline auto binary_log_loss(int y, const var_value<Mat>& y_hat) {
64 if (y == 0) {
65 return make_callback_var(
66 -(-y_hat.val().array()).log1p(), [y_hat](auto& vi) mutable {
67 y_hat.adj().array() += vi.adj().array() / (1.0 - y_hat.val().array());
68 });
69 } else {
70 return make_callback_var(
71 -y_hat.val().array().log(), [y_hat](auto& vi) mutable {
72 y_hat.adj().array() -= vi.adj().array() / y_hat.val().array();
73 });
74 }
75}
76
81template <typename StdVec, typename Mat, require_eigen_t<Mat>* = nullptr,
82 require_st_integral<StdVec>* = nullptr>
83inline auto binary_log_loss(const StdVec& y, const var_value<Mat>& y_hat) {
84 auto arena_y = to_arena(as_array_or_scalar(y).template cast<bool>());
85 auto ret_val
86 = -(arena_y == 0)
87 .select((-y_hat.val().array()).log1p(), y_hat.val().array().log());
88 return make_callback_var(ret_val, [y_hat, arena_y](auto& vi) mutable {
89 y_hat.adj().array()
90 += vi.adj().array()
91 / (arena_y == 0)
92 .select((1.0 - y_hat.val().array()), -y_hat.val().array());
93 });
94}
95
96} // namespace math
97} // namespace stan
98#endif
select_< as_operation_cl_t< T_condition >, as_operation_cl_t< T_then >, as_operation_cl_t< T_else > > select(T_condition &&condition, T_then &&then, T_else &&els)
Selection operation on kernel generator expressions.
Definition select.hpp:148
T as_array_or_scalar(T &&v)
Returns specified input value.
var_value< plain_type_t< T > > make_callback_var(T &&value, F &&functor)
Creates a new var initialized with a callback_vari with a given value and reverse-pass callback funct...
arena_t< T > to_arena(const T &a)
Converts given argument into a type that either has any dynamic allocation on AD stack or schedules i...
Definition to_arena.hpp:25
fvar< T > log1p(const fvar< T > &x)
Definition log1p.hpp:12
fvar< T > binary_log_loss(int y, const fvar< T > &y_hat)
The lgamma implementation in stan-math is based on either the reentrant safe lgamma_r implementation ...