10#define _FASTLMM_FIT_H_
15#include <RcppArmadillo.h>
34template <
typename T1,
typename T2,
typename T3>
46 const double &lambda = 0,
47 const bool REML =
false);
57 const double &lambda = 0,
58 const bool REML =
false);
70 const double &lambda = 0,
71 const bool REML =
false);
77 const double &lambda = 0,
78 const bool REML =
false);
90 const vec
get_beta()
const {
return this->beta; }
93 return this->delta_hat * this->sigSq_g;
95 const int get_iter()
const {
return this->iter;}
96 const double get_delta()
const {
return this->delta_hat;}
98 return this->sigSq_g * inv_sympd(this->QXX, inv_opts::allow_approx) ;
106 beta.fill(datum::nan);
119 const vec
blup()
const;
122 double ll(
const double &delta);
129 this->logLik =
ll( delta );
130 this->delta_hat = delta;
156 mat Gamma_XX, Gamma_XY;
165 double logLik, sigSq_g, delta_hat;
173template <
typename T1,
typename T2,
typename T3>
180 const double &lambda,
182 wsqrt(sqrt(weights_)),
192 U = this->dcmp.
get_U();
193 s = this->dcmp.
get_s();
195 n_active = accu(weights != 0.0);
198 Gamma_XX = X.t() * X - Xu.t() * Xu;
199 Gamma_XY = X.t() * Y - Xu.t() * Yu;
200 inv_s_delta_Xu = mat( Xu.n_rows, Xu.n_cols);
207template <
typename T1,
typename T2,
typename T3>
216 const double &lambda,
218 wsqrt(sqrt(weights_)),
228 U = this->dcmp.
get_U();
229 s = this->dcmp.
get_s();
231 n_active = accu(weights != 0.0);
234 Gamma_XX = X.t() * X - Xu.t() * Xu;
235 Gamma_XY = X.t() * Y - Xu.t() * Yu;
236 inv_s_delta_Xu = mat( Xu.n_rows, Xu.n_cols);
240template <
typename T1,
typename T2,
typename T3>
248 const mat &Gamma_XX_,
249 const mat &Gamma_XY_,
251 const double &lambda,
253 wsqrt(sqrt(weights_)),
263 U = this->dcmp.
get_U();
264 s = this->dcmp.
get_s();
266 n_active = accu(weights != 0.0);
269 Gamma_XX = Gamma_XX_;
270 Gamma_XY = Gamma_XY_;
271 inv_s_delta_Xu = mat( Xu.n_rows, Xu.n_cols);
275template <
typename T1,
typename T2,
typename T3>
280 const double &lambda,
289 U = this->dcmp.
get_U();
290 s = this->dcmp.
get_s();
293 Gamma_XX = X.t() * X - Xu.t() * Xu;
294 inv_s_delta_Xu = mat( Xu.n_rows, Xu.n_cols);
300template <
typename T1,
typename T2,
typename T3>
309 double h1_sum = delta_hat*(sum(1/(s+delta_hat)) + (n-k) / delta_hat);
315 vec w = (s/(delta_hat*s + pow(delta_hat,2)));
317 double h2_sum = delta_hat * trace(solve(A.t() * X, A.t() * A));
319 return h1_sum - h2_sum;
322template <
typename T1,
typename T2,
typename T3>
329 vec h1 = delta_hat * Usq * (1/(s+delta_hat)) + (1 - sum(Usq, 1));
334 vec w = (s/(delta_hat*s + pow(delta_hat,2)));
336 mat D_A_t = solve(A.t() * X, A.t());
337 vec h2 = delta_hat * sum(A % D_A_t.t(), 1);
345template <
typename T1,
typename T2,
typename T3>
348 return (Y / sqrt(weights)) -
fitted();
352template <
typename T1,
typename T2,
typename T3>
358 return ((U * (sqrt(s) %
blup())) + X * beta) / sqrt(weights);
362template <
typename T1,
typename T2,
typename T3>
378 return (sqrt(s) % ru) / (s + delta_hat);
381template <
typename T1,
typename T2,
typename T3>
385 double rank = Xu.n_rows;
387 inv_s_delta = 1 / (s+delta);
393 QXX = Xu.t() * inv_s_delta_Xu + Gamma_XX / delta;
398 for( uword i=1; i<QXX.n_rows; ++i){
403 QXY = Xu.t() * (inv_s_delta % Yu) + Gamma_XY / delta;
407 int status = solve(beta, QXX, QXY,
408 solve_opts::likely_sympd + arma::solve_opts::no_approx);
413 beta.set_size(QXX.n_rows);
414 beta.fill(datum::nan);
426 double QRR = dot(ru, (inv_s_delta % ru)) + (dot(r,r) - dot(ru,ru)) / delta;
430 double logLik = -n/2.0 * log(2.0*M_PI*sigSq_g) - 1.0/2.0 * (sum( log(s + delta ) ) + (n-rank) * log(delta)) - n/2.0;
442static inline double ll_alone_mat(
double delta_log,
void *arg){
448 fit->eval_delta( exp(delta_log) );
450 return -1.0*fit->get_logLik();
454static inline double ll_alone_spmat(
double delta_log,
void *arg){
456 auto *fit = (fastlmm<mat,mat,sp_mat> *) arg;
460 fit->eval_delta( exp(delta_log) );
462 return -1.0*fit->get_logLik();
466template <
typename T1,
typename T2,
typename T3>
469 double leftIn = left;
470 double rightIn = right;
488 this->logLik = -1*
local_min(leftIn, rightIn, tol, &F, res, iter);
493 this->logLik += sum(omit_nonfinite(log(weights)))/2.0;
495 this->delta_hat = exp(res);
499template <
typename T1,
typename T2,
typename T3>
501 const vec &weights_){
507template <
typename T1,
typename T2,
typename T3>
522 n_active = accu(weights_ != 0.0);
523 this->weights = weights_;
526 this->Gamma_XY = X.t() * Y - Xu.t() * Yu;
530template <
typename T1,
typename T2,
typename T3>
532 const bool &returnUS){
559 res.
se = sqrt(diagvec(V));
577 switch( dcmp.get_type() ){
579 res.
setUS(U, s, dcmp.get_V());
583 res.
setUS(U, s, eye<sp_mat>(U.n_cols, U.n_cols));
double rdf
Definition ModelFit.h:43
vec residuals
Definition ModelFit.h:46
mat vcov
Definition ModelFit.h:45
vec se
Definition ModelFit.h:41
vec hatvalues
Definition ModelFit.h:47
Definition ModelFit.h:201
double sigSq_g
Definition ModelFit.h:206
mat A_sat
Definition ModelFit.h:219
mat hessian_vc
Definition ModelFit.h:219
vec y
Definition ModelFit.h:205
double sigSq_e
Definition ModelFit.h:206
mat B_sat
Definition ModelFit.h:219
void setUS(const mat &U_, const vec &s_, const mat &V_)
Definition ModelFit.h:317
Definition satterthwaite.h:27
const mat get_hessian() const
Definition satterthwaite.h:48
const mat get_A() const
Definition satterthwaite.h:49
const mat get_B() const
Definition satterthwaite.h:50
Definition fastlmm_fit.h:35
void set_model_failure()
Definition fastlmm_fit.h:105
const double get_delta() const
Definition fastlmm_fit.h:96
double ll(const double &delta)
Definition fastlmm_fit.h:382
fastlmm()
Definition fastlmm_fit.h:39
vec get_r()
Definition fastlmm_fit.h:145
const double get_logLik() const
Definition fastlmm_fit.h:89
vec get_y()
Definition fastlmm_fit.h:146
double score_test(const vec &x_)
void estimate_delta(const double &left, const double &right, const double &tol)
Definition fastlmm_fit.h:467
void update_Y(const T1 &Y_)
const mat get_vcov() const
Definition fastlmm_fit.h:97
void update_X(const vec &X_)
vec get_weights()
Definition fastlmm_fit.h:142
const vec fitted() const
Definition fastlmm_fit.h:353
const vec blup() const
Definition fastlmm_fit.h:363
const vec residuals() const
Definition fastlmm_fit.h:346
const double get_sigSq_e() const
Definition fastlmm_fit.h:92
ModelFitLMM get_result(const bool &returnUS=false)
Definition fastlmm_fit.h:531
const double get_rdf() const
Definition fastlmm_fit.h:301
vec get_ru()
Definition fastlmm_fit.h:144
const vec get_beta() const
Definition fastlmm_fit.h:90
const vec hatvalues() const
Definition fastlmm_fit.h:323
const double get_sigSq_g() const
Definition fastlmm_fit.h:91
void update_response(const T1 &Y_, const vec &weights_)
Definition fastlmm_fit.h:500
const mat get_beta_se() const
Definition fastlmm_fit.h:100
const int get_iter() const
Definition fastlmm_fit.h:95
void eval_delta(const double &delta)
Definition fastlmm_fit.h:128
Definition spectralDecomp.h:29
void reweight(const vec &weights, const bool &sort=false)
Definition spectralDecomp.h:76
vec get_s() const
Definition spectralDecomp.h:107
T get_U() const
Definition spectralDecomp.h:102
double local_min(double a, double b, double t, funcStruct *f, double &x, int &calls)
Definition local_min.h:25
mat scaleEachCol(const mat &X, const vec &w)
Definition misc.h:16
bool isSpMatrix(const T &t)
Definition misc.h:64
Definition CleanData.h:17
@ CATEGORICAL
Definition spectralDecomp.h:20
@ GENERAL
Definition spectralDecomp.h:19
ModelDetail
Definition ModelFit.h:26
@ MOST
Definition ModelFit.h:31
@ MEDIUM
Definition ModelFit.h:29
@ LEAST
Definition ModelFit.h:27
@ HIGH
Definition ModelFit.h:30
@ MAX
Definition ModelFit.h:32
@ LOW
Definition ModelFit.h:28
Definition local_min.h:18
double(* function)(double x, void *params)
Definition local_min.h:19
void * params
Definition local_min.h:20