R/VAExprs.R

Defines functions plot_aug gen_exprs plot_vae fit_vae

Documented in fit_vae gen_exprs plot_aug plot_vae

fit_vae <- function(object = NULL,
                    x_train = NULL,
                    x_val = NULL,
                    y_train = NULL,
                    y_val = NULL,
                    encoder_layers,
                    decoder_layers,
                    latent_dim = 2,
                    regularization = 1,
                    epochs,
                    batch_size,
                    preprocessing = list(
                        x_train = NULL,
                        x_val = NULL,
                        y_train = NULL,
                        y_val = NULL,
                        minmax = NULL,
                        lenc = NULL),
                    use_generator = FALSE,
                    optimizer = "adam",
                    validation_split = 0, ...) {
    if (tensorflow::tf$executing_eagerly())
        tensorflow::tf$compat$v1$disable_eager_execution()
    
    result <- NULL
    result$preprocessing <- NULL
    if (regularization < 0) {
        stop("regularization parameter should be nonnegative")
    }
    
    
    ### pre-processing
    if (all(unlist(lapply(preprocessing, Negate(is.null))))) {
        result$preprocessing <- preprocessing
        x_train <- preprocessing$x_train
        if (!is.null(preprocessing$x_val)) x_val <- preprocessing$x_val
        if (!is.null(preprocessing$y_train)) y_train <- preprocessing$y_train
        if (!is.null(preprocessing$y_val)) y_val <- preprocessing$y_val
    } else if (!is.null(object)) {
        message("pre-processing...")
        x <- t(SummarizedExperiment::assay(scater::logNormCounts(object), "logcounts"))
        y <- NULL
        if (length(colData(object)) > 0) {
            y <- SummarizedExperiment::colData(object)[,1]
        }
        
        if (validation_split > 0) {
            idx <- sample(seq_len(nrow(x)))
            train_idx <- seq_len(nrow(x)) %in%
                idx[seq_len(round(nrow(x) * (1 - validation_split)))]
            x_train <- x[train_idx,]
            x_val <- x[!train_idx,]
            
            if (!is.null(y)) {
                y_train <- y[train_idx]
                y_val <- y[!train_idx]
            }
        } else {
            x_train <- x
            if (!is.null(y)) {
                y_train <- y
            }
        }
    }
    
    if (any(unlist(lapply(preprocessing, is.null)))) {
        if (is.null(object) & is.null(x_val) & validation_split) {
            message("pre-processing...")
            x <- x_train
            idx <- sample(seq_len(nrow(x)))
            train_idx <- seq_len(nrow(x)) %in%
                idx[seq_len(round(nrow(x) * (1 - validation_split)))]
            x_train <- x[train_idx,]
            x_val <- x[!train_idx,]
            if (!is.null(y_train) & is.null(y_val)) {
                y <- y_train
                y_train <- y[train_idx]
                y_val <- y[!train_idx]
            }
        }
    }
    
    
    ### normalization
    if (any(unlist(lapply(preprocessing, is.null)))) {
        message("normalizing...")
        minmax <- gradDescent::minmaxScaling(data.frame(x_train))
        x_train <- as.matrix(minmax$scaledDataSet)
        if (!is.null(x_val)) {
            x_val <- as.matrix(gradDescent::minmaxScaling(data.frame(x_val))$scaledDataSet)
        }
        result$preprocessing$minmax <- minmax
    }
    
    
    ### building model
    # VAE
    encoded <- NULL
    for (i in seq_len(length(encoder_layers))) {
        if (is.null(encoded)) {
            x <- encoder_layers[[i]]
            if (is.null(y_train)) {
                encoded <- x
            } else {
                condition <- layer_input(shape = c(1), name = "condition")
                encoded <- layer_concatenate(c(x, condition))
            }
        } else {
            encoded <- encoder_layers[[i]](encoded)
        }
    }
    
    assign("z_mean", layer_dense(encoded, latent_dim, name = "z_mean"),
        envir = globalenv())
    assign("z_log_stddev", layer_dense(encoded, latent_dim, name = "z_log_stddev"),
        envir = globalenv())
    
    z <- layer_lambda(list(z_mean, z_log_stddev),
            function(arg) {
                z_mean <- arg[[1]]
                z_log_stddev <- arg[[2]]
                
                epsilon <- k_random_normal(shape = c(k_shape(z_mean)[[1]]),
                                            mean = 0.0, stddev = 1)
                
                z_mean + k_exp(z_log_stddev)*epsilon
            }, name = "latent")
    
    if (is.null(y_train)) {
        decoded <- z
    } else {
        decoded <- layer_concatenate(c(z, condition))
    }
    for (i in seq_len(length(decoder_layers))) {
        decoded <- decoder_layers[[i]](decoded)
    }
    
    if (is.null(y_train)) {
        model <- keras_model(inputs = x, outputs = decoded)
    } else {
        model <- keras_model(inputs = c(x, condition), outputs = decoded)
    }
    
    # encoder
    if (is.null(y_train)) {
        encoder <- keras_model(inputs = x, outputs = z_mean)
    } else {
        encoder <- keras_model(inputs = c(x, condition), outputs = z_mean)
    }
    
    # decoder
    if (is.null(y_train)) {
        decoder_input <- layer_input(shape = latent_dim)
    } else {
        decoder_input <- layer_input(shape = latent_dim + 1)
    }
    decoder_output <- decoder_input
    for (i in seq_len(length(decoder_layers))) {
        decoder_output <- decoder_layers[[i]](decoder_output)
    }
    decoder <- keras_model(decoder_input, decoder_output)
    
    # loss
    vae_loss <- function(x_true, x_pred) {
        xent_loss <- (original_dim / 1.0) * loss_binary_crossentropy(x_true, x_pred)
        kl_loss <- -0.5 * k_mean(1 + z_log_stddev - k_square(z_mean) - k_exp(z_log_stddev), axis = -1L)
        xent_loss + regularization * kl_loss
    }
    
    model %>% keras::compile(optimizer = optimizer, loss = vae_loss)
    
    
    ### training
    message("training...")
    if (is.null(y_train)) {
        # vanilla vae
        if (!use_generator) {
            # without generator
            validation_data <- NULL
            if (!is.null(x_val)) {
                validation_data <- list(x_val, x_val)
            }
            
            model %>% keras::fit(
                x_train, x_train,
                epochs = epochs, 
                batch_size = batch_size, 
                validation_data = validation_data,
                validation_split = validation_split, ...)
        } else {
            # with generator
            validation_data <- NULL
            validation_steps <- NULL
            if (!is.null(x_val)) {
                validation_data <- DeepPINCS::multiple_sampling_generator(
                    list(x_val), x_val,
                    batch_size = batch_size)
                validation_steps <- ceiling(nrow(x_val)/batch_size)
            }
            
            model %>% keras::fit(
                DeepPINCS::multiple_sampling_generator(
                    list(x_train), x_train,
                    batch_size = batch_size),
                steps_per_epoch = ceiling(nrow(x_train)/batch_size),
                epochs = epochs,
                validation_data = validation_data,
                validation_steps = validation_steps, ...)
        }
    } else {
        # conditional vae
        if (any(unlist(lapply(preprocessing, is.null)))) {
            lenc <- CatEncoders::LabelEncoder.fit(y_train)
            y_train <- CatEncoders::transform(lenc, y_train)
            result$preprocessing$lenc <- lenc
        }
        
        if (!use_generator) {
            # without generator
            validation_data <- NULL
            if (!is.null(x_val)) {
                validation_data <- list(list(x_val, y_val), x_val)
            }
            
            model %>% keras::fit(
                list(x_train, y_train), x_train,
                shuffle = TRUE, 
                epochs = epochs, 
                batch_size = batch_size, 
                validation_data = validation_data,
                validation_split = validation_split, ...)
        } else {
            # with generator
            validation_data <- NULL
            validation_steps <- NULL
            if (!is.null(x_val)) {
                validation_data <- DeepPINCS::multiple_sampling_generator(
                    list(x_val, cbind(y_val)), cbind(x_val),
                    batch_size = batch_size)
                validation_steps <- ceiling(nrow(x_val)/batch_size)
            }
            
            model %>% keras::fit(
                DeepPINCS::multiple_sampling_generator(
                    list(x_train, cbind(y_train)), cbind(x_train),
                    batch_size = batch_size),
                steps_per_epoch = ceiling(nrow(x_train)/batch_size),
                epochs = epochs,
                validation_data = validation_data,
                validation_steps = validation_steps, ...)
        }
    }
    
    rm(z_mean, envir = globalenv())
    rm(z_log_stddev, envir = globalenv())
    
    result$model <- model
    result$encoder <- encoder
    result$decoder <- decoder
    result$preprocessing$x_train <- x_train
    if (!is.null(y_train)) result$preprocessing$y_train <- y_train
    if (!is.null(x_val)) result$preprocessing$x_val <- x_val
    if (!is.null(y_val)) result$preprocessing$y_val <- y_val
    result
}





