R/functions_hierarchical.R

Defines functions setup_convolved_lm_hierarchical

setup_convolved_lm_hierarchical = function(.data,
																					 .formula = ~ 1,
																					 .sample = NULL,
																					 .transcript = NULL,
																					 .abundance = NULL,
																					 approximate_posterior = F,
																					 prior_survival_time = c(),
																					 transform_time_function = sqrt,
																					 reference = NULL) {
	
	# At the moment is not active
	levels = 1
	full_bayesian = F
	.n_markers = n_markers
	do_regression = T
	cores = 4
	shards = cores 
	iterations = 800
	sampling_iterations = 200
	model = stanmodels$ARMET_tc_fix_hierarchical
	
	input = c(as.list(environment()))
	input$.formula = .formula
	
	# Get column names
	.sample = enquo(.sample)
	.transcript = enquo(.transcript)
	.abundance = enquo(.abundance)
	col_names = get_sample_transcript_counts(.data, .sample, .transcript, .abundance)
	.sample = col_names$.sample
	.transcript = col_names$.transcript
	.abundance = col_names$.abundance
	
	# Rename columns mix
	.data = .data %>% rename( sample = !!.sample, symbol = !!.transcript ,  count = !!.abundance)
	input$.data = .data
	
	
	# Warning is sensitive names in columns
	names_taken = c("level") 
	if(.data %>% colnames %in% names_taken %>% any) stop(sprintf("ARMET says: your input data frame includes reserved column names: %s", names_taken))
	
	# Check if count is integer
	if(.data %>% select(count) %>% lapply(class) %>% unlist() %>% equals("integer") %>% `!`)
		stop(sprintf("ARMET says: the %s column must be integer as the deconvolution model is Negative Binomial", quo_name(.abundance)))
	
	# Covariate column
	if(do_regression & paste(as.character(.formula), collapse="")  != "~1"){
		formula_df = parse_formula(.formula)
		
		# Censoring column
		if(do_regression && length(formula_df$censored_column) == 1) {
			cens = .data %>% select(sample, formula_df$censored_column) %>% distinct %>% arrange(sample) %>% pull(2)
			
			# Check cens right type
			if(typeof(cens) %in% c("integer", "logical") %>% any %>% `!`) stop("ARMET says: censoring variable should be logical of integer (0,1)")
			if(length(prior_survival_time) == 0) stop("ARMET says: you really need to provide third party survival time for your condition/disease")
			
			sd_survival_months = .data %>%  select(sample, formula_df$censored_value_column) %>% distinct %>% pull(formula_df$censored_value_column) %>% sd
			prior_survival_time = transform_time_function(prior_survival_time %>% when(min(.)==0 ~ (.) + 1, (.))) 
			
			
			time_column = formula_df$censored_value_column 
			
			X =
				model.matrix(
					object = 	formula_df$formula_formatted,
					data = 
						.data %>% 
						select(sample, one_of(formula_df$covariates_formatted)) %>% 
						distinct %>% 
						arrange(sample) %>%
						mutate(!!as.symbol(formula_df$censored_value_column ) := transform_time_function(!!as.symbol(formula_df$censored_value_column )) )
				)
			
			columns_idx_including_time = 
				which(grepl(time_column, colnames(X))) %>% 
				as.array() %>% 
				
				# Fix if NULL
				when(is.null(.) ~ c(), ~ (.))
			
		}
		else{
			X =
				model.matrix(
					object = 	formula_df$formula_formatted,
					data = 
						.data %>% 
						select(sample, one_of(formula_df$covariates_formatted)) %>% 
						distinct %>% 
						arrange(sample) 
				)
			
			cens = NULL
			columns_idx_including_time = array(0)[0]
			
		} 
		
		
		
	}	else {
		formula_df = cens  = NULL	
		columns_idx_including_time = array(0)[0]
		X =
			model.matrix(
				object = 	~ 1,
				data = .data %>% select(sample) %>% distinct %>% arrange(sample)
			)
	}
	
	# Do regression
	#if(length(formula_df$covariates_formatted) > 0 & (formula_df$covariates_formatted %>% is.na %>% `!`)) do_regression = T
	
	# distinct_at is not released yet for dplyr, thus we have to use this trick
	df_for_edgeR <- .data %>%
		
		# Stop if any counts is NA
		error_if_counts_is_na(count) %>%
		
		# Stop if there are duplicated transcripts
		error_if_duplicated_genes(sample,symbol,count) %>%
		
		# Prepare the data frame
		select(symbol,
					 sample,
					 count,
					 one_of(formula_df$covariates_formatted)) %>%
		distinct() %>%
		
		# Check if data rectangular
		ifelse_pipe(
			(.) %>% check_if_data_rectangular(sample,symbol,count, type = "soft") %>% `!` &
				TRUE, #!fill_missing_values,
			~ .x %>% eliminate_sparse_transcripts(symbol)
		) %>%
		
		when(
			do_regression && length(formula_df$censored_column) == 1 ~ 
				mutate(., !!formula_df$censored_value_column := !!as.symbol(formula_df$censored_value_column) / sd_survival_months),
			~ (.)
		)
	
	mix =
		.data %>%
		select(sample, symbol, count, one_of(formula_df$covariates_formatted)) %>%
		distinct() 
	
	
	
	tree = 	data.tree::Clone(ARMET::tree) 
	
	
	# Print overlap descriptive stats
	#get_overlap_descriptive_stats(mix %>% slice(1) %>% gather(symbol, count, -sample), reference)
	
	# Prepare data frames -
	# For Q query first
	# For G house keeing first
	# For GM level 1 first
	
	Q = mix %>% distinct(sample) %>% nrow
	
	reference_filtered =
		reference %>% 
		inner_join(
			tree %>%
				data.tree::ToDataFrameTree("Cell type category", "C", "C1", "C2", "C3", "C4", "isLeaf") %>%
				as_tibble %>%
				rename(cell_type = `Cell type category`) %>% 
				select(-1),
			by = "cell_type"
		)	
	
	tree_propeties = get_tree_properties(tree)
	
	# Find normalisation
	sample_scaling = 
		reference_filtered %>%
		
		mutate(sample = "reference") %>% 
		tidybulk::aggregate_duplicates(sample, symbol, count, aggregation_function = median) %>%
		bind_rows(mix) %>%
		tidybulk::identify_abundant(sample, symbol, count) %>%
		tidybulk::scale_abundance(sample, symbol, count, reference_sample = "reference", action ="get", .subset_for_scaling = .abundant) %>%
		distinct(sample, multiplier) %>%
		mutate(exposure_rate = -log(multiplier)) %>%
		mutate(exposure_multiplier = exp(exposure_rate)) 
	
	
	
	# Default internals
	list(
		internals = list(
			prop = NULL,
			fit = NULL,
			df = NULL,
			prop_posterior = get_null_prop_posterior(tree_propeties$ct_in_nodes),
			alpha = NULL,
			Q = Q,
			reference_filtered = reference_filtered,
			mix = mix,
			X = X,
			cens = cens,
			tree_properties = tree_propeties,
			prior_survival_time = prior_survival_time,
			formula_df = formula_df,
			sample_scaling = sample_scaling,
			columns_idx_including_time = columns_idx_including_time,
			approximate_posterior = approximate_posterior,
			transform_time_function = transform_time_function,
			
			shards = shards,
			levels = levels,
			full_bayesian = full_bayesian,
			iterations = iterations,
			sampling_iterations = sampling_iterations	,
			do_regression = do_regression,
			.formula = .formula,
			model = model
		),
		input = input
	)
	
	
}

