library(GridLMM)
library(glmnet)
library(itertools)

This demonstrates the use of the GridLMMnet function.

Create some simulated data:

nG = 90
nR = 10
n = nG*nR
data = expand.grid(Group = factor(1:nG),Rep = 1:nR)
data$ID = 1:n
p = 100
X = matrix(sample(c(0,1),nG*p,replace=T),nG,p)
# X = BSFG::rstdnorm_mat(nG,p)
X = X[data$Group,]
X =  matrix(sample(c(0,1),nG*nR*p,replace=T),nG*nR,p)
rownames(X) = data$ID

s2_g = 0.9
s2_e = 1-s2_g
b1 = 2
b2 = p-b1
beta = c(rep(.2,b1),rep(0,b2))


# K = tcrossprod(model.matrix(~0+Group,data))
# V = s2_g*K + diag(s2_e,n)
# sum(diag(solve(V)))/n
# V = s2_g/s2_e*K + diag(1,n)
# sum(diag(solve(V)))/n
# #
# h2s = seq(0.3,.69,length=20)
# trs = sapply(h2s,function(h2) {
#   V = h2*K + diag(1-h2,n)
#   # V = h2/(1-h2)*K + diag(1,n)
#   sum(diag(solve(V)))#/n
# })
# plot(1/(h2s+(1-h2s)/nR),n/trs);abline(0,1)
# plot(1-h2s,trs)
# # plot()
# plot((h2s+(1-h2s)/nR)/(h2s+(1-h2s)),1/trs);abline(0,1)
# plot(h2s/(1-h2s)+1/nR,1/trs);abline(0,1)
# plot((h2s*(nR-1)+1)/(nR*(1-h2s))/(1-h2s),1/trs);abline(0,1)

y = 1*X %*% beta + sqrt(s2_g)*rnorm(nG)[data$Group] + sqrt(s2_e)*rnorm(n)
data$y = (y-mean(y))#/sd(y)

library(lme4)
m0 = lmer(y~(1|Group),data)
m1 = lmer(y~X[,1:b1]+(1|Group),data)
v1 = as.data.frame(VarCorr(m1))$vcov
v1[1]/sum(v1)
v1
sum(v1)

data2 = aggregate(y~Group,data,FUN=mean)
data2$y = data2$V1#/sd(data2$V1)
X2 = X[data$Rep==1,]
m2 = lm(y~X2[,1:b1],data2)


foldid = sample(1:10,n,replace = T)

# rsss_s2e = sapply(h2s,function(h2) {
#   V = h2/(1-h2)*K + diag(1,n)
#   Vinv = solve(V)
#   sapply(1:100,function(i) {
#     s2_g = h2;s2_e = 1-h2
#     s2_g = h2/(1-h2);s2_e = 1
#     y = sqrt(s2_g)*rnorm(nG)[data$Group] + sqrt(s2_e)*rnorm(n)
#     sum(sum(y*(Vinv %*% y)))
#   })
# })
# plot(colMeans(rsss_s2e));abline(0,1)
# plot(n*(1-h2s),colMeans(rsss_s2e));abline(0,1)
# 
# 
# rsss_s2 = sapply(h2s,function(h2) {
#   V = h2*K + diag(1-h2,n)
#   Vinv = solve(V)
#   sapply(1:100,function(i) {
#     s2_g = h2;s2_e = 1-h2
#     y = sqrt(s2_g)*rnorm(nG)[data$Group] + sqrt(s2_e)*rnorm(n)
#     sum(sum(y*(Vinv %*% y)))
#   })
# })
# plot(colMeans(rsss_s2));abline(0,1)

Use glmnet to get the lasso solutions without random effects

sd_y= sd(y)*sqrt((n-1)/n)
res0 = glmnet(X/sd_y,y/sd_y,alpha=1,standardize = F)

Standard coefficient-by-lambda plot. Red = true covariates, black= false covariates.

plot(res0,'lambda',col = c(rep(2,b1),rep(1,b2)))
k = t(X) %*% X
ck = chol(k)
ind = res0$beta[,50] != 0
w=rep(0,p)
w[ind] = 1/abs(res0$beta[ind,50])
i1 = solve(k + n*res0$lambda[50]*diag(w))
sum(diag(X %*% i1 %*% t(X)))
# W = matrix(0,p,3)
# diag(W) = sqrt(c(3,2,10))
# ck2 = chol_update_L(t(ck),W,(rep(1,3)))
# i2 = chol2inv(t(ck2))
# max(abs(i1-i2))
r1 = sapply(seq_along(res0$lambda),function(i) {
  e = sum((y-res0$a0[i]^2*sd_y - X %*% res0$beta[,i])^2)
  k = sum(res0$beta[,i] != 0)
  ck2 = t(ck)
  if(k>0) {
    W = matrix(0,p,k)
    c=0
    for(j in 1:p) {
      if(res0$beta[j,i] != 0) {
        c=c+1
        W[j,c] = sqrt(n*res0$lambda[i]/abs(res0$beta[j,i]))
      }
    }
    ck2 = chol_update_L(t(ck),W,(rep(1,k)))
  }
  Xti = forwardsolve(ck2,t(X))
  # sum(diag(crossprod(Xti)))
  1/n*sum(e^2)/(1-(sum(diag(crossprod(Xti)))-sum(res0$beta[,i]==0))/n)^2
})
plot(r1)
library(MASS)
K = t(X) %*% X
r2 = sapply(seq_along(res0$lambda),function(i) {
  e = sum((y-res0$a0[i]^2*sd_y - X %*% res0$beta[,i])^2)
  k = sum(res0$beta[,i] != 0)
  n0 = sum(res0$beta[,i] == 0)
  K2 = K
  if(k > 0) {
    W = rep(0,p)
    for(j in 1:p) {
      if(res0$beta[j,i] != 0) W[j] = n*res0$lambda[i]/abs(res0$beta[j,i])
    }
    K2 = K + diag(W)
  }
  # m = X %*% ginv(K2) %*% t(X)
  m = X %*% solve_LDLt(K2,t(X))
  # sum(diag(m))-sum(res0$beta[,i]==0)
  1/n*sum(e^2)/(1-(sum(diag(m))-n0)/n)^2
})
plot(log(res0$lambda),r2,log='y')
K = t(X) %*% X
nlambda_beta_sqrt = sqrt(1/n) * sqrt(abs(sweep(res0$beta,2,res0$lambda,'/')))
nlambda_beta = (1/n) * abs(sweep(res0$beta,2,res0$lambda,'/'))
l = length(res0$lambda)
results = rep(0,l)
current_w = rep(0,p)
total_w = rep(0,p)
K2 = K
a = c()
b = c()
d = c()
for(i in 1:length(res0$lambda)) {
  e = sum((y-res0$a0[i]^2*sd_y - X %*% res0$beta[,i])^2)
  k = sum(res0$beta[,i] != 0)
  n0 = sum(res0$beta[,i] == 0)
  for(j in 1:p) {
    # current_w[j] = -total_w[j]
    if(nlambda_beta[j,i] > 0) current_w[j] = 1/nlambda_beta[j,i]-total_w[j]
    total_w[j] = total_w[j] + current_w[j]
  }
  W = rep(0,p)
  if(k > 0) {
    for(j in 1:p) {
      if(res0$beta[j,i] != 0) W[j] = n*res0$lambda[i]/abs(res0$beta[j,i])
    }
  }
  diag(K2) = diag(K2) + current_w
  a = rbind(a,diag(K2))
  b = rbind(b,current_w)
  d = rbind(d,total_w)
  print(c(max(abs(diag(K2) - (diag(K) + W))),max(abs(W-total_w))))
  # Sys.sleep(.1)
}
qt = Calculate_qt(X,as.matrix(res0$beta),res0$lambda)
r3 = sapply(seq_along(res0$lambda),function(i) {
  e = sum((y-res0$a0[i]^2*sd_y - X %*% res0$beta[,i])^2)
  k = sum(res0$beta[,i] != 0)
  n0 = sum(res0$beta[,i] == 0)
  1/n*sum(e^2)/(1-(qt[i]-n0)/n)^2
})
plot(log(res0$lambda),r3,log='y')
plot(log(res0$lambda),colSums(res0$beta!=0))
# plot(r3,r2,log='xy');abline(0,1)
K = t(X) %*% X
cK = Cholesky(as(K,'CsparseMatrix'),LDL=T)
qt2 = sapply(seq_along(res0$lambda),function(i) {
  cK2 = cK
  idx = res0$beta[,i] != 0
  if(sum(idx) > 0) {
    nlambda_betai = n*res0$lambda[i] * (idx)
    nlambda_betai[idx] = nlambda_betai[idx] / abs(res0$beta[idx,i])
    W = matrix(0,p,p)
    diag(W) = sqrt(nlambda_betai)
    W = W[,idx,drop=F]
    cK2 = updown(T,W,cK)
  }
  S = X %*% solve(cK2,t(X))
  sum(diag(S))
})

