fit_vae: Variational autoencoder model fitting

Description Usage Arguments Value Author(s) References See Also Examples

View source: R/VAExprs.R

Description

A fundamental problem in biomedical research is the low number of observations available. Augmenting a few real observations with generated in silico samples could lead to more robust analysis. Here, the variational autoencoder (VAE) is used for the realistic generation of single-cell RNA-seq data. Also, the conditional variational autoencoder (CVAE) can be used if labels of samples are available. This function allows us to fit variational autoencoders with the standard Gaussian prior to expression data. It is assumed that there will likely be no clusters in the latent space representation of variational autoencoders.

Usage

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
fit_vae(object = NULL,
        x_train = NULL,
        x_val = NULL,
        y_train = NULL,
        y_val = NULL,
        encoder_layers,
        decoder_layers,
        latent_dim = 2,
        regularization = 1,
        epochs,
        batch_size,
        preprocessing = list(
            x_train = NULL,
            x_val = NULL,
            y_train = NULL,
            y_val = NULL,
            minmax = NULL,
            lenc = NULL),
        use_generator = FALSE,
        optimizer = "adam",
        validation_split = 0, ...)

Arguments

object

SummarizedExperiment object

x_train

expression data for train, where each row is a cell and each column is a gene

x_val

expression data for validation, where each row is a cell and each column is a gene

y_train

labels for train

y_val

labels for validation

encoder_layers

list of layers for encoder

decoder_layers

list of layers for decoder

latent_dim

dimension of latent vector (default: 2)

regularization

regularization parameter, which is nonnegative (default: 1)

epochs

number of epochs

batch_size

batch size

preprocessing

list of preprocessed results, they are set to NULL as default

  • x_train : expression data for train

  • x_val : expression data for validation

  • y_train : labels for train

  • y_val : labels for validation

  • minmax : result of min-max normalization

  • lenc : encoded labels

use_generator

use data generator if TRUE (default: FALSE)

optimizer

name of optimizer (default: adam)

validation_split

proportion of validation data, it is ignored when there is a validation set (default: 0)

...

additional parameters for the "fit" or "fit_generator"

Value

model

trained VAE model

encoder

trained encoder model

decoder

trained decoder model

preprocessing

preprocessed results

Author(s)

Dongmin Jung

References

Marouf, M., Machart, P., Bansal, V., Kilian, C., Magruder, D. S., Krebs, C. F., & Bonn, S. (2020). Realistic in silico generation and augmentation of single-cell RNA-seq data using generative adversarial networks. Nature communications, 11(1), 1-12.

See Also

SummarizedExperiment::assay, SummarizedExperiment::colData, scater::logNormCounts, gradDescent::minmaxScaling, keras::fit, keras::fit_generator, keras::compile, CatEncoders::LabelEncoder.fit, CatEncoders::transform, DeepPINCS::multiple_sampling_generator

Examples

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
### simulate differentially expressed genes
set.seed(1)
g <- 3
n <- 100
m <- 1000
mu <- 5
sigma <- 5
mat <- matrix(rnorm(n*m*g, mu, sigma), m, n*g)
rownames(mat) <- paste0("gene", seq_len(m))
colnames(mat) <- paste0("cell", seq_len(n*g))
group <- factor(sapply(seq_len(g), function(x) { 
    rep(paste0("group", x), n)
}))
names(group) <- colnames(mat)
mu_upreg <- 6
sigma_upreg <- 10
deg <- 100
for (i in seq_len(g)) {
    mat[(deg*(i-1) + 1):(deg*i), group == paste0("group", i)] <- 
        mat[1:deg, group==paste0("group", i)] + rnorm(deg, mu_upreg, sigma_upreg)
}
# positive expression only
mat[mat < 0] <- 0
x_train <- as.matrix(t(mat)) 


### model
batch_size <- 32
original_dim <- 1000
intermediate_dim <- 512
epochs <- 2
# VAE
vae_result <- fit_vae(x_train = x_train,
                    encoder_layers = list(layer_input(shape = c(original_dim)),
                                        layer_dense(units = intermediate_dim,
                                                    activation = "relu")),
                    decoder_layers = list(layer_dense(units = intermediate_dim,
                                                    activation = "relu"),
                                        layer_dense(units = original_dim,
                                                    activation = "sigmoid")),
                    epochs = epochs, batch_size = batch_size,
                    validation_split = 0.5,
                    use_generator = FALSE,
                    callbacks = keras::callback_early_stopping(
                        monitor = "val_loss",
                        patience = 10,
                        restore_best_weights = TRUE))


### from preprocessing
vae_result_preprocessing <- fit_vae(preprocessing = vae_result$preprocessing,
                                    encoder_layers = list(layer_input(shape = c(original_dim)),
                                                        layer_dense(units = intermediate_dim,
                                                                    activation = "relu")),
                                    decoder_layers = list(layer_dense(units = intermediate_dim,
                                                                    activation = "relu"),
                                                        layer_dense(units = original_dim,
                                                                    activation = "sigmoid")),
                                    epochs = epochs, batch_size = batch_size,
                                    validation_split = 0.5,
                                    use_generator = FALSE,
                                    callbacks = keras::callback_early_stopping(
                                        monitor = "val_loss",
                                        patience = 10,
                                        restore_best_weights = TRUE))

dongminjung/VAExprs documentation built on Dec. 20, 2021, 12:13 a.m.