get_estimates = function(.data, lev, X) {
	
	.data %>% 
		filter(level ==lev) %>%
		filter(.variable %>% is.na %>% `!`) %>%
		select(level, `Cell type category`, draws) %>% 
		mutate(regression = map(draws,
														~ .x %>%
															group_by(A) %>%
															summarise(.median = median(.value), .sd = sd(.value)) %>% 
															#tidybayes::median_qi(.width = credible_interval) %>%
													
															left_join(tibble(A=1:ncol(X), A_name = colnames(X)) ,  by = "A") %>% 
															select(-A) %>% 
														
															pivot_wider(
																names_from = A_name,
																values_from = c(.median, .sd)
															))) %>% 
		select(-draws) %>% 
		unnest(regression)
	
}


clear_previous_levels = function(.data, my_level){
	# Eliminate previous results
	.data$internals$fit = .data$internals$fit[1:(my_level-1)]
	.data$internals$prop = .data$internals$prop %>% filter(level < !!my_level) 
	if(.data$internals$alpha %>% is.null %>% `!`) .data$internals$alpha = .data$internals$alpha  %>% filter(level < !!my_level)
	.data$internals$draws = .data$internals$draws[1:(my_level-1)]
	
	nodes_to_eliminate = 
		.data$proportions  %>% 
		filter(level >= !!my_level & (.variable %>% is.na %>% `!`)) %>%
		distinct(.variable) %>% 
		pull(.variable) %>%
		gsub("alpha_", "", .) %>%
		sprintf("prop_%s_prior", .)
	
	for(i in nodes_to_eliminate) {
		.data$internals$prop_posterior[[i]] = 1:ncol(.data$internals$prop_posterior[[i]]) %>% matrix(nrow=1) %>% as_tibble() %>% slice(0)
	}
	
	.data$proportions = .data$proportions  %>% filter(level < !!my_level)
	
	.data
}