r4 = sapply(seq_along(res0$lambda),function(i) {
  e = sum((y-res0$a0[i]^2*sd_y - X %*% res0$beta[,i])^2)
  k = sum(res0$beta[,i] != 0)
  n0 = sum(res0$beta[,i] == 0)
  1/n*sum(e^2)/(1-(qt2[i]-n0)/n)^2
})
plot(log(res0$lambda),r4,log='y')
plot(log(res0$lambda),colSums(res0$beta!=0))
plot(r4,r3,log='xy');abline(0,1)
i = 50
idx = res0$beta[,i] != 0
nlambda_betai = n*res0$lambda[i] * (idx)
nlambda_betai[idx] = nlambda_betai[idx] / abs(res0$beta[idx,i])
K = t(X) %*% X
l1 = BSFG::LDLt(K + diag(nlambda_betai))
K2 = t(l1$P) %*% (l1$L) %*% diag(l1$d) %*% t(l1$L) %*% l1$P
W = matrix(0,p,p)
diag(W) = sqrt(nlambda_betai)
W = W[,idx,drop=F]
W = W[,order(-apply(W,2,max))]
cK = Cholesky(as(K,'CsparseMatrix'),LDL=T)
cK2 = updown(T,W,cK)
a=X %*% solve(cK2,t(X))
sum(diag(a))
a2 = X %*% solve(K2,t(X))
sum(diag(a2))

l2 = LDLt_downdate(K,W)
# plot(l1$d,l2$d,log='xy');abline(0,1)
# plot(diag(K2),diag(l2$K2),log='xy');abline(0,1)
plot(diag(K2)-diag(l2$K2))
a = 3
bs = rnorm(100)
current_w = 0
total_w = 0
r = c()
for(i in 1:100) {
  current_w = -total_w + bs[i]
  total_w = total_w + current_w
  r[i] = current_w
}
X = rstdnorm_mat(3,5)
k = t(X) %*% X
lk = LDLt(k)
max(abs(k - t(lk$P) %*% lk$L %*% diag(lk$d) %*% t(lk$L) %*% (lk$P)))
k%*% ginv(k) %*% k - k

k %*% (t(lk$P) %*% solve(t(lk$L)) %*% ginv(diag(lk$d)) %*% solve(lk$L) %*% lk$P) %*% k - k
tpLi = t(lk$P) %*% solve(t(lk$L))
k %*% (tpLi%*% ginv(diag(lk$d)) %*% t(tpLi)) %*% k - k

# if p > n
k = X %*% t(X)
ck = chol(k)
r1 = sapply(seq_along(res0$lambda),function(i) {
  e = sum((y-res0$a0[i]^2*sd_y - X %*% res0$beta[,i])^2)
  k = sum(res0$beta[,i] != 0)
  ck2 = t(ck)
  if(k>0) {
    W = matrix(0,p,k)
    c=0
    for(j in 1:p) {
      if(res0$beta[j,i] != 0) {
        c=c+1
        W[j,c] = 1/sqrt(n*res0$lambda[i]/abs(res0$beta[j,i]))
      }
    }
    ck2 = chol_update_L(t(ck),W,(rep(1,k)))
  }
  Xti = forwardsolve(ck2,t(X))
  w = rep(0,p)
  ind = res0$beta != 0
  w[ind] = abs(res0$beta[ind,i])
  A = X %*% (diag(w) - diag(w) %*% crossprod(Xti) %*% diag(w))
  # sum(diag(crossprod(Xti)))
  1/n*sum(e^2)/(1-sum(diag(A))/n)^2
})
plot(r1)
library(MASS)
w = rep(0,p)
ind = res0$beta[,50] != 0
w[ind] = 1/abs(res0$beta[ind,50])
b2 = solve(t(X) %*% X + n*res0$lambda[50]*diag(2+w)) %*% t(X) %*% (y-res0$a0[50])
plot(b2,res0$beta[,50]);abline(0,1)

Run GridLMMnet

res1 = GridLMMnet(y~1+(1|Group),data = data,X = X,h2_step = 0.025)
res1a=res1

Standard coefficient-by-lambda plot. Red = true covariates, black= false covariates.

plot(res1,'lambda',col = c(rep(2,b1),rep(1,b2)))
res1b = GridLMMnet(y~1+(1|Group),data = data,X = X,h2_step = 0.025,lambdaType = 's2e',scoreType = 'GCV')
plot(res1b,'lambda',col = c(rep(2,b1),rep(1,b2)))
plot(log(res1b$lambda),res1b$h2s)
plot(log(res1a$lambda),res1a$h2s)
res1c = GridLMMnet(y~1+(1|Group),data = data,X = X,h2_step = 0.025,lambdaType = 's2e',foldid = foldid)
plot(res1c$glmnet.fit,'lambda',col = c(rep(2,b1),rep(1,b2)))
plot(res1c)
plot(log(res1c$lambda),res1c$glmnet.fit$h2s)
plot(log(res1b$lambda),res1b$h2s)
plot(log(res1a$lambda),res1a$h2s)
res02 = glmnet(X2,data2$y,alpha=1,lambda = res0$lambda)
plot(res0,'lambda',col = c(rep(2,b1),rep(1,b2)))
plot(res02,'lambda',col = c(rep(2,b1),rep(1,b2)))
res1 = GridLMMnet(y~1+(1|Group),data = data,X = X,h2_step = 0.025,alpha = 0)
plot(res1,'lambda',col = c(rep(2,b1),rep(1,b2)))
s2_g = 0.3
s2_e = 1-s2_g
b1 = 10
b2 = p-b1
beta = c(rep(.3,b1),rep(0,b2))

K = tcrossprod(X) + diag(1e-8,n)
K = K/mean(diag(K))
V = s2_g*K + diag(s2_e,n)
sum(diag(solve(V)))/n
V = s2_g/s2_e*K + diag(1,n)
sum(diag(solve(V)))/n
cK = chol(K)
y = X %*% beta + sqrt(s2_g)*t(cK) %*%rnorm(n) + sqrt(s2_e)*rnorm(n)

Use glmnet to get the lasso solutions without random effects

res0 = glmnet(X,y,alpha=1)

Standard coefficient-by-lambda plot. Red = true covariates, black= false covariates.

plot(res0,'lambda',col = c(rep(2,b1),rep(1,b2)))

Run GridLMMnet

res1 = GridLMMnet(y~1+(1|ID),data = data,X = X,relmat = list(ID=K),h2_step = 0.025)
res1a=res1

Standard coefficient-by-lambda plot. Red = true covariates, black= false covariates.

