fastglmm
Massively scalable generalized linear mixed models
Loading...
Searching...
No Matches
fastglmm_fit.h
Go to the documentation of this file.
1/***************************************************************
2 * @file fastglmm_fit.h
3 * @author Gabriel Hoffman
4 * @email gabriel.hoffman@mssm.edu
5 * @brief Fit generalizd linear mixed model
6 * Copyright (C) 2024 Gabriel Hoffman
7 **************************************************************/
8
9#ifndef _FASTGLMM_FIT_H_
10#define _FASTGLMM_FIT_H_
11
12// if -D USE_R, use RcppArmadillo library
13#ifdef USE_R
14// [[Rcpp::depends(RcppParallel)]]
15#include <RcppArmadillo.h>
16#else
17#include <armadillo>
18#endif
19
20#include "fastlmm_fit.h"
21#include "glm_family.h"
22#include "glm.h"
23#include "ModelFit.h"
24#include "spectralDecomp.h"
25
26using namespace arma;
27using namespace std;
28using namespace fastglmmLib;
29
30namespace fastglmmLib {
31
32// Order of template variables
33// T1 y
34// T2 X
35// T3 U
36template <typename T1, typename T2, typename T3>
37class fastglmm {
38 public:
39
40 // constructor, minimal
42
43 fastglmm( const T1 &y,
44 const T2 &X,
45 const spectralDecomp<T3> &dcmp,
46 const vec &weights,
47 const vec &offset,
48 const string &family,
49 const ModelDetail md = LOW,
50 const double &tol = 1e-5,
51 const double &tol_eta = 1e-7,
52 const int &maxit = 100,
53 const double &lambda = 0,
54 const double &delta = -1,
55 const double &left = -10,
56 const double &right = 10,
57 const bool &returnUS = false,
58 const bool &doCoxReid = false);
59
60 const vec residuals() const ; // Response
61 const vec residuals_pearson() const ; // Pearson
62 const vec fitted() const ;
63 const vec devianceResiduals() const;
64
65 // extract results
67
68 private:
70 vec y;
71 mat X;
72 vec weights, mu, eta, offset;
74 string family;
75 double lambda;
76 bool returnUS;
77 int niter_pql;
78 double w_mean;
79 double mu_mean = datum::nan;
80 double y_mean = datum::nan;
81 double eta_var = datum::nan;
82 ModelDetail md;
83 shared_ptr<GLMFamily> fam;
84 bool isValid = true;
85};
86
87template <typename T1, typename T2, typename T3>
89 const T1 &y,
90 const T2 &X,
91 const spectralDecomp<T3> &dcmp,
92 const vec &weights,
93 const vec &offset,
94 const string &family,
95 const ModelDetail md,
96 const double &tol,
97 const double &tol_eta,
98 const int &maxit,
99 const double &lambda,
100 const double &delta,
101 const double &left,
102 const double &right,
103 const bool &returnUS,
104 const bool &doCoxReid):
105 y(y),
106 X(X),
107 weights(weights),
108 offset(offset),
109 dcmp(dcmp),
110 family(family),
111 lambda(lambda),
112 returnUS(returnUS),
113 md(md)
114 {
115
116 fam = getGLMFamily( family );
117
118 // if Negative Binomial with unspecified theta
119 // estimate theta, and initialize with Poisson GLM
120 bool estimateTheta = family == "nb" ? true : false;
121 if( estimateTheta ){
122 this->family = "poisson/log";
123 }
124
125 checkResponse(y, this->family);
126
127 GLMWork *work = new GLMWork();
128 vec eta_old;
129
130 // Initialize eta
131 // just need a rough starting value
132 ModelFitGLM fit_init = GLM(X, y, this->family, LEAST, weights, offset, work, {}, 1e-2, 3, lambda);
133
134 int iter_in = 0;
135 double theta;
136 uvec idx_drop = find(weights == 0.0);
137 double n_active = weights.n_elem - idx_drop.n_elem;
138
139 CountTable ct;
140 if( estimateTheta ){
141 // Precompute lgamma() on each unique count
142 ct = CreateLUT(y, weights);
143 }
144
145 // PQL iterations
146 for(niter_pql=0; niter_pql<maxit; niter_pql++){
147
148 // if Negative Binomial with unspecified theta
149 if( estimateTheta ){
150 theta = nb_theta_ml(y, work->mu, y.n_elem, weights, X, doCoxReid, ct, -5, 20);
151 fam->setOverdispersion( theta );
152 }
153
154 // update mu, eta, z, w, eta,
155 if( niter_pql == 0){
156 work->eta = work->eta + offset;
157 }else{
158 eta_old = work->eta;
159 work->eta = fit.fitted() + offset;
160
161 // convergence criteria based on norm of eta change
162 if( norm(work->eta - eta_old) < tol_eta){
163 break;
164 }
165 }
166
167 // entries with zero weights have NAN value
168 work->eta.elem( idx_drop ).zeros();
169
170 // mu <- family$linkinv(eta)
171 work->mu = fam->linkinv( work->eta );
172
173 // mu.eta.val <- family$mu.eta(eta)
174 work->gprime = fam->mu_eta( work->eta );
175
176 // zz <- eta + (y.orig - mu)/mu.eta.val - offset
177 work->z = (work->eta - offset) + (y - work->mu) / work->gprime;
178
179 // wz <- w * mu.eta.val^2/family$variance(mu)
180 work->w = square(work->gprime) % (weights / fam->variance( work->mu ));
181
182 // wz <- wz / mean(wz)
183 w_mean = sum(work->w) / n_active;
184 work->w = work->w / w_mean;
185
186 // if model has nan values in z or w,
187 // it can't be fit
188 // so set beta values to nan
189 // and set isValid to false
190 if( work->z.has_nan() || work->w.has_nan() ){
191 fit.set_model_failure();
192 isValid = false;
193 break;
194 }
195
196 // recompute U and s since work->w changed
197 this->dcmp.reweight(work->w);
198
199 // fit fastlmm
200 fit = fastlmm(work->z, X, this->dcmp, work->w, LEAST, lambda);
201
202 if( delta > 0 ){
203 fit.eval_delta( delta );
204 }else{
205 fit.estimate_delta(left, right, tol);
206 }
207
208 // increment interation count
209 iter_in += fit.get_iter();
210 }
211
212 // Final fit with ModelDetail md
213 if( isValid && (md > LEAST) ){
214
215 // fit fastlmm
216 fit = fastlmm(work->z, X, this->dcmp, work->w, md, lambda);
217
218 if( delta > 0 ){
219 fit.eval_delta( delta );
220 }else{
221 fit.estimate_delta(left, right, tol);
222 }
223 }
224
225 if( estimateTheta ){
226 // update family to include estimated theta
227 this->family = "nb:" + to_string(theta);
228 }
229
230 // Use result of lmm and inverse link
231 // to get final value of mu
232 if( isValid ){
233 eta = fit.fitted() + offset;
234 mu = fam->linkinv(eta);
235 }else{
236 eta = vec(offset.n_elem, fill::value(datum::nan));
237 mu = vec(offset.n_elem, fill::value(datum::nan));
238 }
239
240 // Compute mean of mu
241 // use robust mean to avoid influence of outliers
242 // This can happen with many zero and a few large values
243 mu_mean = robust_mean(mu, 4);
244
245 eta_var = var(eta);
246
247 delete work;
248}
249
250
251
252template <typename T1, typename T2, typename T3>
254
255 // Response residuals
256 return( y - mu );
257}
258
259
260template <typename T1, typename T2, typename T3>
262
263 // Pearson residuals
264 // (y - mu) * sqrt(wts) / sqrt(fam$variance(mu))
265 return (y - mu) % sqrt(weights) / sqrt(fam->variance(mu));
266}
267
268
269template <typename T1, typename T2, typename T3>
271
272 return fam->linkinv( fit.fitted() );
273}
274
275
276template <typename T1, typename T2, typename T3>
278
279 // transform from residuals.glm
280 // d.res <- sqrt(pmax((object$family$dev.resids)(y, mu,
281 // wts), 0))
282 // ifelse(y > mu, d.res, -d.res)
283
284 // compute raw deviance residuals
285 vec dr = fam->dev_resids(y, mu, weights);
286
287 vec drMod = sqrt(pmax(dr, 0));
288 uvec idx = find(y <= mu);
289 drMod.elem(idx) = -1.0*drMod.elem(idx);
290
291 return drMod;
292}
293
294
295template <typename T1, typename T2, typename T3>
297
298 ModelFitLMM res1 = fit.get_result(returnUS);
299 res1.set_w_mean( w_mean );
300
301 // if model is not valid, it failed before final calculations
302 // so set values to NAN matching ModelDetail
303 if( ! isValid ){
304 // number of coefs
305 int p = X.n_cols;
306 res1.coef = vec(p, fill::value(datum::nan));
307
308 switch( md ){
309 case MAX:
310 case MOST:
311 res1.mu.fill(datum::nan);
312 res1.hatvalues.fill(datum::nan);
313 case HIGH:
314 res1.residuals.fill(datum::nan);
315 case MEDIUM:
316 res1.vcov = mat(p, p, fill::value(datum::nan));
317 case LOW:
318 res1.se = vec(p, fill::value(datum::nan));
319 res1.rdf = datum::nan;
320 case LEAST:
321 break;
322 }
323 }
324
325 ModelFitGLMM mf(res1, family, niter_pql);
326
327 mf.mu_mean = mu_mean;
328 mf.y_mean = mean(y);
329 mf.varFitted = eta_var;
330
331 // QL dispersion based on Pearson residuals
332 // sum((w*r^2)[w > 0]) / df.r
333 // need to drop elements without nan
334 vec rp = residuals_pearson();
335 vec rp_clean = rp.elem(find_finite(rp));
336 double disp = dot(rp_clean, rp_clean) / res1.rdf ;
337
338 if( fam->estimateDispersion() ){
339 mf.dispersion = disp;
340 }else{
341 mf.dispersion = 1.0;
342 // unscale variances by dispersion
343 mf.vcov /= disp;
344 mf.se /= sqrt(disp);
345 }
346
347 if( md == MAX ){
349 }
350
351 if( md >= HIGH ){
352 // response residuals
353 mf.residuals = residuals();
354 }
355
356 if( md >= MOST ){
357 // Respones residuals
358 mf.mu = fitted();
359 // mf.hatvalues = hatvalues();
360 }
361
362 return mf;
363}
364
365
366} // end namespace
367#endif
Definition ModelFit.h:165
Definition ModelFit.h:355
double mu_mean
Definition ModelFit.h:387
double y_mean
Definition ModelFit.h:388
vec coef
Definition ModelFit.h:40
double rdf
Definition ModelFit.h:43
vec residuals
Definition ModelFit.h:46
vec devianceResiduals
Definition ModelFit.h:49
mat vcov
Definition ModelFit.h:45
vec se
Definition ModelFit.h:41
vec mu
Definition ModelFit.h:48
double varFitted
Definition ModelFit.h:50
double dispersion
Definition ModelFit.h:42
vec hatvalues
Definition ModelFit.h:47
Definition ModelFit.h:201
void set_w_mean(const double &value)
Definition ModelFit.h:349
ModelFitGLMM get_result()
Definition fastglmm_fit.h:296
fastglmm()
Definition fastglmm_fit.h:41
const vec fitted() const
Definition fastglmm_fit.h:270
const vec devianceResiduals() const
Definition fastglmm_fit.h:277
const vec residuals() const
Definition fastglmm_fit.h:253
const vec residuals_pearson() const
Definition fastglmm_fit.h:261
Definition fastlmm_fit.h:35
Definition spectralDecomp.h:29
void reweight(const vec &weights, const bool &sort=false)
Definition spectralDecomp.h:76
Definition CleanData.h:17
ModelDetail
Definition ModelFit.h:26
@ MOST
Definition ModelFit.h:31
@ MEDIUM
Definition ModelFit.h:29
@ LEAST
Definition ModelFit.h:27
@ HIGH
Definition ModelFit.h:30
@ MAX
Definition ModelFit.h:32
@ LOW
Definition ModelFit.h:28
unordered_map< long, double > CountTable
Definition nb_theta.h:18
Definition glm.h:80
vec w
Definition glm.h:81
vec gprime
Definition glm.h:81
vec z
Definition glm.h:81
vec eta
Definition glm.h:81
vec mu
Definition glm.h:81