#' ARMET_tc_continue
#' 
#' @description This function does inference for higher levels of the hierarchy
#' 
#' @param armet_obj An ARMEt object
#' @param level An integer
#' @param model A stan model
#' 
#' @export
ARMET_tc_continue = function(armet_obj, level, model = stanmodels$ARMET_tc_fix_hierarchical){
	
	armet_obj= clear_previous_levels(armet_obj,  level)
	internals = armet_obj$internals
	input = armet_obj$input
	
	
	res = run_model(
		reference_filtered = internals$reference_filtered,
		mix = internals$mix,
		shards = input$cores,
		lv = level,
		full_bayesian = input$full_bayesian,
		internals$prop_posterior,
		iterations = input$iterations,
		sampling_iterations = input$sampling_iterations	,
		X = internals$X,
		do_regression = input$do_regression,
		cens = internals$cens,
		tree_properties = internals$tree_properties,
		Q = internals$Q,
		model = model,
		prior_survival_time = internals$prior_survival_time,
		sample_scaling = internals$sample_scaling,
		prior_prop = 
			internals$prop %>% 
			filter(level == !!level -1) %>% 
			distinct(Q, C, .value) %>%
			spread(C, .value) %>% 
			as_matrix(rownames = Q),
		columns_idx_including_time = internals$columns_idx_including_time,
		approximate_posterior = internals$approximate_posterior
	)
	
	df = res[[1]]
	fit = res[[2]]
	
	fit_prop_parsed = 
		fit %>%
		draws_to_tibble("prop_", "Q", "C") %>%
		filter(!grepl("_UFO|_rng|_logit", .variable))  %>%
		mutate(Q = Q %>% as.integer)
	
	draws = get_draws(fit_prop_parsed, level, internals)	
	
	prop = get_props(draws, level, df, input$approximate_posterior)	
	
	internals$prop = bind_rows(internals$prop , prop) 
	internals$fit = internals$fit %>% c(list(fit))
	internals$df = internals$df %>% c(list(df))
	internals$prop_posterior[sprintf("%s_prior", fit_prop_parsed %>% distinct(.variable) %>% pull())] = fit_prop_parsed %>% group_by(Q, C, .variable) %>% prop_to_list
	internals$draws = internals$draws %>% c(list(draws))
	
	if (input$do_regression && paste(as.character(input$.formula), collapse="")  != "~1" )
		internals$alpha = internals$alpha  %>% bind_rows( 
			get_alpha(fit, level) %>% 
				left_join(
					get_generated_quantities_standalone(fit, level, internals),
					by = c("node", "C")
					
				)
		)
	
	# Return
	list(
		# Matrix of proportions
		proportions =	
			input$.data %>%
			select(c(sample, (.) %>% get_specific_annotation_columns(sample))) %>%
			distinct() %>%
			left_join(internals$prop) %>%
			
			# Attach alpha if regression
			ifelse_pipe(
				input$do_regression && paste(as.character(input$.formula), collapse="")  != "~1" ,
				~ .x %>%
					nest(proportions = -c(`Cell type category`, C, level)) %>%
					left_join(
						internals$alpha %>%	select(`Cell type category`, contains("alpha"), level, draws, rng_prop, rng_mu, .variable, one_of("Rhat")),
						by = c("Cell type category", "level")
					)
			),
		
		# # Return the input itself
		input = input,
		
		# Return the fitted object
		internals = internals
	)
	
}