plot(res1,'lambda',col = c(rep(2,b1),rep(1,b2)))
res1 = GridLMMnet(y~1+(1|ID),data = data,X = X,relmat = list(ID=K),h2_step = 0.025)
plot(res1,'lambda',col = c(rep(2,b1),rep(1,b2)))
plot(res1b,'lambda',col = c(rep(2,b1),rep(1,b2)))

Plot the estimate variance component h^2

plot(res1$h2s~ log(res1$lambda))
abline(h=s2_g/(s2_g+s2_e))

Plot estimate of s2_g

plot(res1$h2s*res1$s2s~ - log(res1$lambda),ylim = c(0,var(y)))
abline(h=s2_g)
abline(h=v1[1],col=2)

plot estiamte of s2_e

plot((1-res1$h2s)*res1$s2s~ log(res1$lambda),ylim = c(0,var(y)))
abline(h=s2_e)  # simulated data
abline(h=v1[2],col=2) # true estimate

plot estimate of s2 = s2_g + s2_e

plot(res1$s2s~ log(res1$lambda),ylim = c(0,var(y)))
abline(h=s2_g+s2_e)  # simulated data
abline(h=sum(v1),col=2) # true estimate

Compare to ggmix

library(ggmix)
K = tcrossprod(model.matrix(~0+factor(Group),data))
sK = svd(K)
# res2 = ggmix(X,y,U = sK$u[,1:nG],D = sK$d[1:nG])
res2 = ggmix(X,y,kinship = K,n_nonzero_eigenvalues = nG)

Standard coefficient-by-lambda plot. Red = true covariates, black= false covariates.

plot(res2,xvar = 'lambda',col = c(rep(2,b1),rep(1,b2)))
plot(res1,xvar = 'lambda',col = c(rep(2,b1),rep(1,b2)))
plot(res02,xvar = 'lambda',col = c(rep(2,b1),rep(1,b2)))
# res1cv = GridLMMnet(y~1+(1|Group),data = data,X = X,h2_step = 0.025,alpha = 1,foldid = foldid) #,foldid = as.numeric(data$Group)
res1cv = GridLMMnet(y~1+(1|ID),data = data,X = X,h2_step = 0.025,relmat = list(ID=K),alpha = 1,foldid = foldid) #,foldid = as.numeric(data$Group)
plot(res1cv)
best_lambda = res1cv$lambda[order(res1cv$cvm)[1]]
plot(res1cv$glmnet.fit,'lambda',col = c(rep(2,b1),rep(1,b2)));abline(v=log(best_lambda))
plot(log(res1cv$lambda),res1cv$glmnet.fit$h2s);abline(v=log(best_lambda))
preds = matrix(0,n,2)
data$x = X[,1]
m1 = lmer(y~x+(1|Group),data)
m0 = lmer(y~1+(1|Group),data)
for(i in unique(foldid)) {
  mi = lmer(y~x+(1|Group),data = data[foldid != i,])
  mi0 = lmer(y~1+(1|Group),data = data[foldid != i,])
  preds[foldid == i,1] = predict(mi,newdata = data[foldid == i,],allow.new.levels = T)
  preds[foldid == i,2] = predict(mi0,newdata = data[foldid == i,],allow.new.levels = T)
}
cvraw = (data$y[,1]-as.matrix(preds))^2
prediction_weights = rep(1,n)
cvob = glmnet::cvcompute(cvraw,prediction_weights,foldid,rep(2,max(foldid)))
cvraw = cvob$cvraw
weights = cvob$weights
N = cvob$N
cvm = apply(cvraw, 2, weighted.mean, w = weights, na.rm = TRUE)
cvsd = sqrt(apply(scale(cvraw, cvm_lambda, FALSE)^2, 2, weighted.mean,
                  w = weights, na.rm = TRUE)/(N - 1))
cvm
cvsd
plot(res1cv);abline(h=cvm)
# X = scale(X)
# data$y = data$y-mean(data$y)
# data$y = data$y/sd(data$y)*sqrt(n/(n-1))
sdy = sd(data$y)*sqrt(n/(n-1))
res0 = glmnet(X,data$y,standardize = F)
res10 = GridLMMnet(y~1+(1|Group),data = data,X = X,h2_step = 1,h2_start = 0,alpha = 1,lambda = res0$lambda,diagonalize = F)
res0cv = cv.glmnet(X,data$y,foldid = foldid,lambda = res0$lambda,standardize=F)
res1cv0 = GridLMMnet(y~1+(1|Group),data = data,X = X,h2_step = 1,h2_start = 0,alpha = 1,foldid = foldid,lambda = res0$lambda[1:71])
plot(res0,'lambda',col = c(rep(2,b1),rep(1,b2)))
plot(res10,'lambda',col = c(rep(2,b1),rep(1,b2)))
plot(res1cv0$glmnet.fit,'lambda',col = c(rep(2,b1),rep(1,b2)))
plot(res0cv$glmnet.fit,'lambda',col = c(rep(2,b1),rep(1,b2)))
plot(res0cv)
plot(res1cv0,ylim = c(.49,.53));abline(h=cvm)
plot(res2$eta[1,]~ log(res2$lambda))
i = 8
plot(res0$lambda,res0$beta[i,],type='l')
lines(res1$lambda,res1$beta[i+1,],col=2)
lines(res2$lambda,res2$beta[i,],col=3)
library(ggmix)
X = sweep(X,2,colMeans(X),'-')
X = sweep(X,2,apply(X,2,sd),'/')
y = X %*% beta + rnorm(n)
y = y - mean(y)
y = y/sd(y)
K = diag(1,nrow(data))
rownames(K) = colnames(K) = data$ID
res0 = glmnet(X,y,alpha=1)
res1 = GridLMMnet(y~1+(1|ID),data = data,X = X,h2_step = 0.025,relmat = list(ID=K))
res2 = ggmix(X,y,kinship = K,eta_init = .5,standardize = T)

K2 = 10*K
res1b = GridLMMnet(y~1+(1|ID),data = data,X = X,h2_step = 0.025,relmat = list(ID=K2))
res2b = ggmix(X,y,kinship = K2,eta_init = .5)

K3 = .1*K
res2c = ggmix(X,y,kinship = K3,eta_init = .01)
library(BSFG)
K = tcrossprod(rstdnorm_mat(nrow(X),nrow(X)))
K = K/mean(diag(K))/2
K = K + diag(.5,nrow(K))
rownames(K) = colnames(K) = data$ID
cK = chol(K)
y = X %*% beta + sqrt(s2_g)*t(cK) %*% rnorm(n) + sqrt(s2_e)*rnorm(n)
# y = y - mean(y)
# y = y/sd(y)
res0 = glmnet(X,y,alpha=1)
res1 = GridLMMnet(y~1+(1|ID),data = data,X = X,h2_step = 0.025,relmat = list(ID=K))
res2 = ggmix(X,y,kinship = K,standardize = T)
K2 = K*10
res1b = GridLMMnet(y~0+(1|ID),data = data,X = X,h2_step = 0.025,relmat = list(ID=K2))
res2b = ggmix(X,y,kinship = K2)
i = 14
for(i in 1:12) {
plot(log(res0$lambda),res0$beta[i,],type='l',main = i,ylim = range(res0$beta))
lines(log(res1$lambda),res1$beta[i+1,],col=2)
lines(log(res2$lambda),res2$beta[i,],col=3)
}
X = scale(X)/sqrt(p)
K = tcrossprod(X)
K = K#/mean(diag(K))
K = K + diag(1e-10,n)
rownames(K) = colnames(K) = data$ID
sK = svd(K)
cK = diag(sqrt(sK$d)) %*% t(sK$u)
# cK = chol(K)
# 
# y = X %*% beta + sqrt(s2_g + s2_e)*rnorm(n)
y = sqrt(s2_g)*t(cK) %*% rnorm(n) + sqrt(s2_e)*rnorm(n)
y = scale(y)
library(rrBLUP)
resr = mixed.solve(y = y,K=K)
s2_g = resr$Vu
s2_e = resr$Ve
prop_b = seq(0,1,length=101)[-c(1,101)]
bs_target = s2_g*prop_b
lambdas = (s2_e) / (n*bs_target) #+s2_g - bs_target
res1 = GridLMMnet(y~0+(1|ID),data = data,X = X,h2_step = 0.0025,relmat = list(ID=K),alpha = 0,lambda = lambdas)
plot(lambdas)
plot(res1$h2s)
plot(res1$s2s);abline(h=resr$Ve)
# as = s2_g - s2_e/(n*lambdas)
# as = s2_g - res1$s2s/(n*lambdas)
# as = res1$h2s * (1-res1$h2s) * s2_e
# bs = s2_e/(n*lambdas)
as = res1$h2s / (1-res1$h2s) * res1$s2s
bs = res1$s2s/(n*lambdas)
plot(as,s2_g*(1-prop_b));abline(0,1)