plot_vae <- function(x, node_color = list(encoder_col = "tomato",
                                        mean_vector_col = "orange",
                                        stddev_vector_col = "lavender",
                                        latent_vector_col = "lightblue",
                                        decoder_col = "palegreen",
                                        condition_col = "gray")) {
    # node information
    layer_name <- x$layers %>%
        purrr::map_chr(~(purrr::`%||%`(purrr::pluck(., "name"), "")))
    layer_type <- purrr::map(x$layers, c(class, 1)) %>%
        unlist() %>%
        strsplit(split = "\\.") %>%
        purrr::map(c(rev, 1)) %>%
        unlist()
    layer_input_shape <- unlist(lapply(x$layers, function(x) {
        ifelse(length(purrr::pluck(x$input_shape, 1)) > 0,
            paste("[", paste(unlist(lapply(x$input_shape,
                function(x) paste("(", toString(paste(x)), ")", sep = ""))),
                    collapse = ", "), "]", sep = ""),
            paste("(", toString(paste(x$input_shape)), ")", sep = ""))
    }))
    layer_input_shape <- gsub("NULL", "None", layer_input_shape)
    layer_output_shape <- unlist(lapply(x$layers, function(x) {
        ifelse(length(purrr::pluck(x$output_shape, 1)) > 0,
            paste("[", paste(unlist(lapply(x$output_shape,
                function(x) paste("(", toString(paste(x)), ")", sep = ""))),
                    collapse = ", "), "]", sep = ""),
            paste("(", toString(paste(x$output_shape)), ")", sep = ""))
    }))
    layer_output_shape <- gsub("NULL", "None", layer_output_shape)
    layer_color <- c(rep(node_color[["encoder_col"]], which(layer_name == "z_mean") - 1),
                    node_color[["mean_vector_col"]],
                    node_color[["stddev_vector_col"]],
                    node_color[["latent_vector_col"]],
                    rep(node_color[["decoder_col"]],
                        length(x$layers) - which(layer_name == "latent")))
    if (any(layer_name == "condition")) {
        layer_color[which(layer_name == "condition")] <- node_color[["condition_col"]]
    }
    node_info <- data.frame(layer_name, layer_type, layer_input_shape, layer_output_shape)
    
    # edge information
    inbound <- lapply(x$layers %>%
                        purrr::map("inbound_nodes") %>%
                        purrr::map(1) %>%
                        purrr::map("inbound_layers"),
                    function(x) switch(length(x), x$name, unlist(x %>% purrr::map("name"))))
    
    if (length(Filter(Negate(is.null), inbound)) == 0) {
        edge_info <- embed(rownames(node_info), dimension = 2)
        from <- edge_info[, 2]
        to <- edge_info[, 1]
        edge_info <- data.frame(from, to, stringsAsFactors = FALSE)
    } else {
        names(inbound) <- layer_name
        inbound <- Filter(Negate(is.null), inbound)
        edge_info <- purrr::imap_dfr(inbound, ~data.frame(from = .x, to = .y, stringsAsFactors = FALSE))
        edge_info$from <- rownames(node_info)[match(edge_info$from, node_info$layer_name)]
        edge_info$to <- rownames(node_info)[match(edge_info$to, node_info$layer_name)]
    }
    
    # plot
    nodes <- paste(unlist(lapply(seq_len(nrow(node_info)),
                                function(x) paste(toString(x),
                                                " [label = '@@",
                                                toString(x),
                                                "', fillcolor = ",
                                                layer_color[x], "]",
                                                sep = ""))), collapse = "")
    names <- paste(unlist(lapply(seq_len(nrow(node_info)), function(x) {
        paste(" [", toString(x), "]: ", "'", node_info$layer_name[x], " ",
            ifelse(node_info$layer_name_sub[x] != "",
                paste("(", node_info$layer_name_sub[x], ")", sep = ""), ""),
            " : ",
            node_info$layer_type[x], " ",
            ifelse(node_info$layer_type_sub[x] != "",
                paste("(", node_info$layer_type_sub[x], ")", sep = ""), ""),
            "|{input: | output:}", "|{", 
            node_info$layer_input_shape[x], "|", node_info$layer_output_shape[x], "}", "'", sep = "")
    })), collapse = "\n")
    edges <- gsub(",", "->", paste(apply(edge_info, 1, toString), collapse = " "))
    DiagrammeR::grViz(paste("digraph{
                            graph [layout = dot]
                            node [shape = record, color = black, style = filled]", 
                            nodes, edges, "} \n", names))
}





