4#include <RcppArmadillo.h>
22static CountTable CreateLUT(
const vec &y,
const vec &weights = {},
const int &maxUniqueEntries = 1000){
26 w = vec(y.n_elem, fill::ones);
30 unordered_map<long, double> lut;
31 lut.reserve(maxUniqueEntries);
35 for(
int i=0; i<y.n_elem; i++){
40 lut[(long) y[i]] += w[i];
45 if( lut.size() > maxUniqueEntries){
79 const bool &doCoxReid,
83 vec y_theta = y + theta;
89 vec value = lgamma(y_theta) - lgamma(theta) + theta*log(theta) - y_theta%log(mu + theta);
90 if( weights.is_empty() ){
93 res = dot(value, weights) / n;
97 double sum_lgamma_y_theta = 0;
99 sum_lgamma_y_theta += it.second * lgamma(it.first + theta);
102 if( weights.is_empty() ){
104 res = sum_lgamma_y_theta / n -
105 n*lgamma(theta) / n +
106 n*theta*log(theta)/n -
107 dot(y_theta, log(mu + theta)) / n;
109 double sum_w = sum(weights);
112 res = sum_lgamma_y_theta / n -
113 sum_w*lgamma(theta) / n +
114 sum_w*theta*log(theta)/n;
116 res -= dot(weights,y_theta%log(mu + theta)) / n;
131 if( weights.is_empty() ){
132 w = 1 / (1.0/mu + 1.0/theta);
134 w = weights / (1.0/mu + 1.0/theta);
137 bool success = log_det_sympd(ld, X.t() * (X.each_col()%w));
138 cr = -0.5 * ld * 0.99;
139 if( ! success ) cr = 0;
148static inline double nb_ll(
double theta_log,
void *arg){
151 auto *data = (
nbData *) arg;
154 return -1.0*nb_ll(data->y, data->mu, data->n, data->weights, data->X, data->doCoxReid, data->ct, exp(theta_log));
166static double nb_theta_ml(
170 const vec &weights = {},
172 const bool doCoxReid =
true,
174 const double &left = -5,
175 const double &right = 20,
176 const double &tol = 0.0001220703){
189 if( weights.is_empty() || all(weights == 1.0) ){
190 d =
new nbData(y, mu, n, {}, X, doCoxReid, ct);
192 d =
new nbData(y, mu, n, weights, X, doCoxReid, ct);
199 local_min(left, right, tol, &F, theta_log, iter);
203 return exp(theta_log);
double local_min(double a, double b, double t, funcStruct *f, double &x, int &calls)
Definition local_min.h:25
unordered_map< long, double > CountTable
Definition nb_theta.h:18
Definition local_min.h:18
double(* function)(double x, void *params)
Definition local_min.h:19
void * params
Definition local_min.h:20
vec y
Definition nb_theta.h:56
vec mu
Definition nb_theta.h:57
double n
Definition nb_theta.h:58
mat X
Definition nb_theta.h:60
vec weights
Definition nb_theta.h:59
bool doCoxReid
Definition nb_theta.h:61
nbData()
Definition nb_theta.h:65
CountTable ct
Definition nb_theta.h:62
nbData(const vec &y, const vec &mu, const double &n, const vec &weights, const mat &X, const bool &doCoxReid, const CountTable &ct={})
Definition nb_theta.h:68