# as = res1$h2s*res1$s2s
# bs = res1$s2s/(n*lambdas)

plot(bs,bs_target);abline(0,1)
plot(as,col=2,ylim = c(0,1));points(bs,col=3);points(as+bs,col=1);abline(h=resr$Vu)

plot(res1,'lambda');abline(h=-sort(-res1$beta[,1])[2])
bs_target = s2_g*prop_b
tr = sapply(bs_target,function(b) {
  a = s2_g - b
  V = a/s2_e*K + diag(1,n)
  sum(diag(solve(V)))
})
lambdas = (s2_e) / (tr*bs_target) #+s2_g - bs_target
res1 = GridLMMnet(y~0+(1|ID),data = data,X = X,h2_step = 0.0025,relmat = list(ID=K),alpha = 0,lambda = lambdas)
plot(lambdas)
plot(res1$h2s)
plot(res1$s2s);abline(h=resr$Ve)
# as = s2_g - s2_e/(n*lambdas)
# as = s2_g - res1$s2s/(n*lambdas)
# as = res1$h2s * (1-res1$h2s) * s2_e
# bs = s2_e/(n*lambdas)
tr2 = sapply(res1$h2s, function(h2) {
  V = h2/(1-h2)*K + diag(1,n)
  sum(diag(solve(V)))
})
as = res1$h2s / (1-res1$h2s) * res1$s2s
bs = res1$s2s/(tr2*lambdas)
plot(as,s2_g*(1-prop_b));abline(0,1)

# as = res1$h2s*res1$s2s
# bs = res1$s2s/(n*lambdas)

plot(bs,bs_target);abline(0,1)
plot(as,col=2,ylim = c(0,1));points(bs,col=3);points(as+bs,col=1);abline(h=resr$Vu)

plot(res1,'lambda');abline(h=-sort(-res1$beta[,1])[2])
res0 = glmnet(X,y,alpha= 0,intercept = F,standardize=F)
# res0b = glmnet(X*5,y*5,alpha= 0,intercept = F,standardize=F,lambda = res0$lambda*25,family = 'mgaussian',standardize.response = F)
library(rrBLUP)
resr = mixed.solve(y = y,K=K)
l1 = (resr$Ve+.99*resr$Vu)/((.01*resr$Vu)*sum(diag(solve(.99*K + (1-.99)*diag(1,n)))))
l2 = (resr$Ve+.01*resr$Vu)/((.99*resr$Vu)*sum(diag(solve(.01*K + (1-.01)*diag(1,n)))))
lambdas = 1/seq(1/(l1),1/(l2),length=100)
res1 = GridLMMnet(y~0+(1|ID),data = data,X = X,h2_step = 0.0025,relmat = list(ID=K),alpha = 0,lambda = lambdas * sum(diag(solve(s2_g*K + (1-s2_g)*diag(1,n))))/n)
plot(res1$h2s)
# res2 = ggmix(X,y,kinship = K,standardize = T,alpha = 0,lambda = res0$lambda)
tr = sapply(res1$h2s,function(h2) sum(diag(solve(h2*K + (1-h2)*diag(1,n)))))
s2hat = res1$s2s# - 2*res1$lambda*colSums(res1$beta^2)/2
i = res1$h2s>0.001 #& res1$h2s < .99
# a2 = apply(res1$beta^2,2,mean)
a1 = res1$h2s*s2hat;a2 = s2hat/(res1$lambda*tr)#*tr/n/2)
plot((res1$lambda*tr/n)[i],a1[i]+a2[i],ylim=c(0,1.5),log='x');points((res1$lambda*tr/n)[i],a1[i],col=3);points((res1$lambda*tr/n)[i],a2[i],col=2);abline(h=(a1[i]+a2[i])[1]);abline(h=resr$Vu,col=2)
a2 = s2hat/(res1$lambda*n)#*tr/n/2)
plot(res1$lambda[i],a1[i]+a2[i],ylim=c(0,1.5),log='x');points(res1$lambda[i],a1[i],col=3);points(res1$lambda[i],a2[i],col=2);abline(h=(a1[i]+a2[i])[1]);abline(h=resr$Vu,col=2)
# a2 = s2hat/(res1$lambda*tr/2)#*tr/n/2)
# plot(a1[i]+a2[i],ylim=c(0,1));points(a1[i],col=3);points(a2[i],col=2);abline(h=(a1[i]+a2[i])[1]);abline(h=resr$Vu,col=2)
plot(res1,'lambda')
plot((res1$lambda*tr),res1$beta[order(res1$beta[,1])[1],],log='x')
plot((res1$lambda),res1$beta[order(res1$beta[,1])[1],],log='x')
res1a=res1
res1 = GridLMMnet(y~0+(1|ID),data = data,X = X,h2_step = 0.0025,relmat = list(ID=K),alpha = 0,lambda = 1/seq(1/lambdas[21],1/lambdas[22],length=100))
plot(res1$h2s)
# res2 = ggmix(X,y,kinship = K,standardize = T,alpha = 0,lambda = res0$lambda)
tr = sapply(res1$h2s,function(h2) sum(diag(solve(h2*K + (1-h2)*diag(1,n)))))
s2hat = res1$s2s# - 2*res1$lambda*colSums(res1$beta^2)/2
i = res1$h2s>0.001 #& res1$h2s < .99
# a2 = apply(res1$beta^2,2,mean)
a1 = res1$h2s*s2hat;a2 = s2hat/(res1$lambda*tr)#*tr/n/2)
plot((res1$lambda*tr/n)[i],a1[i]+a2[i],ylim=c(0,1.5),log='x');points((res1$lambda*tr/n)[i],a1[i],col=3);points((res1$lambda*tr/n)[i],a2[i],col=2);abline(h=(a1[i]+a2[i])[1]);abline(h=resr$Vu,col=2)
a2 = s2hat/(res1$lambda*n)#*tr/n/2)
plot(res1$lambda[i],a1[i]+a2[i],ylim=c(0,1.5),log='x');points(res1$lambda[i],a1[i],col=3);points(res1$lambda[i],a2[i],col=2);abline(h=(a1[i]+a2[i])[1]);abline(h=resr$Vu,col=2)
# a2 = s2hat/(res1$lambda*tr/2)#*tr/n/2)
# plot(a1[i]+a2[i],ylim=c(0,1));points(a1[i],col=3);points(a2[i],col=2);abline(h=(a1[i]+a2[i])[1]);abline(h=resr$Vu,col=2)
plot(res1,'lambda')
plot((res1$lambda*tr),res1$beta[order(res1$beta[,1])[1],],log='x')
plot((res1$lambda),res1$beta[order(res1$beta[,1])[1],],log='x')
res1 = GridLMMnet(y~0+(1|ID),data = data,X = X,h2_step = 0.0025,relmat = list(ID=K),alpha = 0,lambda =lambdas*5)
plot(res1$h2s)
plot(res1$s2s)
# res2 = ggmix(X,y,kinship = K,standardize = T,alpha = 0,lambda = res0$lambda)
# tr = sapply(res1$h2s,function(h2) sum(diag(solve(h2*K + (1-h2)*diag(1,n)))))
s2e = res1$s2s# - 2*res1$lambda*colSums(res1$beta^2)/2
i = res1$h2s>0.001 #& res1$h2s < .99
# a2 = apply(res1$beta^2,2,mean)
a1 = res1$h2s/(1-res1$h2s)*s2e;a2 = s2e/(res1$lambda*n)#*tr/n/2)
# plot((res1$lambda*tr/n)[i],a1[i]+a2[i],ylim=c(0,1.5),log='x');points((res1$lambda*tr/n)[i],a1[i],col=3);points((res1$lambda*tr/n)[i],a2[i],col=2);abline(h=(a1[i]+a2[i])[1]);abline(h=resr$Vu,col=2)
# a2 = s2hat/(res1$lambda*n)#*tr/n/2)
plot(res1$lambda[i],a1[i]+a2[i],ylim=c(0,1.5),log='x');points(res1$lambda[i],a1[i],col=3);points(res1$lambda[i],a2[i],col=2);abline(h=(a1[i]+a2[i])[1]);abline(h=resr$Vu,col=2)
# a2 = s2hat/(res1$lambda*tr/2)#*tr/n/2)
# plot(a1[i]+a2[i],ylim=c(0,1));points(a1[i],col=3);points(a2[i],col=2);abline(h=(a1[i]+a2[i])[1]);abline(h=resr$Vu,col=2)
plot(res1,'lambda')
# plot((res1$lambda*tr),res1$beta[order(res1$beta[,1])[1],],log='x')
plot((res1$lambda),res1$beta[order(res1$beta[,1])[1],],log='x')
# plot(a1[i]+(a2[i]),log='y')
# plot(a1[i]+(a2[i]),a2[i]/(a1[i]+a2[i]))
# plot(a2[i])

