Nothing
# R/model-maxvit.R - torch compatible MaxViT implementation
#' @importFrom torch nn_module nn_conv2d nn_gelu nn_batch_norm2d nn_sequential
maxvit_conv_norm_act <- function(in_channels, mid_channels = 64, out_channels = 64, ...) {
nn_sequential(
"stem.0" = nn_sequential(
"0" = nn_conv2d(in_channels, mid_channels, ..., bias = FALSE),
"1" = nn_batch_norm2d(mid_channels, track_running_stats = TRUE),
"2" = nn_gelu()
),
"stem.1" = nn_sequential(
"0" = nn_conv2d(mid_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = TRUE)
)
)
}
#' @importFrom torch nn_module nn_conv2d nnf_gelu torch_sigmoid
squeeze_excitation <- nn_module(
initialize = function(in_channels, squeeze_factor = 16) {
squeeze_channels <- max(1L, as.integer(in_channels / squeeze_factor))
self$fc1 <- nn_conv2d(in_channels, squeeze_channels, 1)
self$fc2 <- nn_conv2d(squeeze_channels, in_channels, 1)
},
forward = function(x) {
scale <- x$mean(dim = c(3, 4), keepdim = TRUE)
scale <- self$fc1(scale)
scale <- nnf_gelu(scale)
scale <- self$fc2(scale)
scale <- torch_sigmoid(scale)
x * scale
}
)
# Simplified relative position implementation
#' @importFrom torch torch_arange torch_zeros torch_long
create_relative_position_params <- function(window_size, heads) {
# Create a simple relative position index
coords <- torch_arange(window_size * window_size)
relative_position_index <- torch_zeros(window_size * window_size, window_size * window_size, dtype = torch_long())
# Create bias table
num_relative_distance <- (2 * window_size - 1) * (2 * window_size - 1)
relative_position_bias_table <- torch_zeros(num_relative_distance, heads)
list(
relative_position_bias_table = relative_position_bias_table,
relative_position_index = relative_position_index
)
}
#' @importFrom torch nn_module nn_module_dict nn_layer_norm nn_linear nn_gelu nn_parameter torch_chunk
#' @importFrom torch torch_matmul nnf_softmax
window_attention <- nn_module(
initialize = function(dim, heads = 2, window_size = 7) {
self$dim <- dim
self$heads <- heads
self$window_size <- window_size
# Layer structure to match expected keys
self$attn_layer <- nn_module_dict(list(
"0" = nn_layer_norm(dim), # norm1
"1" = nn_module_dict(list(
# Custom attention implementation
"to_qkv" = nn_linear(dim, dim * 3, bias = TRUE),
"merge" = nn_linear(dim, dim, bias = TRUE)
))
))
# Add relative position parameters as buffers/parameters
rel_pos_params <- create_relative_position_params(window_size, heads)
self$attn_layer[["1"]][["relative_position_bias_table"]] <- nn_parameter(rel_pos_params$relative_position_bias_table)
self$register_buffer("attn_layer.1.relative_position_index", rel_pos_params$relative_position_index)
self$mlp_layer <- nn_module_dict(list(
"0" = nn_layer_norm(dim), # norm2
"1" = nn_linear(dim, dim * 4),
"2" = nn_gelu(),
"3" = nn_linear(dim * 4, dim)
))
},
forward = function(x) {
b <- x$size(1); c <- x$size(2); h <- x$size(3); w <- x$size(4)
x <- x$permute(c(1, 3, 4, 2))$reshape(c(b, h * w, c))
# Attention path
normed <- self$attn_layer[["0"]](x)
qkv <- self$attn_layer[["1"]][["to_qkv"]](normed)
qkv_chunks <- torch_chunk(qkv, 3, dim = -1)
q <- qkv_chunks[[1]]
k <- qkv_chunks[[2]]
v <- qkv_chunks[[3]]
# Simplified attention computation
scale <- 1.0 / sqrt(c / self$heads)
attn <- torch_matmul(q, k$transpose(-2, -1)) * scale
attn <- nnf_softmax(attn, dim = -1)
out <- torch_matmul(attn, v)
out <- self$attn_layer[["1"]][["merge"]](out)
x <- x + out
# MLP path
normed <- self$mlp_layer[["0"]](x)
mlp_out <- self$mlp_layer[["1"]](normed)
mlp_out <- self$mlp_layer[["2"]](mlp_out)
mlp_out <- self$mlp_layer[["3"]](mlp_out)
x <- x + mlp_out
x <- x$reshape(c(b, h, w, c))$permute(c(1, 4, 2, 3))
x
}
)
#' @importFrom torch nn_module nn_module_dict nn_layer_norm nn_linear nn_gelu nn_parameter torch_chunk
grid_attention <- nn_module(
initialize = function(dim, heads = 2, grid_size = 7) {
self$dim <- dim
self$heads <- heads
self$grid_size <- grid_size
# Layer structure to match expected keys
self$attn_layer <- nn_module_dict(list(
"0" = nn_layer_norm(dim), # norm1
"1" = nn_module_dict(list(
# Custom attention implementation
"to_qkv" = nn_linear(dim, dim * 3, bias = TRUE),
"merge" = nn_linear(dim, dim, bias = TRUE)
))
))
# Add relative position parameters as buffers/parameters
rel_pos_params <- create_relative_position_params(grid_size, heads)
self$attn_layer[["1"]][["relative_position_bias_table"]] <- nn_parameter(rel_pos_params$relative_position_bias_table)
self$register_buffer("attn_layer.1.relative_position_index", rel_pos_params$relative_position_index)
self$mlp_layer <- nn_module_dict(list(
"0" = nn_layer_norm(dim), # norm2
"1" = nn_linear(dim, dim * 4),
"2" = nn_gelu(),
"3" = nn_linear(dim * 4, dim)
))
},
forward = function(x) {
b <- x$size(1); c <- x$size(2); h <- x$size(3); w <- x$size(4)
gs <- self$grid_size
gh <- h %/% gs; gw <- w %/% gs
x <- x$reshape(c(b, c, gh, gs, gw, gs))$permute(c(1, 3, 5, 4, 6, 2))$reshape(c(b * gh * gw, gs * gs, c))
# Attention path
normed <- self$attn_layer[["0"]](x)
qkv <- self$attn_layer[["1"]][["to_qkv"]](normed)
qkv_chunks <- torch_chunk(qkv, 3, dim = -1)
q <- qkv_chunks[[1]]
k <- qkv_chunks[[2]]
v <- qkv_chunks[[3]]
# Simplified attention computation
scale <- 1.0 / sqrt(c / self$heads)
attn <- torch_matmul(q, k$transpose(-2, -1)) * scale
attn <- nnf_softmax(attn, dim = -1)
out <- torch_matmul(attn, v)
out <- self$attn_layer[["1"]][["merge"]](out)
x <- x + out
# MLP path
normed <- self$mlp_layer[["0"]](x)
mlp_out <- self$mlp_layer[["1"]](normed)
mlp_out <- self$mlp_layer[["2"]](mlp_out)
mlp_out <- self$mlp_layer[["3"]](mlp_out)
x <- x + mlp_out
x <- x$reshape(c(b, gh, gw, gs, gs, c))$permute(c(1, 6, 2, 4, 3, 5))$reshape(c(b, c, h, w))
x
}
)
#' @importFrom torch nn_module nn_module_dict nn_batch_norm2d nn_conv2d nn_identity nn_sequential
mbconv <- nn_module(
initialize = function(in_channels, out_channels, expansion = 4, stride = 1) {
hidden_dim <- in_channels * expansion
self$layers <- nn_module_dict(list(
pre_norm = nn_batch_norm2d(in_channels, track_running_stats = TRUE),
conv_a = nn_sequential(
"0" = nn_conv2d(in_channels, hidden_dim, 1, bias = FALSE),
"1" = nn_batch_norm2d(hidden_dim, track_running_stats = TRUE),
"2" = nn_gelu()
),
conv_b = nn_sequential(
"0" = nn_conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups = hidden_dim, bias = FALSE),
"1" = nn_batch_norm2d(hidden_dim, track_running_stats = TRUE),
"2" = nn_gelu()
),
squeeze_excitation = squeeze_excitation(hidden_dim),
conv_c = nn_conv2d(hidden_dim, out_channels, 1, bias = TRUE)
))
# Only add projection layer for blocks that change dimensions
self$has_proj <- (stride != 1 || in_channels != out_channels)
if (self$has_proj) {
self$proj <- nn_sequential(
"0" = nn_identity(),
"1" = nn_conv2d(in_channels, out_channels, 1, bias = TRUE)
)
} else {
self$proj <- nn_sequential(
"0" = nn_identity()
)
}
self$use_res_connect <- stride == 1 && in_channels == out_channels
},
forward = function(x) {
identity <- x
out <- self$layers$pre_norm(x)
out <- self$layers$conv_a(out)
out <- self$layers$conv_b(out)
out <- self$layers$squeeze_excitation(out)
out <- self$layers$conv_c(out)
if (self$use_res_connect) {
out <- out + identity
} else if (self$has_proj) {
identity <- self$proj(identity)
if (identity$size(3) != out$size(3) || identity$size(4) != out$size(4)) {
identity <- nnf_avg_pool2d(identity, kernel_size = 2, stride = 2)
}
out <- out + identity
}
out
}
)
#' @importFrom torch nn_module nn_module_dict nn_sequential
maxvit_block <- nn_module(
initialize = function(in_channels, out_channels, expansion = 4, stride = 1, num_heads = 2) {
self$layers <- nn_module_dict(list(
MBconv = mbconv(in_channels, out_channels, expansion, stride),
window_attention = window_attention(out_channels, heads = num_heads),
grid_attention = grid_attention(out_channels, heads = num_heads)
))
},
forward = function(x) {
x <- self$layers$MBconv(x)
x <- self$layers$window_attention(x)
x <- self$layers$grid_attention(x)
x
}
)
#' @importFrom torch nn_module nn_module_dict nn_sequential
maxvit_stage <- nn_module(
initialize = function(...) {
self$layers <- nn_sequential(...)
},
forward = function(x) {
self$layers(x)
}
)
#' @importFrom torch nn_module nn_module_list nn_adaptive_avg_pool2d nn_identity nn_layer_norm nn_linear
maxvit_impl <- nn_module(
initialize = function(num_classes = 1000) {
self$stem <- maxvit_conv_norm_act(
3, mid_channels = 64, out_channels = 64,
kernel_size = 3, stride = 2, padding = 1
)
self$blocks <- nn_module_list(list(
# stage 0 - 2 heads (default)
maxvit_stage(
"0" = maxvit_block(64, 64, expansion = 4, stride = 2, num_heads = 2),
"1" = maxvit_block(64, 64, expansion = 4, stride = 1, num_heads = 2)
),
# stage 1 - 4 heads
maxvit_stage(
"0" = maxvit_block(64, 128, expansion = 8, stride = 2, num_heads = 4),
"1" = maxvit_block(128, 128, expansion = 4, stride = 1, num_heads = 4)
),
# stage 2 - 8 heads
maxvit_stage(
"0" = maxvit_block(128, 256, expansion = 8, stride = 2, num_heads = 8),
"1" = maxvit_block(256, 256, expansion = 4, stride = 1, num_heads = 8),
"2" = maxvit_block(256, 256, expansion = 4, stride = 1, num_heads = 8),
"3" = maxvit_block(256, 256, expansion = 4, stride = 1, num_heads = 8),
"4" = maxvit_block(256, 256, expansion = 4, stride = 1, num_heads = 8)
),
# stage 3 - 16 heads
maxvit_stage(
"0" = maxvit_block(256, 512, expansion = 8, stride = 2, num_heads = 16),
"1" = maxvit_block(512, 512, expansion = 4, stride = 1, num_heads = 16)
)
))
self$pool <- nn_adaptive_avg_pool2d(c(1, 1))
# Fixed classifier structure to match original naming
self$classifier <- nn_sequential(
"0" = nn_identity(),
"1" = nn_identity(),
"2" = nn_layer_norm(512),
"3" = nn_linear(512, 512, bias = TRUE),
"4" = nn_identity(),
"5" = nn_linear(512, num_classes, bias = FALSE) # Additional classifier layer
)
},
forward = function(x) {
x <- self$stem(x)
for (i in seq_along(self$blocks)) {
stage <- self$blocks[[i]]
x <- stage(x)
}
x <- self$pool(x)
x <- x$flatten(start_dim = 2)
self$classifier(x)
}
)
.rename_maxvit_state_dict <- function(state_dict) {
renamed <- list()
for (nm in names(state_dict)) {
new_nm <- nm
# MBConv projection layers - fix the renaming
if (grepl("\\.proj\\.0\\.(weight|bias)$", new_nm)) {
# Skip identity layer renaming
next
}
# Don't rename the relative position bias and index - keep original names
if (grepl("relative_position_bias_table|relative_position_index", new_nm)) {
renamed[[new_nm]] <- state_dict[[nm]]
next
}
# Attention + MLP layer renaming
new_nm <- sub("attn_layer\\.0\\.", "attn_layer.0.", new_nm)
new_nm <- sub("attn_layer\\.1\\.to_qkv\\.", "attn_layer.1.to_qkv.", new_nm)
new_nm <- sub("attn_layer\\.1\\.merge\\.", "attn_layer.1.merge.", new_nm)
new_nm <- sub("mlp_layer\\.0\\.", "mlp_layer.0.", new_nm)
new_nm <- sub("mlp_layer\\.1\\.", "mlp_layer.1.", new_nm)
new_nm <- sub("mlp_layer\\.3\\.", "mlp_layer.3.", new_nm)
# Keep classifier names exactly as they are in the .pth file
# Don't rename classifier.2, classifier.3, classifier.5
renamed[[new_nm]] <- state_dict[[nm]]
}
renamed
}
#' MaxViT Model
#'
#' Implementation of the MaxViT architecture described in
#' [MaxViT: Multi-Axis Vision Transformer](https://arxiv.org/abs/2204.01697).
#' The model performs image classification and by default returns logits for
#' 1000 ImageNet classes.
#'
#' @inheritParams model_resnet18
#' @param num_classes (integer) Number of output classes.
#' @param ... Additional parameters passed to the model initializer.
#'
#' @family classification_model
#'
#' @examples
#' \dontrun{
#' library(magrittr)
#' # 1. Load the basketball image
#' img_url <- "https://upload.wikimedia.org/wikipedia/commons/7/7a/Basketball.png"
#' img <- base_loader(img_url)
#'
#' # 2. Define normalization (ImageNet)
#' norm_mean <- c(0.485, 0.456, 0.406)
#' norm_std <- c(0.229, 0.224, 0.225)
#'
#' # 3. Preprocess: convert to tensor, resize, Normalize
#' input <- img %>%
#' transform_to_tensor() %>%
#' transform_resize(c(400, 400)) %>%
#' transform_normalize(norm_mean, norm_std)
#' batch <- input$unsqueeze(1) # Add batch dimension (1, 3, H, W)
#'
#' # 4. Display the image before normalization
#' tensor_image_browse(input)
#'
#' # 5. Load MaxViT model
#' model <- model_maxvit(pretrained = TRUE)
#' model$eval()
#'
#' # 6. Run inference
#' output <- model(batch)
#' topk <- output$topk(k = 5, dim = 2)
#' indices <- as.integer(topk[[2]][1, ])
#' scores <- as.numeric(topk[[1]][1, ])
#'
#' # 7. Show Top-5 predictions
#' glue::glue("{seq_along(indices)}. {imagenet_label(indices)} ({round(scores, 2)}%)")
#' }
#'
#' @export
model_maxvit <- function(pretrained = FALSE, progress = TRUE, num_classes = 1000, ...) {
model <- maxvit_impl(num_classes = num_classes)
if (pretrained) {
path <- download_and_cache("https://torch-cdn.mlverse.org/models/vision/v2/models/maxvit.pth")
state_dict <- torch::load_state_dict(path)
state_dict <- .rename_maxvit_state_dict(state_dict)
state_dict <- state_dict[!grepl("num_batches_tracked$", names(state_dict))]
model_state <- model$state_dict()
model_state <- model_state[!grepl("num_batches_tracked$", names(model_state))]
model$load_state_dict(state_dict, strict = FALSE)
}
model
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.