gen_exprs <- function(x,
                    num_samples,
                    batch_size,
                    use_generator = FALSE) {
    result <- NULL
    encoder <- x$encoder
    decoder <- x$decoder
    x_train <- x$preprocessing$x_train
    y_train <- x$preprocessing$y_train
    minmax <- x$preprocessing$minmax
    
    x_train_encoded <- predict(encoder, list(x_train, y_train))
    rownames(x_train_encoded) <- rownames(x_train)
    colnames(x_train_encoded) <- paste("latent", seq_len(ncol(x_train_encoded)))
    
    
    ### generating
    message("generating...")
    if (is.null(y_train)) {
        # vanilla vae
        y_gen <- NULL
        model_BIC <- mclust::mclustBIC(x_train_encoded, verbose = FALSE)
        mod <- mclust::mclustModel(x_train_encoded, model_BIC)
        z_sample <- mclust::sim(modelName = mod$modelName, 
                                parameters = mod$parameters, 
                                n = num_samples)
        
        if (!use_generator) {
            x_gen <- predict(decoder, z_sample[,-1])
        } else {
            x_gen <- predict(decoder,
                             DeepPINCS::multiple_sampling_generator(
                                 list(z_sample[,-1]),
                                 batch_size = batch_size,
                                 shuffle = FALSE),
                             steps = ceiling(nrow(z_sample)/batch_size))
        }
        
    } else {
        # conditional vae
        all_labels <- names(table(y_train))
        y_gen <- sample(all_labels, num_samples, replace = TRUE,
                        prob = as.vector(table(y_train)/length(y_train)))
        z_sample <- matrix(0, num_samples, ncol(x_train_encoded))
        for (i in seq_len(length(all_labels))) {
            model_BIC <- mclust::mclustBIC(x_train_encoded[y_train == all_labels[i],],
                                        verbose = FALSE)
            mod <- mclust::mclustModel(x_train_encoded[y_train == all_labels[i],], model_BIC)
            temp_z <- mclust::sim(modelName = mod$modelName, 
                                parameters = mod$parameters, 
                                n = table(y_gen)[i])
            z_sample[y_gen == all_labels[i],] <- temp_z[,-1]
        }
        
        if (!use_generator) {
            x_gen <- predict(decoder, cbind(z_sample, y_gen))
        } else {
            x_gen <- predict(decoder,
                             DeepPINCS::multiple_sampling_generator(
                                 list(cbind(z_sample, y_gen)),
                                 batch_size = batch_size,
                                 shuffle = FALSE),
                             steps = ceiling(nrow(z_sample)/batch_size))
        }
    }
    
    
    ### post-processing
    message("post-processing")
    # denormalization
    if (!is.null(minmax)) {
        x_gen <- gradDescent::minmaxDescaling(data.frame(x_gen), minmax$scalingParameter)
        x_train <- gradDescent::minmaxDescaling(data.frame(x_train), minmax$scalingParameter)
    }
    
    # original label
    if (!is.null(y_train)) {
        lenc <- x$preprocessing$lenc
        y_gen <- CatEncoders::inverse.transform(lenc, as.numeric(y_gen))
        y_train <- CatEncoders::inverse.transform(lenc, y_train)
    }
    
    result$x_gen <- x_gen
    result$y_gen <- y_gen
    result$x_train <- x_train
    result$y_train <- y_train
    result$latent_vector <- x_train_encoded
    result
}