resrs1=do.call(rbind,lapply(1:ncol(res1$beta),function(i) {
  resr = mixed.solve(y=y - X %*% res1$beta[,i],K=K,method='ML')
  c(resr$Vu,resr$Ve,cor(resr$u,X %*% res1$beta[,i]),coef(lm(resr$u~X %*% res1$beta[,i])))
}))
resrs2=do.call(cbind,lapply(1:ncol(res1$beta),function(i) {
  resr = mixed.solve(y=y - X %*% res1$beta[,i],K=K)
  c(X %*% res1$beta[,i]) + resr$u
}))

plot(resrs1[i,1]/(resrs1[i,5])^2+resrs1[i,1])
plot(resrs1[i,1]/(resrs1[i,5])^2+resrs1[i,1],resrs1[i,3])
plot(resrs1[i,3])
plot(a2[i],resrs1[i,1]/(resrs1[i,5])^2);abline(0,1)
plot((rowSums(resrs1[,1:2])/(res1$lambda*tr))[i],resrs1[i,1]/(resrs1[i,5])^2);abline(0,1)
a2 = resrs1[,1]/(resrs1[,5])^2
a2 = rowSums(resrs1[,1:2])/(res1$lambda*tr)


a2 = rowSums(resrs1[,1:2])/(res1$lambda*tr)
plot(a1/a2,resrs1[,5]);abline(0,1)
a2 = resrs1[,1]/resrs1[,5]
a1 = resrs1[,1]
plot(a1[i]+(a2[i]),log='y')
a1 = res1$h2s*res1$s2s
a = resr$Vu
a2 = a-a1
plot(a2)
l = (resr$Ve + a1)/(tr*a2)
plot(l[i]/res1$lambda[i]);abline(0,1)
XtX = crossprod(X)
Xty = t(X) %*% y
sXtX = svd(XtX)
UtXty = t(sXtX$u) %*% Xty
betas=sapply(res0$lambda,function(lambda) {
  # A = XtX
  # diag(A) = diag(A) + lambda*n/2
  # solve(A,t(X) %*% y)
  sXtX$u %*% (1/(sXtX$d + lambda*n) * UtXty)
})
i=10
plot(betas[,i],res0$beta[,i]);abline(0,1)
h2 = 0.25
V = h2 * X %*% t(X)/p + (1-h2)*diag(1,n)
R = chol(V)
tr = sum(diag(solve(V)))
ys = solve(t(R),y)
Xs = solve(t(R),X)
XtX = crossprod(Xs)
Xty = t(Xs) %*% ys
sXtX = svd(XtX)
UtXty = t(sXtX$u) %*% Xty
betas_h2=sapply(res0$lambda,function(lambda) {
  # A = XtX
  # diag(A) = diag(A) + lambda*n/2
  # solve(A,t(X) %*% y)
  # tr = sum(Rinv^2)
  sXtX$u %*% (1/(sXtX$d + lambda*tr/(1-h2)/h2) * UtXty)
})
i = 10
plot(cor(betas,betas_h2[,i]));abline(v=i)
h2s = seq(0,.99,length=100)
plot(h2s,sapply(h2s,function(h2) {V = h2 * X %*% t(X)/p + (1-h2)*diag(1,n)
sum(diag(solve(V)))}))
h2 = 0.25
V = h2 * X %*% t(X)/p + (1-h2)*diag(1,n)
log(det(V))
sX = svd(X %*% t(X)/p)
sum(log((h2*sX$d+(1-h2))))
X2 = cbind(X,X)
i = 100
lambda = res0$lambda[i]
s2e = mean((y - X %*% res0$beta[,i])^2) + lambda*sum(res0$beta[,i]^2)
a = s2e/lambda
h2 = .05
lambdas = c(rep(h2*a,p),rep((1-h2)*a,p))
A = t(X2) %*% X2
diag(A) = diag(A) + 1/lambdas
bs = solve(A,t(X2) %*% y)
yhat = X2 %*% bs

h21 = ((1-h2)*a)/((1-h2)*a + s2e)
V = h21 * X %*% t(X) + (1-h21)*diag(1,n)
cV = chol(V)
cy = solve(t(cV),y)
cX = solve(t(cV),X)
A = t(cX) %*% cX
diag(A) = diag(A) + ((1-h2)*a + s2e)/(h2*a)
b1 = solve(A,t(cX) %*% cy)
# cor(b1,bs[1:p])
plot(b1,bs[1:p])
abline(0,1)
# plot(bs[1:p],bs[p+1:p])
# plot(y - X %*% res0$beta[,i],t(cV) %*% (cy - cX %*% b1));abline(0,1)
X = rbind(X,X)
data = rbind(data,data)
n = nrow(data)
X = scale(X)/sqrt(p)
K = tcrossprod(X)
K = K#/mean(diag(K))
rownames(K) = colnames(K) = data$ID
sK = svd(K)
cK = diag(sqrt(sK$d)) %*% t(sK$u)
s2_g = 0.3;s2_e = 1-s2_g
y = sqrt(s2_g)*t(cK) %*% rnorm(n) + sqrt(s2_e)*rnorm(n)
sd_y = 1
y = scale(y)*sd_y



r1 = mixed.solve(y=y,Z=X,method='ML')
l = r1$Ve/(r1$Vu*n)
# r1b = glmnet(x=X,y=y,intercept = F,standardize = F,alpha=0)
# beta = coef(r1b,s=l/sd_y,exact = T,x=X,y=y)[-1,1]
# plot(coef(r1b,s=l/sd_y,exact = T,x=X,y=y)[-1,1],r1$u);abline(0,1)
r1b = glmnet(x=X,y=y,intercept = F,standardize = F,alpha=0,lambda = 2^seq(1,-1,length=1001)*l*sd_y)
beta = r1b$beta[,500]
plot(beta,r1$u);abline(0,1)
# plot(coef(r1b,s=l,exact = T,x=X,y=y)[-1,1]-r1$u)