run_model = function(reference_filtered,
										 mix,
										 shards,
										 lv,
										 full_bayesian,
										 approximate_posterior,
										 prop_posterior,
										 iterations = 250,
										 sampling_iterations = 100,
										 X,
										 do_regression,
										 cens,
										 tree_properties,
										 Q,
										 model = stanmodels$ARMET_tc_fix_hierarchical,
										 prior_survival_time = c(),
										 sample_scaling,
										 prior_prop = matrix(1:Q)[,0, drop=FALSE],
										 columns_idx_including_time) {
	
	
	# Global properties - derived by previous analyses of the whole reference dataset
	sigma_intercept = 1.3420415
	sigma_slope = -0.3386389
	sigma_sigma = 1.1720851
	lambda_mu_mu = 5.612671
	lambda_sigma = 7.131593
	
	# Non centred
	lambda_mu_prior = c(6.2, 1)
	lambda_sigma_prior =  c(3.3 , 1)
	lambda_skew_prior =  c(-2.7, 1)
	sigma_intercept_prior = c(1.9 , 0.1)
	
	# Filter on level considered
	my_genes = reference_filtered %>% filter(level == lv ) %>% filter(is_marker) %>% pull(symbol) %>% unique()
	reference_filtered = reference_filtered %>% filter(level == lv | (level < lv & isLeaf)) %>% filter(symbol %in% my_genes)
	
	df = ref_mix_format(reference_filtered, mix)
	
	GM = df  %>% distinct(symbol) %>% nrow()
	
	y_source =
		df %>%
		filter(`query`) %>%
		select(S, Q, symbol, count, GM, sample) 
	# left_join(
	# 	df %>% filter(!query) %>% distinct(
	# 		symbol,
	# 		G,
	# 		`Cell type category`,
	# 		level,
	# 		lambda_log,
	# 		sigma_inv_log,
	# 		GM,
	# 		C
	# 	),
	# 	by = c("symbol", "GM")
	# ) %>%
	#arrange(C, Q, symbol) %>%
	#mutate(`Cell type category` = factor(`Cell type category`, unique(`Cell type category`)))
	
	# Dirichlet regression
	A = X %>% ncol
	
	# library(rstan)
	# fileConn<-file("~/.R/Makevars")
	# writeLines(c( "CXX14FLAGS += -O2","CXX14FLAGS += -DSTAN_THREADS", "CXX14FLAGS += -pthread"), fileConn)
	# close(fileConn)
	# ARMET_tc_model = rstan::stan_model("~/PhD/deconvolution/ARMET/inst/stan/ARMET_tc_fix.stan", auto_write = F)
	
	exposure_multiplier = 
		sample_scaling %>% 
		filter(sample %in% (y_source %>% pull(sample))) %>% 
		arrange(sample) %>% 
		pull(exposure_multiplier) %>%
		as.array()
	
	
	# Setup for exposure inference
	# 
	# df_for_exposure = 
	# 	df %>%
	# 	filter(`query` &  `house keeping`)  %>%
	# 	distinct(Q, symbol, count) %>%
	# 	left_join(
	# 		df %>%
	# 			filter(!`query` & `house keeping`)  %>%
	# 			mutate(reference_count = exp(lambda_log)) %>%
	# 			distinct(symbol, reference_count )
	# 	)
	# 
	# nrow_for_exposure = nrow(df_for_exposure)
	# Q_for_exposure = df_for_exposure$Q
	# reference_for_exposure = df_for_exposure %>% pull(reference_count)
	# counts_for_exposure = df_for_exposure %>% pull(count)
	
	init_list = list(	lambda_UFO = rep(6.2, GM)	) 
	
	ref = 
		df %>%
		
		# Eliminate the query part, not the house keeping of the query
		filter(!`query`)  %>%
		
		select(C, GM, count ) %>% 
		distinct() %>%
		arrange(C, GM) %>% 
		spread(GM, count) %>% 
		as_matrix(rownames = "C") 
	
	y = 
		y_source %>%
		select(Q, GM, count) %>% 
		distinct() %>%
		arrange(Q, GM) %>% 
		spread(GM, count) %>% 
		as_matrix(rownames = "Q") 
	
	max_y = max(y)
	ct_in_ancestor_level = ifelse(lv == 1, 0, tree_properties$ct_in_levels[lv-1])
	
	Sys.setenv("STAN_NUM_THREADS" = shards)
	
	if(cens %>% is.null) cens =  rep(0, Q)
	which_cens = which(cens == 1)  %>% as.array()
	which_not_cens = which(cens == 0) %>% as.array()
	how_many_cens = length(which_cens)
	
	max_unseen = ifelse(how_many_cens>0, max(X[,2]), 0 )
	if(is.null(prior_survival_time)) prior_survival_time = array(1)[0]
	spt = length(prior_survival_time)
	
	CIT = length(columns_idx_including_time)
	
	fit = 
		approximate_posterior %>%
		when(
			(.) ~ vb_iterative(model,
												 # rstan::stan_model("~/PhD/deconvolution/ARMET/inst/stan/ARMET_tc_fix_hierarchical.stan", auto_write = F),
												 iter = 50000,
												 tol_rel_obj = 0.0005,
												 data = prop_posterior %>% c(tree_properties),
												 init = function () init_list
			),
			
			~ 	sampling(
				model,
				#rstan::stan_model("~/PhD/deconvolution/ARMET/inst/stan/ARMET_tc_fix_hierarchical.stan", auto_write = F),
				chains = 3,
				cores = 3,
				iter = iterations,
				warmup = iterations - sampling_iterations,
				data = prop_posterior %>% c(tree_properties),
				# pars=
				# 	c("prop_1", "prop_2", "prop_3", sprintf("prop_%s", letters[1:9])) %>%
				# 	c("alpha_1", sprintf("alpha_%s", letters[1:9])) %>%
				# 	c("exposure_rate") %>%
				# 	c("lambda_UFO") %>%
				# 	c("prop_UFO") %>%
				# 	c(additional_par_to_save),
				init = function ()	init_list,
				save_warmup = FALSE
				# ,
				# control=list( adapt_delta=0.9,stepsize = 0.01,  max_treedepth =10  )
			) %>%
				{
					(.)  %>% rstan::summary() %$% summary %>% as_tibble(rownames = "par") %>% arrange(Rhat %>% desc) %>% filter(Rhat > 1.5) %>% ifelse_pipe(nrow(.) > 0, ~ .x %>% print)
					(.)
				}
		)
	
	list(df, fit)
	
}



