R/assertions.R

Defines functions is_emb_dim is_dropout_rate is_l2_penalty is_activation is_vectorizable is_n_inputs is_mlp_units is_positive_vector

is_positive_vector <- function(x) {
	assertthat::assert_that(
		is.numeric(x),
		msg = paste0(deparse(match.call()$x), " must be numeric.")
	)
	assertthat::assert_that(
		assertthat::noNA(x),
		all(x > 0),
		msg = paste0(
			deparse(match.call()$x), " entries must be positive."
			)
		)

	return(TRUE)
}

is_mlp_units <- function(x) {
	assertthat::assert_that(is_positive_vector(x))
	assertthat::assert_that(
		all(x == as.integer(x)),
		msg = paste(deparse(match.call()$x), "values must be integer.")
		)
	return(TRUE)
}

is_n_inputs <- function(x) {
	assertthat::assert_that(is_positive_vector(x))
	assertthat::assert_that(
		all(x == as.integer(x)),
		msg = paste(deparse(match.call()$x), "values must be integer.")
	)
	assertthat::assert_that(
		length(x) == 2,
		msg = paste(deparse(match.call()$x), "must have length 2."))
	return(TRUE)
}


is_vectorizable <- function(x, len) {
	assertthat::assert_that(
		length(x) %in% c(1, len),
		msg = paste0(
			deparse(match.call()$x),
			" must have length either 1 or ", len, "."
			)
		)
	return(TRUE)
}

is_activation <- function(x, len = NULL) {
	if (is.null(x))
		return(TRUE)
	assertthat::assert_that(
		is.character(x) || is.list(x),
		msg = paste0(
			deparse(match.call()$x),
			" must be either a character vector or a list.")
	)
	if (!is.null(len))
		assertthat::assert_that(is_vectorizable(x, len))

	for (i in seq_along(x))
		tryCatch(tensorflow::tf$keras$activations$get(x[[i]]),
			 error = function(cnd) {
			 	msg <- paste0(deparse(match.call()$x), "[[", i,
			 		      "]] is not a valid activation ",
			 		      "function.")
				assertthat::assert_that(FALSE, msg = msg)
				}
			 )

	return(TRUE)
}

is_l2_penalty <- function(x, len = NULL) {
	if (is.null(x))
		return(TRUE)
	assertthat::assert_that(is_positive_vector(x))
	if (!is.null(len))
		assertthat::assert_that(is_vectorizable(x, len))

	return(TRUE)
}

is_dropout_rate <- function(x, len = NULL) {
	if (is.null(x))
		return(TRUE)
	assertthat::assert_that(is_positive_vector(x))
	if (!is.null(len))
		assertthat::assert_that(is_vectorizable(x, len))

	assertthat::assert_that(
		all(0 <= x & x <= 1),
		msg = paste(deparse(match.call()$x), "must be between 0 and 1.")
	)

	return(TRUE)
}

is_emb_dim <- function(x, len = NULL) {
	assertthat::assert_that(is_positive_vector(x))
	if (!is.null(len))
		assertthat::assert_that(is_vectorizable(x, len))
	assertthat::assert_that(
		all(x == as.integer(x)),
		msg = paste(deparse(match.call()$x), "values must be integer.")
	)
	return(TRUE)
}
vgherard/neuralcf documentation built on Dec. 23, 2021, 3:08 p.m.