r1c = glmnet(x=X,y=y,intercept = F,standardize = F,alpha=0,family = 'mgaussian',standardize.response = F,lambda = 2^seq(1,-1,length=1001)*l)
plot(r1c$beta[,500],r1$u);abline(0,1)
# plot(coef(r1b,s=l,exact = T,x=X,y=y)[-1,1]-r1$u)

r1d = glmnet(x=X/sd_y,y=y/sd_y,intercept = F,standardize = F,alpha = 0,lambda = 2^seq(1,-1,length=1001)*l/sd_y^2)
plot(r1d$beta[,500],r1$u);abline(0,1)

RSS = sum((y-X %*% beta)^2)
penalties = sum(beta^2/2)
s2_hat = (RSS + 2*n*l*penalties)/n
s2_hat
r1$Ve
RSS/n
# s2_hat is equal to r1$Ve. So it's reasonable to use s2_hat as the estimate of s2_e. It does include the penalty!
prop_u = 0.1
s2_u = r1$Vu * prop_u
s2_b = r1$Vu * (1-prop_u)
s2_e = r1$Ve
V = s2_u / s2_e * K + diag(1,n)
cV = chol(V)
cy = solve(t(cV),y)
cX = solve(t(cV),X)
c1 = solve(t(cV),matrix(1,n,1))
r2 = mixed.solve(y = cy,Z = cX,X = c1,method='ML')
c(r2$Ve,s2_e)
c(r2$Vu,s2_b)
# conditioning on s2_u, we get the correct s2_b

sd_cy = sd(cy)*sqrt((n-1)/n)
l = r2$Ve/(r2$Vu*n)
r2b = glmnet(x=cbind(c1,cX),y=cy,intercept = F,standardize = F,alpha=0)
beta = coef(r2b,s=l*sd_cy,exact = T,x=cX,y=cy)[-1,1]
plot(beta,r2$u);abline(0,1)
r2b2 = glmnet(x=cbind(c1,cX),y=cy,intercept = F,standardize = F,alpha=0,lambda = 2^seq(1,-1,length=1001)*l*sd_cy)
beta2 = r2b2$beta[-1,500]
plot(beta2,r2$u);abline(0,1)

RSS = sum((cy-cX %*% beta)^2)
penalties = sum(beta^2/2)
s2_b_hat = (RSS + 2*n*l*penalties)/n
cbind(s2_b_hat,r2$Ve)
RSS/n

r2c = glmnet(x=cbind(c1,cX),y=cy,intercept = F,standardize = F,alpha=0,family = 'mgaussian',standardize.response = F,lambda = 2^seq(1,-1,length=1001)*l)
betac = r2c$beta[-1,500]
plot(betac,r2$u);abline(0,1)


r2d = glmnet(x=cbind(c1,cX)/sd_cy,y=cy/sd_cy,intercept = F,standardize = F,alpha=0,lambda = 2^seq(1,-1,length=1001)*l/sd_cy^2)
betad = r2d$beta[-1,500]
plot(betad,r2$u);abline(0,1)

# all different ways to the same result!
# given beta, what do we solve for h2?

results = c()
prop_us = seq(0,1,length=101)
prop_us = prop_us[-c(1,length(prop_us))]
for(prop_u in prop_us) {
  s2_u = r1$Vu * prop_u
  s2_b = r1$Vu * (1-prop_u)
  s2_e = r1$Ve
  V = s2_u / s2_e * K + diag(1,n)
  cV = chol(V)
  cy = solve(t(cV),y)
  cX = solve(t(cV),X)
  c1 = solve(t(cV),matrix(1,n,1))
  r2 = mixed.solve(y = cy,Z = cX,X = c1,method='ML')
  c(r2$Ve,s2_e)
  c(r2$Vu,s2_b)
  # conditioning on s2_u, we get the correct s2_b

  sd_cy = sd(cy)*sqrt((n-1)/n)
  l = r2$Ve/(r2$Vu*n)
  r2b = glmnet(x=cbind(c1,cX),y=cy,intercept = F,standardize = F,alpha=0)
  beta = coef(r2b,s=l*sd_cy,exact = T,x=cX,y=cy)[-1,1]

  y2 = y - X %*% beta
  r3 = mixed.solve(y = y2,Z=X,method='ML')
  results = rbind(results,data.frame(
    s2_u = s2_u,
    s2_e = s2_e,
    prop_u = prop_u,
    Vu = r3$Vu,
    Ve = r3$Ve
  ))
}
plot(results$prop_u,results$s2_u)
plot(results$prop_u,results$Vu,ylim = c(0,r1$Vu));lines(results$prop_u,results$s2_u)
plot(results$prop_u,results$Ve,ylim = c(0,r1$Ve));lines(results$prop_u,results$s2_e)
a_hat = r1$Vu
s2e_hat = r1$Ve

prop_1 = 0.9
a_2 = (1-prop_1) * a_hat
h2_1 = a_2 / (a_2 + s2e_hat)

V_1 = h2_1 * K + (1-h2_1) * diag(1,n)
cV_1 = chol(V_1)
# r2 = mixed.solve(y = solve(t(cV_1),y),K = tcrossprod(solve(t(cV_1),X)))
cy = solve(t(cV_1),y)
cX = solve(t(cV_1),X)
c1 = solve(t(cV_1),matrix(1,n,1))
r2 = mixed.solve(y = cy,Z = cX,X = c1,method='ML')

r2$Vu
a_2
a_hat
a_2 + r2$Vu

sd_cy = sd(cy)*(n-1)/n
r2b = glmnet(x = cbind(c1,cX)/sd_cy,y = cy/sd_cy,intercept = F,standardize = F,alpha=0)
tr = sum(diag(solve(V_1)))
l = r2$Ve/(r2$Vu*n)
bhat = r2$u
plot(coef(r2b,s=l/sd_cy^2,exact=T,x = cX/sd_cy,y = cy/sd_cy)[-1,1],r2$u);abline(0,1)
l2=l
s2hat = (r2$Ve*n + 2*l2*tr*(sum(bhat^2)/2)/n)/(n+p)
(l2 = s2hat*n/(r2$Vu*tr))
l2=r2$Ve/(r2$Vu)/n
s2hat = (r2$Ve*n + 2*l2*(sum(bhat^2)/2))/(n+p)
(l2 = s2hat/(r2$Vu))
plot(coef(r2b,s=l2/sd_cy^2,exact=T,x = cX/sd_cy,y = cy/sd_cy)[-1,1],r2$u);abline(0,1)
r1 = mixed.solve(y=y,Z=X,method='ML')
K = tcrossprod(X)
sK = svd(K)
# sum(log(sK$d[sK$d>1e-10]))
sX = svd(X)
# 2*sum(log(sX$d[sX$d>1e-10]))
np = sum(sX$d>1e-10)
r2 = mixed.solve(y=y,K=K,method='ML')
r3 = mixed.solve(y,Z=sX$u[,1:np],K = diag(sK$d[1:np]),method='ML')


r1$LL
r2$LL
r3$LL

-np/2*log(r3$Vu) - 2/2*sum(log(sX$d[sX$d>1e-10]))
- 2/2*sum(log(sqrt(r3$Vu)*sX$d[sX$d>1e-10]))