#' add_cox_test
#' 
#' @description This function adds cox regression statistics to the fit object
#' 
#' @param .data A tibble
#' @param relative A boolean
#' 
#' @export
add_cox_test = function(.data, relative = TRUE){
	
	cens_alpha = 
		.data$proportions %>% 
		select(-draws, -contains("rng")) %>%
		rename(node = .variable)  %>% 
		unnest(proportions) %>%
		censored_regression_joint(formula_df = .data$internals$formula_df, filter_how_many = Inf, relative = relative, transform_time_function = .data$internals$transform_time_function)  %>% 
		rename(.variable = node) %>%
		nest(draws_cens = -c(level, .variable  ,      C)) 
	
	.data$proportions %>%
		filter(.variable %>% is.na %>% `!`) %>%
		left_join(cens_alpha, by = c("level", "C", ".variable"))
}




#------------------------------------#
run_lv_1 = function(internals,
										shards,
										level = 1,
										full_bayesian,
										approximate_posterior,
										iterations = iterations,
										sampling_iterations = sampling_iterations,
										do_regression = do_regression,
										.formula = .formula, model = stanmodels$ARMET_tc_fix_hierarchical){
	res1 = run_model( 
		internals$reference_filtered,
		internals$mix,
		shards,
		level,
		full_bayesian,
		approximate_posterior,
		internals$prop_posterior,
		iterations = iterations,
		sampling_iterations = sampling_iterations,
		X = internals$X,
		do_regression = do_regression,
		cens = internals$cens,
		tree_properties = internals$tree_properties,
		Q = internals$Q,
		model = model,
		prior_survival_time = internals$prior_survival_time,
		sample_scaling = internals$sample_scaling,
		columns_idx_including_time = internals$columns_idx_including_time
		
	)
	
	df = res1[[1]]
	fit = res1[[2]]
	
	fit_prop_parsed = 
		fit %>%
		draws_to_tibble("prop_", "Q", "C") %>%
		filter(!grepl("_UFO|_rng", .variable))  %>%
		mutate(Q = Q %>% as.integer)
	
	draws =
		fit_prop_parsed %>%
		ungroup() %>%
		select(-.variable) %>%
		mutate(.value_relative = .value)
	
	prop =
		fit %>%
		draws_to_tibble("prop_1", "Q", "C") %>%
		#tidybayes::gather_draws(`prop_[1]`[Q, C], regex = T) %>%
		drop_na  %>%
		ungroup() %>%
		
		# Add relative proportions
		mutate(.value_relative = .value) %>%
		
		# Add tree information
		left_join(
			tree %>% data.tree::ToDataFrameTree("name", "C1", "C2", "C3", "C4") %>%
				as_tibble %>%
				select(-1) %>%
				rename(`Cell type category` = name) %>%
				gather(level, C, -`Cell type category`) %>%
				mutate(level = gsub("C", "", level)) %>%
				filter(level == 1) %>%
				drop_na %>%
				mutate(C = C %>% as.integer, level = level %>% as.integer)
		) %>%
		
		# add sample annotation
		left_join(df %>% distinct(Q, sample), by = "Q")	%>%
		
		# If MCMC is used check divergences as well
		ifelse_pipe(
			!approximate_posterior,
			~ .x %>% parse_summary_check_divergence(),
			~ .x %>% parse_summary() %>% rename(.value = mean)
		) %>%
		
		# Parse
		separate(.variable, c(".variable", "level"), convert = T) %>%
		
		# Add sample information
		left_join(df %>%
								filter(`query`) %>%
								distinct(Q, sample))
	
	
	
	
	if (do_regression) # && paste(as.character(.formula), collapse="")  != "~1" ) 
		internals$alpha = 
		get_alpha(fit, level) %>% 
		left_join(
			get_generated_quantities_standalone(fit, level, internals),
			by = c("node", "C")
			
		)
	
	
	internals$prop = prop
	internals$fit = list(fit)
	internals$df = list(df)
	internals$draws = list(draws)
	internals$prop_posterior[[1]] = fit_prop_parsed %>% group_by(.variable, Q, C) %>% prop_to_list %>% `[[` ("prop_1") 
	
	internals
	
	
}




