fastglmm
Massively scalable generalized linear mixed models
Loading...
Searching...
No Matches
fastglmm_fit.h
Go to the documentation of this file.
1/***************************************************************
2 * @file fastglmm_fit.h
3 * @author Gabriel Hoffman
4 * @email gabriel.hoffman@mssm.edu
5 * @brief Fit generalizd linear mixed model
6 * Copyright (C) 2024 Gabriel Hoffman
7 **************************************************************/
8
9#ifndef _FASTGLMM_FIT_H_
10#define _FASTGLMM_FIT_H_
11
12// if -D USE_R, use RcppArmadillo library
13#ifdef USE_R
14// [[Rcpp::depends(RcppParallel)]]
15#include <RcppArmadillo.h>
16#else
17#include <armadillo>
18#endif
19
20#include "fastlmm_fit.h"
21#include "glm_family.h"
22#include "glm.h"
23#include "ModelFit.h"
24#include "spectralDecomp.h"
25
26using namespace arma;
27using namespace std;
28using namespace fastglmmLib;
29
30namespace fastglmmLib {
31
32// Order of template variables
33// T1 y
34// T2 X
35// T3 U
36template <typename T1, typename T2, typename T3>
37class fastglmm {
38 public:
39
40 // constructor, minimal
42
43 fastglmm( const T1 &y,
44 const T2 &X,
45 const spectralDecomp<T3> &dcmp,
46 const vec &weights,
47 const vec &offset,
48 const string &family,
49 const ModelDetail md = LOW,
50 const double &tol = 1e-5,
51 const double &tol_eta = 1e-7,
52 const int &maxit = 100,
53 const double &lambda = 0,
54 const double &delta = -1,
55 const double &left = -10,
56 const double &right = 10,
57 const bool &returnUS = false,
58 const bool &doCoxReid = false);
59
60 const vec residuals() const ; // Response
61 const vec residuals_pearson() const ; // Pearson
62 const vec fitted() const ;
63 const vec devianceResiduals() const;
64
65 // extract results
67
68 private:
70 vec y;
71 mat X;
72 vec weights, mu, eta, offset;
74 string family;
75 double lambda;
76 bool returnUS;
77 int niter_pql;
78 double w_mean;
79 double mu_mean = datum::nan;
80 double y_mean = datum::nan;
81 double eta_var = datum::nan;
82 ModelDetail md;
83 shared_ptr<GLMFamily> fam;
84 bool isValid = true;
85};
86
87template <typename T1, typename T2, typename T3>
89 const T1 &y,
90 const T2 &X,
91 const spectralDecomp<T3> &dcmp,
92 const vec &weights,
93 const vec &offset,
94 const string &family,
95 const ModelDetail md,
96 const double &tol,
97 const double &tol_eta,
98 const int &maxit,
99 const double &lambda,
100 const double &delta,
101 const double &left,
102 const double &right,
103 const bool &returnUS,
104 const bool &doCoxReid):
105 y(y),
106 X(X),
107 weights(weights),
108 offset(offset),
109 dcmp(dcmp),
110 family(family),
111 lambda(lambda),
112 returnUS(returnUS),
113 md(md)
114 {
115
116 fam = getGLMFamily( family );
117
118 // if Negative Binomial with unspecified theta
119 // estimate theta, and initialize with Poisson GLM
120 bool estimateTheta = family == "nb" ? true : false;
121 if( estimateTheta ){
122 this->family = "poisson/log";
123 }
124
125 checkResponse(y, this->family);
126
127 GLMWork *work = new GLMWork();
128 vec eta_old;
129
130 // Initialize eta
131 // just need a rough starting value
132 ModelFitGLM fit_init = GLM(X, y, this->family, LEAST, weights, offset, work, {}, 1e-2, 3, lambda);
133
134 int iter_in = 0;
135 double theta;
136 uvec idx_drop = find(weights == 0.0);
137 double n_active = weights.n_elem - idx_drop.n_elem;
138
139 CountTable ct;
140 if( estimateTheta ){
141 // Precompute lgamma() on each unique count
142 ct = CreateLUT(y, weights);
143 }
144
145 // PQL iterations
146 for(niter_pql=0; niter_pql<maxit; niter_pql++){
147
148 // if Negative Binomial with unspecified theta
149 if( estimateTheta ){
150 theta = nb_theta_ml(y, work->mu, y.n_elem, weights, X, doCoxReid, ct, -5, 20);
151 fam->setOverdispersion( theta );
152 }
153
154 // update mu, eta, z, w, eta,
155 if( niter_pql == 0){
156 work->eta = work->eta + offset;
157 }else{
158 eta_old = work->eta;
159 work->eta = fit.fitted() + offset;
160
161 // convergence criteria based on norm of eta change
162 if( norm(work->eta - eta_old) < tol_eta){
163 break;
164 }
165 }
166
167 // entries with zero weights have NAN value
168 work->eta.elem( idx_drop ).zeros();
169
170 // mu <- family$linkinv(eta)
171 work->mu = fam->linkinv( work->eta );
172
173 // mu.eta.val <- family$mu.eta(eta)
174 work->gprime = fam->mu_eta( work->eta );
175
176 // zz <- eta + (y.orig - mu)/mu.eta.val - offset
177 work->z = (work->eta - offset) + (y - work->mu) / work->gprime;
178
179 // wz <- w * mu.eta.val^2/family$variance(mu)
180 work->w = square(work->gprime) % (weights / fam->variance( work->mu ));
181
182 // wz <- wz / mean(wz)
183 w_mean = sum(work->w) / n_active;
184 work->w = work->w / w_mean;
185
186 // if model has nan values in z or w,
187 // it can't be fit
188 // so set beta values to nan
189 // and set isValid to false
190 if( work->z.has_nan() || work->w.has_nan() ){
191 fit.set_model_failure();
192 isValid = false;
193 break;
194 }
195
196 // recompute U and s since work->w changed
197 this->dcmp.reweight(work->w);
198
199 // fit fastlmm
200 fit = fastlmm(work->z, X, this->dcmp, work->w, LEAST, lambda);
201
202 if( delta > 0 ){
203 fit.eval_delta( delta );
204 }else{
205 fit.estimate_delta(left, right, tol);
206 }
207
208 // increment interation count
209 iter_in += fit.get_iter();
210 }
211
212 // Final fit with ModelDetail md
213 if( isValid && (md > LEAST) ){
214
215 // fit fastlmm
216 fit = fastlmm(work->z, X, this->dcmp, work->w, md, lambda);
217
218 if( delta > 0 ){
219 fit.eval_delta( delta );
220 }else{
221 fit.estimate_delta(left, right, tol);
222 }
223 }
224
225 if( estimateTheta ){
226 // update family to include estimated theta
227 this->family = "nb:" + to_string(theta);
228 }
229
230 // Use result of lmm and inverse link
231 // to get final value of mu
232 if( isValid ){
233 eta = fit.fitted() + offset;
234 mu = fam->linkinv(eta);
235 }else{
236 eta = vec(offset.n_elem, fill::value(datum::nan));
237 mu = vec(offset.n_elem, fill::value(datum::nan));
238 }
239 // Compute mean of mu
240 // use robust mean to avoid influence of outliers
241 // This can happen with many zero and a few large values
242 mu_mean = robust_mean(mu, 4);
243
244 eta_var = var(eta);
245
246 delete work;
247}
248
249
250
251template <typename T1, typename T2, typename T3>
253
254 // Response residuals
255 return( y - mu );
256}
257
258
259template <typename T1, typename T2, typename T3>
261
262 // Pearson residuals
263 // (y - mu) * sqrt(wts) / sqrt(fam$variance(mu))
264 return (y - mu) % sqrt(weights) / sqrt(fam->variance(mu));
265}
266
267
268template <typename T1, typename T2, typename T3>
270
271 return fam->linkinv( fit.fitted() );
272}
273
274
275template <typename T1, typename T2, typename T3>
277
278 // transform from residuals.glm
279 // d.res <- sqrt(pmax((object$family$dev.resids)(y, mu,
280 // wts), 0))
281 // ifelse(y > mu, d.res, -d.res)
282
283 // compute raw deviance residuals
284 vec dr = fam->dev_resids(y, mu, weights);
285
286 vec drMod = sqrt(pmax(dr, 0));
287 uvec idx = find(y <= mu);
288 drMod.elem(idx) = -1.0*drMod.elem(idx);
289
290 return drMod;
291}
292
293
294template <typename T1, typename T2, typename T3>
296
297 ModelFitLMM res1 = fit.get_result(returnUS);
298 res1.set_w_mean( w_mean );
299
300 // if model is not valid, it failed before final calculations
301 // so set values to NAN matching ModelDetail
302 if( ! isValid ){
303 // number of coefs
304 int p = X.n_cols;
305 res1.coef = vec(p, fill::value(datum::nan));
306
307 switch( md ){
308 case MAX:
309 // res1.hatvalues = hatvalues();
310 case MOST:
311 res1.hatvalues.fill(datum::nan);
312 case HIGH:
313 res1.residuals.fill(datum::nan);
314 case MEDIUM:
315 res1.vcov = mat(p, p, fill::value(datum::nan));
316 case LOW:
317 res1.se = vec(p, fill::value(datum::nan));
318 res1.rdf = datum::nan;
319 case LEAST:
320 break;
321 }
322 }
323
324 ModelFitGLMM mf(res1, family, niter_pql);
325
326 mf.mu_mean = mu_mean;
327 mf.y_mean = mean(y);
328 mf.varFitted = eta_var;
329
330 // QL dispersion based on Pearson residuals
331 // sum((w*r^2)[w > 0]) / df.r
332 // need to drop elements without nan
333 vec rp = residuals_pearson();
334 vec rp_clean = rp.elem(find_finite(rp));
335 double disp = dot(rp_clean, rp_clean) / res1.rdf ;
336
337 if( fam->estimateDispersion() ){
338 mf.dispersion = disp;
339 }else{
340 mf.dispersion = 1.0;
341 // unscale variances by dispersion
342 mf.vcov /= disp;
343 mf.se /= sqrt(disp);
344 }
345
346 if( md == MAX ){
348 }
349
350 if( md >= HIGH ){
351 // Respones residuals
352 mf.residuals = residuals();
353 }
354
355 return mf;
356}
357
358
359} // end namespace
360#endif
Definition ModelFit.h:165
Definition ModelFit.h:355
double mu_mean
Definition ModelFit.h:387
double y_mean
Definition ModelFit.h:388
vec coef
Definition ModelFit.h:40
double rdf
Definition ModelFit.h:43
vec residuals
Definition ModelFit.h:46
vec devianceResiduals
Definition ModelFit.h:49
mat vcov
Definition ModelFit.h:45
vec se
Definition ModelFit.h:41
double varFitted
Definition ModelFit.h:50
double dispersion
Definition ModelFit.h:42
vec hatvalues
Definition ModelFit.h:47
Definition ModelFit.h:201
void set_w_mean(const double &value)
Definition ModelFit.h:349
ModelFitGLMM get_result()
Definition fastglmm_fit.h:295
fastglmm()
Definition fastglmm_fit.h:41
const vec fitted() const
Definition fastglmm_fit.h:269
const vec devianceResiduals() const
Definition fastglmm_fit.h:276
const vec residuals() const
Definition fastglmm_fit.h:252
const vec residuals_pearson() const
Definition fastglmm_fit.h:260
Definition fastlmm_fit.h:35
Definition spectralDecomp.h:29
void reweight(const vec &weights, const bool &sort=false)
Definition spectralDecomp.h:76
Definition CleanData.h:17
ModelDetail
Definition ModelFit.h:26
@ MOST
Definition ModelFit.h:31
@ MEDIUM
Definition ModelFit.h:29
@ LEAST
Definition ModelFit.h:27
@ HIGH
Definition ModelFit.h:30
@ MAX
Definition ModelFit.h:32
@ LOW
Definition ModelFit.h:28
unordered_map< long, double > CountTable
Definition nb_theta.h:18
Definition glm.h:74
vec w
Definition glm.h:75
vec gprime
Definition glm.h:75
vec z
Definition glm.h:75
vec eta
Definition glm.h:75
vec mu
Definition glm.h:75