e = y - X %*% r1$u - r1$beta[1]
e2 = y - sX$u[,1:np] %*% r3$u - r1$beta[1]
-n/2*log(2*pi) - n/2*log(r3$Ve) - 1/(2*r3$Ve)*sum(e2^2) - 1/(2*r3$Vu)*matrix(r3$u,nr=1) %*% diag(1/sK$d[1:np]) %*% r3$u -np/2*log(r3$Vu) - np/2*log(2*pi) - 2/2*sum(log(sX$d[1:np]))
-n/2*log(2*pi) - n/2*log(r1$Ve) - 1/(2*r1$Ve)*sum(e^2) - 1/(2*r1$Vu)*sum(r1$u^2) -np/2*log(r1$Vu) - np/2*log(2*pi) - 2/2*sum(log(sX$d[sX$d>1e-10])) 
-n/2*log(2*pi) - n/2*log(r1$Ve) - 1/(2*r1$Ve)*sum(e^2) - 1/(2*r1$Vu)*sum(r1$u^2) -1/2*determinant(diag(1,p) + t(X) %*% diag(1/r1$Ve,n) %*% X * r1$Vu)$mod# p/2*log*r1$Vu
V = r1$Vu * K + diag(r1$Ve,n)
cV = chol(V)
-n/2*log(2*pi) - 1/(2)*sum(solve(t(cV),y - r1$beta[1])^2) - sum(log(diag(cV)))




s2 = r1$Vu+r1$Ve
h2 = r1$Vu/s2
V = h2*K + (1-h2)*diag(1,n)
cV = chol(V)
-n/2*log(2*pi) - n/2*log(s2) - 1/2*determinant(V)$mod - 1/(2*s2)*crossprod(solve(t(cV),(y-r1$beta[1])))
l = r1$Ve/r1$Vu
tr = sum(diag(solve(V)))
-n/2*log(2*pi) - n/2*log(r1$Ve) - 1/(2*r1$Ve)*sum(e^2) - 1/(2*r1$Vu)*sum(r1$u^2) -p/2*log(r1$Vu)  - p/2*log(2*pi)
-n/2*log(2*pi) - n/2*log(r1$Ve) - 1/(2*r1$Ve)*(sum(e^2) + l*sum(r1$u^2)) -p/2*log(r1$Vu) - p/2*log(2*pi)
-n/2*log(2*pi) - n/2*log(r1$Ve) - 1/(2*r1$Ve)*(sum(e^2) + 2*l*sum(r1$u^2)/2) -p/2*log(r1$Vu) - p/2*log(2*pi)
-(n+p)/2*log(2*pi) - (n+p)/2*log(r1$Ve) - 1/(2*r1$Ve)*(crossprod((y-r1$beta[1] - X %*% r1$u)) + 2*l*sum(r1$u^2)/2) - p/2*log(1/(l))

# r2 = mixed.solve(y,K = tcrossprod(X),method='ML')
s2 = r1$Vu+r1$Ve
h2 = r1$Vu/(s2)
V =h2*K + (1-h2)*diag(1,n)
# cV =chol(V)
-n/2*log(2*pi) - n/2*log(s2) -1/2*determinant(V)$mod  - 1/(2*r1$Ve)*sum(e^2) - 1/(2*r1$Vu)*sum(r1$u^2)
prop_1 = 0.5999
a1 = prop_1*r1$Vu
a2 = r1$Vu - a1
s2_1 = a1+r1$Ve
h2_1 = a1/s2_1
V_1 = h2_1 * K + (1-h2_1)*diag(1,n)
cV_1 = chol(V_1)
tr = sum(diag(solve(V_1)))
# r2 = mixed.solve(y = solve(t(cV_1),y),K = tcrossprod(solve(t(cV_1),X)))
cy = solve(t(cV_1),y)
cX = solve(t(cV_1),X)
c1 = solve(t(cV_1),matrix(1,n,1))
r2 = mixed.solve(y = cy,Z = cX,X = c1,method='ML')
e2 = cy - cX %*% r2$u - c1*r2$beta[1]
-n/2*log(2*pi) - 1/2*determinant(V_1)$mod - 
  1/2*determinant(diag(1,p) + t(X) %*% solve(V_1) %*% X * a2/s2_1)$mod - 
  n/2*log(s2_1) - 1/(2*s2_1)*sum(e2^2) - 1/(2*a2)*sum(r2$u^2)#*tr/n

-n/2*log(2*pi) - #n/2*log(s2_1) -# 1/2*determinant(V_1)$mod - 
  1/2*determinant(s2_1*V_1 + a2*K)$mod -
  1/(2*s2_1)*sum(e2^2) - 1/(2*a2)*sum(r2$u^2)#*tr/n

l = s2_1/(n*a2)
s2_hat = sum(e2^2)/n + 2*l*sum(r2$u^2)/2 #== s2_1
-n/2*log(2*pi) - n/2*log(s2_hat) -  n/2 - 1/2*determinant(V_1)$mod - 
  1/2*determinant(diag(1,p) + t(X) %*% solve(V_1) %*% X / (n*l))$mod 

-n/2*log(2*pi) - n/2*log(s2_hat) -  n/2 - 1/2*determinant(V_1)$mod - 
  1/2*determinant(diag(1,p) + crossprod(cX) / (n*l))$mod 

s2 = r1$Vu+r1$Ve
h2 = r1$Vu/s2
V = h2*K + (1-h2)*diag(1,n)
-n/2*log(2*pi*s2) - 1/2*determinant(V)$mod- 
  1/(2*s2_1)*sum(e2^2) - 1/(2*a2)*sum(r2$u^2)


l = s2_1/(tr*a2)
1/(2*s2_1)*(sum(e2^2) - l*tr*sum(r2$u^2))
1/(2*s2_1)*sum(diag(solve(V_1,(tcrossprod(y - X %*% r2$u - r2$beta[1]) - l*sum(r2$u^2)*diag(1,n)))))


# determinant(diag(1,p) + t(X) %*% solve(V_1) %*% X * a2/s2_1)$mod
# determinant(diag(1,p) + crossprod(cX) * a2/s2_1)$mod
# determinant(s2_1*V_1 + a2*X %*% t(X))$mod - determinant(s2_1*V_1)$mod
props = seq(0,.99,length=100)
a1s=sapply(props,function(p) {
  b = r1$u * sqrt(1-p)
  ys = y - X %*% b - r1$beta[1]
  ys = ys/sd(ys)
  r3 = mixed.solve(y = ys,Z=X,method='ML')
  V = p*K + (1-p)*diag(1,n)
  tr = sum(diag(solve(V)))
  r3$Vu #*tr/n
})
plot(props*r1$Vu,a1s);abline(0,1)

V_1 = h2_1 * K + (1-h2_1)*diag(1,n)
cV_1 = chol(V_1)
tr = sum(diag(solve(V_1)))
# r2 = mixed.solve(y = solve(t(cV_1),y),K = tcrossprod(solve(t(cV_1),X)))
cy = solve(t(cV_1),y)
cX = solve(t(cV_1),X)
c1 = solve(t(cV_1),matrix(1,n,1))
r2 = mixed.solve(y = cy,Z = cX,X = c1,method='ML')
p = ncol(X)
X = scale(X)/sqrt(p)
K = tcrossprod(X)
K = K#/mean(diag(K))
rownames(K) = colnames(K) = data$ID
sK = svd(K)
cK = diag(sqrt(sK$d)) %*% t(sK$u)
y = sqrt(s2_g)*t(cK) %*% rnorm(n) + sqrt(s2_e)*rnorm(n)
# y = scale(y)



r1 = mixed.solve(y=y,Z=X,method='ML')

p = .9
y2 = y - sqrt(p)*X %*% r1$u

b2 = solve(t(X) %*% X + diag(r1$Ve/(1-p*r1$Vu),ncol(X)),t(X) %*% y2)
plot(b2,sqrt(1-p)*r1$u);abline(0,1)