get_theoretical_data_disrtibution = function(fit){
	
	
	m2 <- rstan::stan_model(file = "inst/stan/generated_quantities_lv1.stan")
	
	
	# # If those nodes are not in fit add them otherwise generate quantities fails
	# missing_columns = 
	# 	c("1", letters[1:11]) %>%
	# 	imap(
	# 		~ sprintf("alpha_%s", .x) %>%
	# 			grep(colnames(as.matrix(fit))) %>%
	# 			when(length(.)== 0 ~ {
	# 				add_col = matrix(rep(0, nrow(as.matrix(fit)) * (tree_properties$ct_in_nodes[.y]-1) * A  ), nrow = nrow(as.matrix(fit)) )
	# 				colnames(add_col) = sprintf("alpha_%s[%s]", .x, apply(expand.grid( 1:A, 1:(tree_properties$ct_in_nodes[.y]-1)), 1, paste, collapse=","))
	# 					 
	# 					 add_col
	# 			})
	# 	) %>%
	# 	do.call(cbind,.)
	
	
}




setup_convolved_lm_hierarchical = function(.data,
																					 .formula = ~ 1,
																					 .sample = NULL,
																					 .transcript = NULL,
																					 .abundance = NULL,
																					 approximate_posterior = F,
																					 prior_survival_time = c(),
																					 transform_time_function = sqrt,
																					 reference = NULL) {
	
	# At the moment is not active
	levels = 1
	full_bayesian = F
	.n_markers = n_markers
	do_regression = T
	cores = 4
	shards = cores 
	iterations = 800
	sampling_iterations = 200
	model = stanmodels$ARMET_tc_fix_hierarchical
	
	input = c(as.list(environment()))
	input$.formula = .formula
	
	# Get column names
	.sample = enquo(.sample)
	.transcript = enquo(.transcript)
	.abundance = enquo(.abundance)
	col_names = get_sample_transcript_counts(.data, .sample, .transcript, .abundance)
	.sample = col_names$.sample
	.transcript = col_names$.transcript
	.abundance = col_names$.abundance
	
	# Rename columns mix
	.data = .data %>% rename( sample = !!.sample, symbol = !!.transcript ,  count = !!.abundance)
	input$.data = .data
	
	
	# Warning is sensitive names in columns
	names_taken = c("level") 
	if(.data %>% colnames %in% names_taken %>% any) stop(sprintf("ARMET says: your input data frame includes reserved column names: %s", names_taken))
	
	# Check if count is integer
	if(.data %>% select(count) %>% lapply(class) %>% unlist() %>% equals("integer") %>% `!`)
		stop(sprintf("ARMET says: the %s column must be integer as the deconvolution model is Negative Binomial", quo_name(.abundance)))
	
	# Covariate column
	if(do_regression & paste(as.character(.formula), collapse="")  != "~1"){
		formula_df = parse_formula(.formula)
		
		# Censoring column
		if(do_regression && length(formula_df$censored_column) == 1) {
			cens = .data %>% select(sample, formula_df$censored_column) %>% distinct %>% arrange(sample) %>% pull(2)
			
			# Check cens right type
			if(typeof(cens) %in% c("integer", "logical") %>% any %>% `!`) stop("ARMET says: censoring variable should be logical of integer (0,1)")
			if(length(prior_survival_time) == 0) stop("ARMET says: you really need to provide third party survival time for your condition/disease")
			
			sd_survival_months = .data %>%  select(sample, formula_df$censored_value_column) %>% distinct %>% pull(formula_df$censored_value_column) %>% sd
			prior_survival_time = transform_time_function(prior_survival_time %>% when(min(.)==0 ~ (.) + 1, (.))) 
			
			
			time_column = formula_df$censored_value_column 
			
			X =
				model.matrix(
					object = 	formula_df$formula_formatted,
					data = 
						.data %>% 
						select(sample, one_of(formula_df$covariates_formatted)) %>% 
						distinct %>% 
						arrange(sample) %>%
						mutate(!!as.symbol(formula_df$censored_value_column ) := transform_time_function(!!as.symbol(formula_df$censored_value_column )) )
				)
			
			columns_idx_including_time = 
				which(grepl(time_column, colnames(X))) %>% 
				as.array() %>% 
				
				# Fix if NULL
				when(is.null(.) ~ c(), ~ (.))
			
		}
		else{
			X =
				model.matrix(
					object = 	formula_df$formula_formatted,
					data = 
						.data %>% 
						select(sample, one_of(formula_df$covariates_formatted)) %>% 
						distinct %>% 
						arrange(sample) 
				)
			
			cens = NULL
			columns_idx_including_time = array(0)[0]
			
		} 
		
		
		
	}	else {
		formula_df = cens  = NULL	
		columns_idx_including_time = array(0)[0]
		X =
			model.matrix(
				object = 	~ 1,
				data = .data %>% select(sample) %>% distinct %>% arrange(sample)
			)
	}
	
	# Do regression
	#if(length(formula_df$covariates_formatted) > 0 & (formula_df$covariates_formatted %>% is.na %>% `!`)) do_regression = T
	
	# distinct_at is not released yet for dplyr, thus we have to use this trick
	df_for_edgeR <- .data %>%
		
		# Stop if any counts is NA
		error_if_counts_is_na(count) %>%
		
		# Stop if there are duplicated transcripts
		error_if_duplicated_genes(sample,symbol,count) %>%
		
		# Prepare the data frame
		select(symbol,
					 sample,
					 count,
					 one_of(formula_df$covariates_formatted)) %>%
		distinct() %>%
		
		# Check if data rectangular
		ifelse_pipe(
			(.) %>% check_if_data_rectangular(sample,symbol,count, type = "soft") %>% `!` &
				TRUE, #!fill_missing_values,
			~ .x %>% eliminate_sparse_transcripts(symbol)
		) %>%
		
		when(
			do_regression && length(formula_df$censored_column) == 1 ~ 
				mutate(., !!formula_df$censored_value_column := !!as.symbol(formula_df$censored_value_column) / sd_survival_months),
			~ (.)
		)
	
	mix =
		.data %>%
		select(sample, symbol, count, one_of(formula_df$covariates_formatted)) %>%
		distinct() 
	
	
	
	tree = 	data.tree::Clone(ARMET::tree) 
	
	
	# Print overlap descriptive stats
	#get_overlap_descriptive_stats(mix %>% slice(1) %>% gather(symbol, count, -sample), reference)
	
	# Prepare data frames -
	# For Q query first
	# For G house keeing first
	# For GM level 1 first
	
	Q = mix %>% distinct(sample) %>% nrow
	
	reference_filtered =
		reference %>% 
		inner_join(
			tree %>%
				data.tree::ToDataFrameTree("Cell type category", "C", "C1", "C2", "C3", "C4", "isLeaf") %>%
				as_tibble %>%
				rename(cell_type = `Cell type category`) %>% 
				select(-1),
			by = "cell_type"
		)	
	
	tree_propeties = get_tree_properties(tree)
	
	# Find normalisation
	sample_scaling = 
		reference_filtered %>%
		
		mutate(sample = "reference") %>% 
		tidybulk::aggregate_duplicates(sample, symbol, count, aggregation_function = median) %>%
		bind_rows(mix) %>%
		tidybulk::identify_abundant(sample, symbol, count) %>%
		tidybulk::scale_abundance(sample, symbol, count, reference_sample = "reference", action ="get", .subset_for_scaling = .abundant) %>%
		distinct(sample, multiplier) %>%
		mutate(exposure_rate = -log(multiplier)) %>%
		mutate(exposure_multiplier = exp(exposure_rate)) 
	
	
	
	# Default internals
	list(
		internals = list(
			prop = NULL,
			fit = NULL,
			df = NULL,
			prop_posterior = get_null_prop_posterior(tree_propeties$ct_in_nodes),
			alpha = NULL,
			Q = Q,
			reference_filtered = reference_filtered,
			mix = mix,
			X = X,
			cens = cens,
			tree_properties = tree_propeties,
			prior_survival_time = prior_survival_time,
			formula_df = formula_df,
			sample_scaling = sample_scaling,
			columns_idx_including_time = columns_idx_including_time,
			approximate_posterior = approximate_posterior,
			transform_time_function = transform_time_function,
			
			shards = shards,
			levels = levels,
			full_bayesian = full_bayesian,
			iterations = iterations,
			sampling_iterations = sampling_iterations	,
			do_regression = do_regression,
			.formula = .formula,
			model = model
		),
		input = input
	)
	
	
}