plot_aug <- function(x, plot_fun, ...) {
    x_gen <- x$x_gen
    x_train <- x$x_train
    if (!is.null(x$y_gen)) {
        y_gen <- x$y_gen
        y_train <- x$y_train
    }
    
    # single cell object
    if (!is.null(x$y_gen)) {
        group <- data.frame(c(as.vector(y_train), as.vector(y_gen)),
                            c(rep("real", nrow(x_train)), rep("gen", nrow(x_gen))))
        rownames(group) <- c(rownames(x_train), paste("gen", seq_len(nrow(x_gen)), sep = ""))
        colnames(group) <- c("label", "group")
    } else {
        group <- data.frame(c(rep("real", nrow(x_train)), rep("gen", nrow(x_gen))))
        rownames(group) <- c(rownames(x_train), paste("gen", seq_len(nrow(x_gen)), sep = ""))
        colnames(group) <- "group"
    }
    
    sce <- SingleCellExperiment::SingleCellExperiment(
        assays = list(
            counts = t(as.matrix(rbind(as.matrix(x_train), as.matrix(x_gen))))
        ), 
        colData = group
    )
    sce <- scater::logNormCounts(sce)
    
    # plot
    if (!is.null(x$y_gen)) {
        do.call(eval(parse(text = paste("scater::plot", plot_fun, sep = ""))),
                list(do.call(eval(parse(text = paste("scater::run", plot_fun, sep = ""))),
                            list(sce)), colour_by = "group", shape_by = "label", ...))
    } else {
        do.call(eval(parse(text = paste("scater::plot", plot_fun, sep = ""))),
                list(do.call(eval(parse(text = paste("scater::run", plot_fun, sep = ""))),
                            list(sce)), colour_by = "group", ...))
    }
}
dongminjung/VAExprs documentation built on Dec. 20, 2021, 12:13 a.m.