# r2 = mixed.solve(y=y2,Z=X,method='ML')
# 
# r1$Vu
# r1$Ve
# r2$Vu
# r2$Ve
# (1-r2$Vu / r1$Vu)
# plot(r2$u,r1$u)
p = ncol(X)
X = scale(X)/sqrt(p)
K = tcrossprod(X)
K = K#/mean(diag(K))
rownames(K) = colnames(K) = data$ID
sK = svd(K)
cK = diag(sqrt(sK$d)) %*% t(sK$u)
s2_g = .9;s2_e = 1-s2_g
y = sqrt(s2_g)*t(cK) %*% rnorm(n) + sqrt(s2_e)*rnorm(n)
# y = scale(y)



r1 = mixed.solve(y=y,Z=X,method='ML')

p = .1
a = p*r1$Vu
b = (1-p)*r1$Vu
e = r1$Ve

b1 = sqrt(p)*r1$u

y2 = y - X%*% b1 #* sqrt(b/(a+e))

b2 = solve(t(X) %*% X + e/b*diag(1,ncol(X)),t(X) %*% y2)[,1]

plot(b1,b2)
plot(b1+b2,r1$u);abline(0,1)
p = ncol(X)
X = scale(X)/sqrt(p)
K = tcrossprod(X)
K = K#/mean(diag(K))
K = K + diag(1e-3,n)
rownames(K) = colnames(K) = data$ID
sK = svd(K)
cK = diag(sqrt(sK$d)) %*% t(sK$u)
s2_g = .9;s2_e = 1-s2_g
y = sqrt(s2_g)*t(cK) %*% rnorm(n) + sqrt(s2_e)*rnorm(n)
# y = scale(y)



r1 = mixed.solve(y=y,Z=X,method='ML')

p = .4
a = p*r1$Vu
b = (1-p)*r1$Vu
e = r1$Ve


# V = a/(e) * K + diag((a+e)/(e),n)
# cV = chol(V)
# cy = solve(t(cV),y)
# cX = solve(t(cV),X)
# b2 = solve(t(cX) %*% cX + diag(e/b,ncol(X)),t(cX) %*% cy/e)[,1]
V = a/(a+e) * K + diag(e/(a+e),n)
cV = chol(V)
cy = solve(t(cV),y)
cX = solve(t(cV),X)
b2 = solve(t(cX) %*% cX + diag((e+a)/b,ncol(X)),t(cX) %*% cy)[,1]

y2 = y - X%*% b2 #* sqrt(b/(a+e))
b1 = solve(t(X) %*% X + diag(e/a,ncol(X)),t(X) %*% y2)[,1]  # This does give the correct b1~N(0,aI)

r1b = mixed.solve(y=y2,Z=X,method='ML')  # This does not give a = r1b$Vu. But it does still give e == r1b$Ve
b1b = solve(t(X) %*% X + diag(r1b$Ve/r1b$Vu,ncol(X)),t(X) %*% y2)[,1]

plot(b2,r1$u);abline(0,1)
plot(b1,r1$u);abline(0,1)
plot(b1+b2,r1$u);abline(0,1)
plot(b1b+b2,r1$u);abline(0,1)

-3*n/2*log(2*pi)-n/2*log(r1$Ve)-n/2*log(a)-n/2*log(b)-1/(2*r1$Ve)*sum((y-X %*% (b1+b2))^2) - 1/(2*a)*sum(b1^2) - 1/(2*b)*sum(b2^2)
-2*n/2*log(2*pi)-n/2*log(r1$Ve)-n/2*log(r1$Vu)-1/(2*r1$Ve)*sum((y-X %*% (r1$u))^2) - 1/(2*r1$Vu)*sum(r1$u^2)
r1$LL
r1b$LL
r2$LL

u1 = X %*% b1
sK = svd(X %*% t(X))
k = sum(sK$d>1e-10)
Kinv = sK$u[,1:k] %*% (1/sK$d[1:k]*t(sK$u[,1:k]))
-1/2*sum(log(2*pi*b*sK$d[1:k])) - 1/(2*b)*t(u1) %*% Kinv %*% u1
-1/2*(n-1)*log(2*pi*b)-1/2*sum(log(sK$d[1:k])) - 1/(2*b)*t(u1) %*% Kinv %*% u1
-1/2*n*log(2*pi*b) - 1/(2*b)*t(b1) %*% b1

s2 = a+e
h2 = a/s2# = a/(a+e)
V_setup = list(downdate_ratios=1,
               resid_V = diag(1,n),
               ZKZts = list(K))
chol_V_setup = make_chol_V_setup(V_setup,h2s=a/(a+e),invIpKinv = T)
b1b = chol_V_setup$invIpKinv %*% y2
b1c = solve(diag(1,n) + e/a*solve(K)) %*% y2
V = (1-h2)*diag(1,n)+h2*K
cV = chol(V)
b1c = (a/e*(K - e/(a+e)*(a/e)* crossprod(solve(t(cV),K)))) %*% y2
b1c = (h2/(1-h2)*(K - h2* crossprod(solve(t(cV),K)))) %*% y2
plot(X %*% b1,b1c);abline(0,1)
plot(X %*% b1,b1b);abline(0,1)

r2 = mixed.solve(y2,Z=X,method='ML')
c(r1$Ve,r2$Ve)
c(r1$Vu*p,r2$Vu)
plot(r1$u,r2$u+b2);abline(0,1)
r = c()
for(p in seq(0,1,length=12)[-c(1,12)]) {
  a = p*r1$Vu
  b = (1-p)*r1$Vu
  e = r1$Ve

  V = a/(a+e) * K + diag(e/(a+e),n)
  cV = chol(V)
  cy = solve(t(cV),y)
  cX = solve(t(cV),X)
  b2 = solve(t(cX) %*% cX + diag((e+a)/b,ncol(X)),t(cX) %*% cy)[,1]

  y2 = y - X%*% b2 #* sqrt(b/(a+e))
  b1 = solve(t(X) %*% X + diag(e/a,ncol(X)),t(X) %*% y2)[,1]  # This does give the correct b1~N(0,aI)

  r = rbind(r,c(p,log(a+b),log(a)+log(b),
  -3*n/2*log(2*pi)-n/2*log(r1$Ve)-n/2*log(a)-n/2*log(b)-1/(2*r1$Ve)*sum((y-X %*% (b1+b2))^2) - 1/(2*a)*sum(b1^2) - 1/(2*b)*sum(b2^2), # this joint likelihood isn't constant across p
  -3*n/2*log(2*pi)-n/2*log(r1$Ve)-1/2*sum(log(a*sK$d[1:k]))-1/2*sum(log(b*sK$d[1:k]))-1/(2*r1$Ve)*sum((y-X %*% (b1+b2))^2) - 1/(2*a)*sum(b1^2) - 1/(2*b)*sum(b2^2),
  -2*n/2*log(2*pi)-n/2*log(r1$Ve)-n/2*log(a+b)-1/(2*r1$Ve)*sum((y-X %*% (b1+b2))^2) - 1/(2*r1$Vu)*sum((b1+b2)^2) # this is constant
  ))
}
# the joint likelihood favors solutions with only one of a or b (i.e. p = 0 or 1) over splits between the two
# even though the fit to the data is the same
# the difference is in the determinant terms. log(a+b) != log(a)+log(b)
r = c()
X2 = cbind(X,X)
K2 = t(X2) %*% X2
X2y = t(X2) %*% y
p = ncol(X)
for(prop in seq(0,1,length=12)[-c(1,12)]) {
  l = r1$Ve/c(rep(prop*r1$Vu,p),rep((1-prop)*r1$Vu,p))
  bs = solve(K2 + diag(l),X2y)
  b1 = bs[1:p]
  b2 = bs[p+1:p]
}  


deruncie/GridLMM documentation built on May 2, 2023, 7:18 p.m.