#' estimate_convoluted_lm_1
#' 
#' @description This function does inference for higher levels of the hierarchy
#' 
#' @rdname estimate_convoluted_lm
#' @name estimate_convoluted_lm_1
#' 
#' @param armet_obj An ARMET object
#' 
#' @export
estimate_convoluted_lm_1 = function(armet_obj){
	
	level = 1 
	
	internals = 
		run_lv_1(
			armet_obj$internals,
			armet_obj$internals$shards,
			armet_obj$internals$levels,
			armet_obj$internals$full_bayesian,
			armet_obj$internals$approximate_posterior,
			iterations = armet_obj$internals$iterations,
			sampling_iterations = armet_obj$internals$sampling_iterations	,
			do_regression = armet_obj$internals$do_regression,
			.formula = armet_obj$internals$.formula,
			model = armet_obj$internals$model
		)
	
	proportions =	
		armet_obj$input$.data %>%
		select(c(sample, (.) %>% get_specific_annotation_columns(sample))) %>%
		distinct() %>%
		left_join(internals$prop) %>%
		
		# Attach alpha if regression
		ifelse_pipe(
			internals$do_regression, # && paste(as.character(internals$.formula), collapse="")  != "~1" ,
			~ .x %>%
				nest(proportions = -c(`Cell type category`, C, level)) %>%
				left_join(
					internals$alpha %>%	select(`Cell type category`, contains("alpha"), level, draws, rng_prop, rng_mu, .variable, one_of("Rhat")),
					by = c("Cell type category", "level")
				)
		)
	
	
	attrib = 
		list(
			# Matrix of proportions
			proportions = proportions,
			
			# # Return the input itself
			input = armet_obj$input,
			
			# Return the fitted object
			internals = internals
		)
	
	proportions %>% 
		get_estimates(level, 	X = attrib$internals$X) %>% 
		add_attr(attrib, "full_results")
}

