View source: R/build_vae_correlated.R
build_vae_correlated | R Documentation |
Build a VAE that fits to a normal, full covariance N(m,S) latent distribution
build_vae_correlated( num_items, num_skills, Q_matrix, mean_vector = rep(0, num_skills), covariance_matrix = diag(num_skills), model_type = 2, enc_hid_arch = c(ceiling((num_items + num_skills)/2)), hid_enc_activations = rep("sigmoid", length(enc_hid_arch)), output_activation = "sigmoid", kl_weight = 1, learning_rate = 0.001 )
num_items |
an integer giving the number of items on the assessment; also the number of nodes in the input/output layers of the VAE |
num_skills |
an integer giving the number of skills being evaluated; also the dimensionality of the distribution learned by the VAE |
Q_matrix |
a binary, |
mean_vector |
a vector of length |
covariance_matrix |
a symmetric, positive definite, |
model_type |
either 1 or 2, specifying a 1 parameter (1PL) or 2 parameter (2PL) model; if 1PL, then all decoder weights are fixed to be equal to one |
enc_hid_arch |
a vector detailing the size of hidden layers in the encoder; the number of hidden layers is determined by the length of this vector |
hid_enc_activations |
a vector specifying the activation function in each hidden layer in the encoder; must be the same length as |
output_activation |
a string specifying the activation function in the output of the decoder; the ML2P model always used 'sigmoid' |
kl_weight |
an optional weight for the KL divergence term in the loss function |
learning_rate |
an optional parameter for the adam optimizer |
returns three keras models: the encoder, decoder, and vae
Q <- matrix(c(1,0,1,1,0,1,1,0), nrow = 2, ncol = 4) cov <- matrix(c(.7,.3,.3,1), nrow = 2, ncol = 2) models <- build_vae_correlated(4, 2, Q, mean_vector = c(-0.5, 0), covariance_matrix = cov, enc_hid_arch = c(6, 3), hid_enc_activation = c('sigmoid', 'relu'), output_activation = 'tanh', kl_weight = 0.1) vae <- models[[3]]
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.