fastglmm
Massively scalable generalized linear mixed models
Loading...
Searching...
No Matches
fastlmm_fit.h
Go to the documentation of this file.
1/***************************************************************
2 * @file fastlmm_fit.h
3 * @author Gabriel Hoffman
4 * @email gabriel.hoffman@mssm.edu
5 * @brief Fit linear mixed model
6 * Copyright (C) 2024 Gabriel Hoffman
7 **************************************************************/
8
9#ifndef _FASTLMM_FIT_H_
10#define _FASTLMM_FIT_H_
11
12// if -D USE_R, use RcppArmadillo library
13#ifdef USE_R
14// [[Rcpp::depends(RcppParallel)]]
15#include <RcppArmadillo.h>
16#else
17#include <armadillo>
18#endif
19
20using namespace arma;
21
22#include "local_min.h"
23#include "misc.h"
24#include "ModelFit.h"
25#include "spectralDecomp.h"
26#include "satterthwaite.h"
27
28namespace fastglmmLib {
29
30// Order of template variables
31// T1 Y
32// T2 X
33// T3 U
34template <typename T1, typename T2, typename T3>
35class fastlmm {
36 public:
37
38 // constructor, minimal
40
41 fastlmm(const T1 &Y_,
42 const T2 &X_,
43 const spectralDecomp<T3> &dcmp,
44 const vec &weights_,
45 const ModelDetail md = LOW,
46 const double &lambda = 0,
47 const bool REML = false);
48
49 // constructor, precompute Yu, Xu
50 fastlmm(const T1 &Y_,
51 const T2 &X_,
52 const spectralDecomp<T3> &dcmp,
53 const vec &weights_,
54 const vec &Yu_,
55 const mat &Xu_,
56 const ModelDetail md = LOW,
57 const double &lambda = 0,
58 const bool REML = false);
59
60 // constructor, precompute Yu, Xu, Gamma_XX, Gamma_XY
61 fastlmm(const T1 &Y_,
62 const T2 &X_,
63 const spectralDecomp<T3> &dcmp,
64 const vec &weights_,
65 const vec &Yu_,
66 const mat &Xu_,
67 const mat &Gamma_XX_,
68 const mat &Gamma_XY_,
69 const ModelDetail md = LOW,
70 const double &lambda = 0,
71 const bool REML = false);
72
73 // constructor without response
74 fastlmm(const T2 &X_,
75 const spectralDecomp<T3> &dcmp,
76 const ModelDetail md = LOW,
77 const double &lambda = 0,
78 const bool REML = false);
79
80 void update_response(const T1 &Y_, const vec &weights_);
81 void update_response(const T1 &Y_,
82 const vec &weights_,
83 const mat &Yu_);
84
85 // extract results
86 ModelFitLMM get_result(const bool &returnUS = false);
87
88 // Accessors
89 const double get_logLik() const { return this->logLik; }
90 const vec get_beta() const { return this->beta; }
91 const double get_sigSq_g() const { return sigSq_g;}
92 const double get_sigSq_e() const {
93 return this->delta_hat * this->sigSq_g;
94 }
95 const int get_iter() const { return this->iter;}
96 const double get_delta() const { return this->delta_hat;}
97 const mat get_vcov() const {
98 return this->sigSq_g * inv_sympd(this->QXX, inv_opts::allow_approx) ;
99 }
100 const mat get_beta_se() const {
101 return sqrt(diagvec(get_vcov()));
102 }
103
104 // if model fails, set beta to nan
106 beta.fill(datum::nan);
107 }
108
111 const double get_rdf() const;
112
113 const vec hatvalues() const; // diag of hat matrix
114 const vec residuals() const;
115 const vec fitted() const;
116
117 // Best linear unbiased predictor of random effect
118 // same as ranef() in R
119 const vec blup() const;
120
121 // compute log likelihood
122 double ll(const double &delta);
123
124 void estimate_delta( const double &left,
125 const double &right,
126 const double &tol);
127 // evaluate logLik, beta, etc at delta value
128 void eval_delta( const double &delta){
129 this->logLik = ll( delta );
130 this->delta_hat = delta;
131 }
132
133 // Score test
134 double score_test(const vec &x_);
135
136 // Update Y, keeping rest constant
137 void update_Y( const T1 &Y_);
138
139 // Update X, keeping rest constant
140 void update_X( const vec &X_);
141
142 vec get_weights(){ return weights;}
143
144 vec get_ru(){ return ru;}
145 vec get_r(){ return r;}
146 vec get_y(){ return Y;}
147
148 private:
149 vec wsqrt;
150 T1 Y;
151 T2 X;
152 T3 U;
153 vec s, weights;
154 T1 Yu;
155 T2 Xu;
156 mat Gamma_XX, Gamma_XY;
157 vec inv_s_delta;
158 mat inv_s_delta_Xu;
159 mat QXX, QXY;
160 vec beta;
161 vec r, ru;
162 ModelDetail md;
163 double lambda;
164 bool REML;
165 double logLik, sigSq_g, delta_hat;
166 int iter = 0;
167 int n_active; // sample size with non-zero weight
169};
170
171
172// constructor, minimal
173template <typename T1, typename T2, typename T3>
175 const T1 &Y_,
176 const T2 &X_,
177 const spectralDecomp<T3> &dcmp,
178 const vec &weights_,
179 const ModelDetail md,
180 const double &lambda,
181 const bool REML):
182 wsqrt(sqrt(weights_)),
183 Y(Y_ % wsqrt),
184 X(scaleEachCol(X_, wsqrt)),
185 weights(weights_),
186 md(md),
187 lambda(lambda),
188 REML(REML),
189 dcmp(dcmp) {
190
191 this->dcmp.reweight(weights);
192 U = this->dcmp.get_U();
193 s = this->dcmp.get_s();
194
195 n_active = accu(weights != 0.0);
196 Yu = U.t() * Y;
197 Xu = U.t() * X;
198 Gamma_XX = X.t() * X - Xu.t() * Xu;
199 Gamma_XY = X.t() * Y - Xu.t() * Yu;
200 inv_s_delta_Xu = mat( Xu.n_rows, Xu.n_cols);
201}
202
203
204
205
206// constructor, precompute Yu, Xu
207template <typename T1, typename T2, typename T3>
209 const T1 &Y_,
210 const T2 &X_,
211 const spectralDecomp<T3> &dcmp,
212 const vec &weights_,
213 const vec &Yu_,
214 const mat &Xu_,
215 const ModelDetail md,
216 const double &lambda,
217 const bool REML):
218 wsqrt(sqrt(weights_)),
219 Y(Y_ % wsqrt),
220 X(scaleEachCol(X_, wsqrt)),
221 dcmp(dcmp),
222 weights(weights_),
223 md(md),
224 lambda(lambda),
225 REML(REML) {
226
227 this->dcmp.reweight(weights);
228 U = this->dcmp.get_U();
229 s = this->dcmp.get_s();
230
231 n_active = accu(weights != 0.0);
232 Yu = Yu_;
233 Xu = Xu_;
234 Gamma_XX = X.t() * X - Xu.t() * Xu;
235 Gamma_XY = X.t() * Y - Xu.t() * Yu;
236 inv_s_delta_Xu = mat( Xu.n_rows, Xu.n_cols);
237}
238
239// constructor, precompute Yu, Xu, Gamma_XX, Gamma_XY
240template <typename T1, typename T2, typename T3>
242 const T1 &Y_,
243 const T2 &X_,
244 const spectralDecomp<T3> &dcmp,
245 const vec &weights_,
246 const vec &Yu_,
247 const mat &Xu_,
248 const mat &Gamma_XX_,
249 const mat &Gamma_XY_,
250 const ModelDetail md,
251 const double &lambda,
252 const bool REML):
253 wsqrt(sqrt(weights_)),
254 Y(Y_ % wsqrt),
255 X(scaleEachCol(X_, wsqrt)),
256 dcmp(dcmp),
257 weights(weights_),
258 md(md),
259 lambda(lambda),
260 REML(REML) {
261
262 this->dcmp.reweight(weights);
263 U = this->dcmp.get_U();
264 s = this->dcmp.get_s();
265
266 n_active = accu(weights != 0.0);
267 Yu = Yu_;
268 Xu = Xu_;
269 Gamma_XX = Gamma_XX_;
270 Gamma_XY = Gamma_XY_;
271 inv_s_delta_Xu = mat( Xu.n_rows, Xu.n_cols);
272}
273
274
275template <typename T1, typename T2, typename T3>
277 const T2 &X_,
278 const spectralDecomp<T3> &dcmp,
279 const ModelDetail md,
280 const double &lambda,
281 const bool REML):
282 X(X_),
283 dcmp(dcmp),
284 md(md),
285 lambda(lambda),
286 REML(REML) {
287
288 this->dcmp.reweight(weights);
289 U = this->dcmp.get_U();
290 s = this->dcmp.get_s();
291
292 Xu = U.t() * X;
293 Gamma_XX = X.t() * X - Xu.t() * Xu;
294 inv_s_delta_Xu = mat( Xu.n_rows, Xu.n_cols);
295}
296
297
298
299
300template <typename T1, typename T2, typename T3>
301const double fastlmm<T1, T2, T3>::get_rdf() const {
302
303 int n = n_active;
304 int k = s.n_elem;
305
306 // X is already scaled
307 // sum(h1)
308 // h1.sum <- with(object, delta*sum(1/(s+delta))) + (n-k)
309 double h1_sum = delta_hat*(sum(1/(s+delta_hat)) + (n-k) / delta_hat);
310
311 // sum(h2)
312 // A <- with(object, X / delta - U %*% ((s/(delta*s + delta^2)) * crossprod(U, X)))
313 // D <- solve(crossprod(A, X))
314 // h2.sum <- object$delta * sum(A * (A %*% D))
315 vec w = (s/(delta_hat*s + pow(delta_hat,2)));
316 mat A = mat(X / delta_hat - U * scaleEachCol( Xu, w));
317 double h2_sum = delta_hat * trace(solve(A.t() * X, A.t() * A));
318
319 return h1_sum - h2_sum;
320}
321
322template <typename T1, typename T2, typename T3>
324
325 // Usq <- model$U^2
326 T3 Usq = square(U);
327
328 // h1 <- model$delta*with(model, Usq %*% (1/(s+delta))) + (1 - rowSums(Usq))
329 vec h1 = delta_hat * Usq * (1/(s+delta_hat)) + (1 - sum(Usq, 1));
330
331 // A <- with(model, X / delta - U %*% ((s/(delta*s + delta^2)) * crossprod(U, X)))
332 // D <- solve(crossprod(A, X))
333 // h2 <- model$delta * rowSums(A * (A %*% D))
334 vec w = (s/(delta_hat*s + pow(delta_hat,2)));
335 mat A = mat(X / delta_hat - U * scaleEachCol( Xu, w));
336 mat D_A_t = solve(A.t() * X, A.t());
337 vec h2 = delta_hat * sum(A % D_A_t.t(), 1);
338
339 // hatvalues
340 return 1 - h1 + h2;
341}
342
343
344
345template <typename T1, typename T2, typename T3>
347
348 return (Y / sqrt(weights)) - fitted();
349}
350
351// return predict(fit)
352template <typename T1, typename T2, typename T3>
354
355 // ** need to scale X because it was transformed at the start
356 // a <- object$U %*% (sqrt(object$s) * ranef.fastlmm(object))
357 // a / sqrt(object$weights) + object$design %*% coef(object)
358 return ((U * (sqrt(s) % blup())) + X * beta) / sqrt(weights);
359}
360
361
362template <typename T1, typename T2, typename T3>
363const vec fastlmm<T1, T2, T3>::blup() const {
364
365 // Zw <- c(sqrt(fit$weights)) * fit$Z
366 // A <- crossprod(fit$U, Zw)
367 // A <- with(fit, crossprod(fit$U, c(sqrt(weights)) * U * sqrt(s)))
368 // b <- fit$ru / (fit$s + fit$delta)
369 // crossprod(A, b)
370
371 // T3 A = U.t() * scaleRowsCols(U, sqrt(weights), sqrt(s));
372 // vec b = ru / (s + delta_hat);
373 // return A.t() * b;
374
375 // # since U^T U is identity if the GRM is full rank
376 // v <- with(object, sqrt(s)*ru / (s + delta))
377
378 return (sqrt(s) % ru) / (s + delta_hat);
379}
380
381template <typename T1, typename T2, typename T3>
382double fastlmm<T1, T2, T3>::ll(const double &delta ) {
383
384 double n = n_active;
385 double rank = Xu.n_rows;
386
387 inv_s_delta = 1 / (s+delta);
388
389 // inv_s_delta_Xu <- inv_s_delta * Xu
390 inv_s_delta_Xu = scaleEachCol(Xu, inv_s_delta);
391
392 // QXX = crossprod(Xu, inv_s_delta_Xu) + Gamma_XX / delta
393 QXX = Xu.t() * inv_s_delta_Xu + Gamma_XX / delta;
394
395 // Ridge penalty
396 // QXX.diag() += lambda;
397 // but not on intercept
398 for( uword i=1; i<QXX.n_rows; ++i){
399 QXX(i,i) += lambda;
400 }
401
402 // QXY = crossprod(Xu, inv_s_delta_Yu) + Gamma_XY / delta
403 QXY = Xu.t() * (inv_s_delta % Yu) + Gamma_XY / delta;
404
405 // beta <<- solve( QXX, QXY)
406 // beta = solve(QXX, QXY, solve_opts::likely_sympd);
407 int status = solve(beta, QXX, QXY,
408 solve_opts::likely_sympd + arma::solve_opts::no_approx);
409
410 // if model failed
411 if( ! status ){
412 // set beta to NAN
413 beta.set_size(QXX.n_rows);
414 beta.fill(datum::nan);
415 }
416
417 // # Eval sig_g
418 // ru <- Yu - Xu %*% beta
419 ru = Yu - Xu * beta;
420
421 // r <- Y - X %*% beta
422 r = Y - X * beta;
423
424 // Qrr <- crossprod(ru, inv_s_delta_ru) + (crossprod(r)[1] - crossprod(ru)[1])/ delta
425 // sig_g <<- Qrr[1] / n
426 double QRR = dot(ru, (inv_s_delta % ru)) + (dot(r,r) - dot(ru,ru)) / delta;
427 sigSq_g = QRR / n;
428
429 // use 2.0 to ensure double precision
430 double logLik = -n/2.0 * log(2.0*M_PI*sigSq_g) - 1.0/2.0 * (sum( log(s + delta ) ) + (n-rank) * log(delta)) - n/2.0;
431
432 // this is fixed, so don't eval every time,
433 // just after estimation
434 // + sum(log(weights))/2.0;
435
436 return logLik;
437}
438
439
440
441// function to be minimized
442static inline double ll_alone_mat( double delta_log, void *arg){
443
444 auto *fit = (fastlmm<mat,mat,mat> *) arg;
445
446 // search is done in log space
447 // to give faster convergence
448 fit->eval_delta( exp(delta_log) );
449
450 return -1.0*fit->get_logLik();
451}
452
453// sparse version
454static inline double ll_alone_spmat( double delta_log, void *arg){
455
456 auto *fit = (fastlmm<mat,mat,sp_mat> *) arg;
457
458 // search is done in log space
459 // to give faster convergence
460 fit->eval_delta( exp(delta_log) );
461
462 return -1.0*fit->get_logLik();
463}
464
465
466template <typename T1, typename T2, typename T3>
467void fastlmm<T1, T2, T3>::estimate_delta( const double &left, const double &right, const double &tol ){
468
469 double leftIn = left;
470 double rightIn = right;
471 iter = 0;
472
473 // initialize function
474 funcStruct F;
475 F.params = this;
476
477 // Since F.function can't take templated function
478 if( isSpMatrix( U ) ){
479 F.function = & ll_alone_spmat;
480 }else{
481 F.function = & ll_alone_mat;
482 }
483
484 // get maximize log-likelihood
485 // need to mutliply but -1 since it actually minimizes
486 // evaluated at minimum value
487 double res;
488 this->logLik = -1*local_min(leftIn, rightIn, tol, &F, res, iter);
489
490 // augment with value this is constant for varying delta's
491 // weights with zero value, give Inf log values
492 // so use omit_nonfinite
493 this->logLik += sum(omit_nonfinite(log(weights)))/2.0;
494
495 this->delta_hat = exp(res);
496}
497
498
499template <typename T1, typename T2, typename T3>
501 const vec &weights_){
502
503 update_response(Y, weights_, U.t() * Y_);
504}
505
506
507template <typename T1, typename T2, typename T3>
509 const vec &weights_,
510 const mat &Yu_){
511
512 // indicator_decomp
513 // modiy this->U and this->s internally
514 // compute sqrt(weights) for
515 // Y <- Y * sqrt(weights)
516 // X <- X * sqrt(weights)
517 // vec sqrtW = sqrt(weights_);
518 // update_weights( Y_, X_, U_, s_, weights_);
519 // Need to save X, U, s unmodified so it
520 // can be weighted later
521
522 n_active = accu(weights_ != 0.0);
523 this->weights = weights_;
524 this->Y = Y_;
525 this->Yu = Yu_;
526 this->Gamma_XY = X.t() * Y - Xu.t() * Yu;
527}
528
529
530template <typename T1, typename T2, typename T3>
532 const bool &returnUS){
533
534 // initialize with standard entries
535 ModelFitLMM res = ModelFitLMM( true,
536 get_logLik(),
537 get_weights(),
538 get_ru(),
539 get_y(),
540 get_delta(),
541 get_sigSq_g(),
542 get_sigSq_e(),
543 get_iter(),
544 1.0,
545 get_beta());
546
547 // set additional values based on ModelDetail md
548 mat V = this->get_vcov();
549
550 switch( md ){
551 case MAX:
552 case MOST:
553 res.hatvalues = this->hatvalues();
554 case HIGH:
555 res.residuals = this->residuals();
556 case MEDIUM:
557 res.vcov = V;
558 case LOW:
559 res.se = sqrt(diagvec(V));
560 res.rdf = this->get_rdf();
561 case LEAST:
562 break;
563 }
564
565 // Precompute values for Satterthwaite DDF to be
566 // used later with V and L specified
567 Satterthwaite ddf_sat(res.y.n_elem, res.sigSq_g, res.sigSq_e, s, Xu, Gamma_XX, inv_s_delta, inv_s_delta_Xu);
568
569 // save precomputed values
570 res.hessian_vc = ddf_sat.get_hessian();
571 res.A_sat = ddf_sat.get_A();
572 res.B_sat = ddf_sat.get_B();
573
574 // if returnUS
575 // return U and s
576 if( returnUS ){
577 switch( dcmp.get_type() ){
578 case GENERAL:
579 res.setUS(U, s, dcmp.get_V());
580 break;
581 case CATEGORICAL:
582 // V is identity
583 res.setUS(U, s, eye<sp_mat>(U.n_cols, U.n_cols));
584 break;
585 };
586 }
587
588 return res;
589}
590
591
592
593} // end namespace
594
595
596#endif
double rdf
Definition ModelFit.h:43
vec residuals
Definition ModelFit.h:46
mat vcov
Definition ModelFit.h:45
vec se
Definition ModelFit.h:41
vec hatvalues
Definition ModelFit.h:47
Definition ModelFit.h:201
double sigSq_g
Definition ModelFit.h:206
mat A_sat
Definition ModelFit.h:219
mat hessian_vc
Definition ModelFit.h:219
vec y
Definition ModelFit.h:205
double sigSq_e
Definition ModelFit.h:206
mat B_sat
Definition ModelFit.h:219
void setUS(const mat &U_, const vec &s_, const mat &V_)
Definition ModelFit.h:317
Definition satterthwaite.h:27
const mat get_hessian() const
Definition satterthwaite.h:48
const mat get_A() const
Definition satterthwaite.h:49
const mat get_B() const
Definition satterthwaite.h:50
Definition fastlmm_fit.h:35
void set_model_failure()
Definition fastlmm_fit.h:105
const double get_delta() const
Definition fastlmm_fit.h:96
double ll(const double &delta)
Definition fastlmm_fit.h:382
fastlmm()
Definition fastlmm_fit.h:39
vec get_r()
Definition fastlmm_fit.h:145
const double get_logLik() const
Definition fastlmm_fit.h:89
vec get_y()
Definition fastlmm_fit.h:146
double score_test(const vec &x_)
void estimate_delta(const double &left, const double &right, const double &tol)
Definition fastlmm_fit.h:467
void update_Y(const T1 &Y_)
const mat get_vcov() const
Definition fastlmm_fit.h:97
void update_X(const vec &X_)
vec get_weights()
Definition fastlmm_fit.h:142
const vec fitted() const
Definition fastlmm_fit.h:353
const vec blup() const
Definition fastlmm_fit.h:363
const vec residuals() const
Definition fastlmm_fit.h:346
const double get_sigSq_e() const
Definition fastlmm_fit.h:92
ModelFitLMM get_result(const bool &returnUS=false)
Definition fastlmm_fit.h:531
const double get_rdf() const
Definition fastlmm_fit.h:301
vec get_ru()
Definition fastlmm_fit.h:144
const vec get_beta() const
Definition fastlmm_fit.h:90
const vec hatvalues() const
Definition fastlmm_fit.h:323
const double get_sigSq_g() const
Definition fastlmm_fit.h:91
void update_response(const T1 &Y_, const vec &weights_)
Definition fastlmm_fit.h:500
const mat get_beta_se() const
Definition fastlmm_fit.h:100
const int get_iter() const
Definition fastlmm_fit.h:95
void eval_delta(const double &delta)
Definition fastlmm_fit.h:128
Definition spectralDecomp.h:29
void reweight(const vec &weights, const bool &sort=false)
Definition spectralDecomp.h:76
vec get_s() const
Definition spectralDecomp.h:107
T get_U() const
Definition spectralDecomp.h:102
double local_min(double a, double b, double t, funcStruct *f, double &x, int &calls)
Definition local_min.h:25
mat scaleEachCol(const mat &X, const vec &w)
Definition misc.h:16
bool isSpMatrix(const T &t)
Definition misc.h:64
Definition CleanData.h:17
@ CATEGORICAL
Definition spectralDecomp.h:20
@ GENERAL
Definition spectralDecomp.h:19
ModelDetail
Definition ModelFit.h:26
@ MOST
Definition ModelFit.h:31
@ MEDIUM
Definition ModelFit.h:29
@ LEAST
Definition ModelFit.h:27
@ HIGH
Definition ModelFit.h:30
@ MAX
Definition ModelFit.h:32
@ LOW
Definition ModelFit.h:28
Definition local_min.h:18
double(* function)(double x, void *params)
Definition local_min.h:19
void * params
Definition local_min.h:20