#' estimate_convoluted_lm_2
#' 
#' @description This function does inference for higher levels of the hierarchy
#' 
#' 
#' @rdname estimate_convoluted_lm
#' @name estimate_convoluted_lm_2
#' 
#' @param armet_obj An ARMET object
#' 
#' @export
estimate_convoluted_lm_2 = function(armet_obj){
	
	level = 2
	
	attrib = attr(armet_obj, "full_results")
	
	attrib = ARMET_tc_continue(attrib, level, model = attrib$internals$model)
	
	armet_obj %>% 
		bind_rows(
			attrib$proportions 
			#%>% 
			#	get_estimates(level, 	X = attrib$internals$X) 
		) %>% 
		
		# Add back update attributes
		add_attr(attrib, "full_results")
	
}

#' estimate_convoluted_lm_3
#' 
#' @description This function does inference for higher levels of the hierarchy
#' 
#' 
#' @rdname estimate_convoluted_lm
#' @name estimate_convoluted_lm_3
#' 
#' @param armet_obj An ARMET object
#' 
#' @export
estimate_convoluted_lm_3 = function(armet_obj){
	
	level = 3
	
	attrib = attr(armet_obj, "full_results")
	
	attrib = ARMET_tc_continue(attrib, level, model = attrib$internals$model)
	
	armet_obj %>% 
		bind_rows(
			attrib$proportions 
			#%>% 
			#	get_estimates(level, 	X = attrib$internals$X) 
		) %>% 
		
		# Add back update attributes
		add_attr(attrib, "full_results")	
}

#' estimate_convoluted_lm_4
#' 
#' @description This function does inference for higher levels of the hierarchy
#' 
#' 
#' @rdname estimate_convoluted_lm
#' @name estimate_convoluted_lm_4
#' 
#' @param armet_obj An ARMET object
#' 
#' @export
estimate_convoluted_lm_4 = function(armet_obj){
	
	level = 4
	
	attrib = attr(armet_obj, "full_results")
	
	attrib = ARMET_tc_continue(attrib, level, model = attrib$internals$model)
	
	armet_obj %>% 
		bind_rows(
			attrib$proportions 
			#%>% 
			#	get_estimates(level, 	X = attrib$internals$X)
		) %>% 
		
		# Add back update attributes
		add_attr(attrib, "full_results")	
}
stemangiola/ARMET documentation built on July 9, 2022, 1:25 a.m.