R/models-maxvit.R

Defines functions model_maxvit .rename_maxvit_state_dict create_relative_position_params maxvit_conv_norm_act

Documented in model_maxvit

# R/model-maxvit.R - torch compatible MaxViT implementation
#' @importFrom torch nn_module nn_conv2d nn_gelu nn_batch_norm2d nn_sequential
maxvit_conv_norm_act <- function(in_channels, mid_channels = 64, out_channels = 64, ...) {
  nn_sequential(
    "stem.0" = nn_sequential(
      "0" = nn_conv2d(in_channels, mid_channels, ..., bias = FALSE),
      "1" = nn_batch_norm2d(mid_channels, track_running_stats = TRUE),
      "2" = nn_gelu()
    ),
    "stem.1" = nn_sequential(
      "0" = nn_conv2d(mid_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = TRUE)
    )
  )
}

#' @importFrom torch nn_module nn_conv2d nnf_gelu torch_sigmoid
squeeze_excitation <- nn_module(
  initialize = function(in_channels, squeeze_factor = 16) {
    squeeze_channels <- max(1L, as.integer(in_channels / squeeze_factor))
    self$fc1 <- nn_conv2d(in_channels, squeeze_channels, 1)
    self$fc2 <- nn_conv2d(squeeze_channels, in_channels, 1)
  },
  forward = function(x) {
    scale <- x$mean(dim = c(3, 4), keepdim = TRUE)
    scale <- self$fc1(scale)
    scale <- nnf_gelu(scale)
    scale <- self$fc2(scale)
    scale <- torch_sigmoid(scale)
    x * scale
  }
)

# Simplified relative position implementation
#' @importFrom torch torch_arange torch_zeros torch_long
create_relative_position_params <- function(window_size, heads) {
  # Create a simple relative position index
  coords <- torch_arange(window_size * window_size)
  relative_position_index <- torch_zeros(window_size * window_size, window_size * window_size, dtype = torch_long())

  # Create bias table
  num_relative_distance <- (2 * window_size - 1) * (2 * window_size - 1)
  relative_position_bias_table <- torch_zeros(num_relative_distance, heads)

  list(
    relative_position_bias_table = relative_position_bias_table,
    relative_position_index = relative_position_index
  )
}

#' @importFrom torch nn_module nn_module_dict nn_layer_norm nn_linear nn_gelu nn_parameter torch_chunk
#' @importFrom torch torch_matmul nnf_softmax
window_attention <- nn_module(
  initialize = function(dim, heads = 2, window_size = 7) {
    self$dim <- dim
    self$heads <- heads
    self$window_size <- window_size

    # Layer structure to match expected keys
    self$attn_layer <- nn_module_dict(list(
      "0" = nn_layer_norm(dim),  # norm1
      "1" = nn_module_dict(list(
        # Custom attention implementation
        "to_qkv" = nn_linear(dim, dim * 3, bias = TRUE),
        "merge" = nn_linear(dim, dim, bias = TRUE)
      ))
    ))

    # Add relative position parameters as buffers/parameters
    rel_pos_params <- create_relative_position_params(window_size, heads)
    self$attn_layer[["1"]][["relative_position_bias_table"]] <- nn_parameter(rel_pos_params$relative_position_bias_table)
    self$register_buffer("attn_layer.1.relative_position_index", rel_pos_params$relative_position_index)

    self$mlp_layer <- nn_module_dict(list(
      "0" = nn_layer_norm(dim),  # norm2
      "1" = nn_linear(dim, dim * 4),
      "2" = nn_gelu(),
      "3" = nn_linear(dim * 4, dim)
    ))
  },

  forward = function(x) {
    b <- x$size(1); c <- x$size(2); h <- x$size(3); w <- x$size(4)
    x <- x$permute(c(1, 3, 4, 2))$reshape(c(b, h * w, c))

    # Attention path
    normed <- self$attn_layer[["0"]](x)
    qkv <- self$attn_layer[["1"]][["to_qkv"]](normed)
    qkv_chunks <- torch_chunk(qkv, 3, dim = -1)
    q <- qkv_chunks[[1]]
    k <- qkv_chunks[[2]]
    v <- qkv_chunks[[3]]

    # Simplified attention computation
    scale <- 1.0 / sqrt(c / self$heads)
    attn <- torch_matmul(q, k$transpose(-2, -1)) * scale
    attn <- nnf_softmax(attn, dim = -1)
    out <- torch_matmul(attn, v)
    out <- self$attn_layer[["1"]][["merge"]](out)

    x <- x + out

    # MLP path
    normed <- self$mlp_layer[["0"]](x)
    mlp_out <- self$mlp_layer[["1"]](normed)
    mlp_out <- self$mlp_layer[["2"]](mlp_out)
    mlp_out <- self$mlp_layer[["3"]](mlp_out)
    x <- x + mlp_out

    x <- x$reshape(c(b, h, w, c))$permute(c(1, 4, 2, 3))
    x
  }
)

#' @importFrom torch nn_module nn_module_dict nn_layer_norm nn_linear nn_gelu nn_parameter torch_chunk
grid_attention <- nn_module(
  initialize = function(dim, heads = 2, grid_size = 7) {
    self$dim <- dim
    self$heads <- heads
    self$grid_size <- grid_size

    # Layer structure to match expected keys
    self$attn_layer <- nn_module_dict(list(
      "0" = nn_layer_norm(dim),  # norm1
      "1" = nn_module_dict(list(
        # Custom attention implementation
        "to_qkv" = nn_linear(dim, dim * 3, bias = TRUE),
        "merge" = nn_linear(dim, dim, bias = TRUE)
      ))
    ))

    # Add relative position parameters as buffers/parameters
    rel_pos_params <- create_relative_position_params(grid_size, heads)
    self$attn_layer[["1"]][["relative_position_bias_table"]] <- nn_parameter(rel_pos_params$relative_position_bias_table)
    self$register_buffer("attn_layer.1.relative_position_index", rel_pos_params$relative_position_index)

    self$mlp_layer <- nn_module_dict(list(
      "0" = nn_layer_norm(dim),  # norm2
      "1" = nn_linear(dim, dim * 4),
      "2" = nn_gelu(),
      "3" = nn_linear(dim * 4, dim)
    ))
  },

  forward = function(x) {
    b <- x$size(1); c <- x$size(2); h <- x$size(3); w <- x$size(4)
    gs <- self$grid_size
    gh <- h %/% gs; gw <- w %/% gs
    x <- x$reshape(c(b, c, gh, gs, gw, gs))$permute(c(1, 3, 5, 4, 6, 2))$reshape(c(b * gh * gw, gs * gs, c))

    # Attention path
    normed <- self$attn_layer[["0"]](x)
    qkv <- self$attn_layer[["1"]][["to_qkv"]](normed)
    qkv_chunks <- torch_chunk(qkv, 3, dim = -1)
    q <- qkv_chunks[[1]]
    k <- qkv_chunks[[2]]
    v <- qkv_chunks[[3]]

    # Simplified attention computation
    scale <- 1.0 / sqrt(c / self$heads)
    attn <- torch_matmul(q, k$transpose(-2, -1)) * scale
    attn <- nnf_softmax(attn, dim = -1)
    out <- torch_matmul(attn, v)
    out <- self$attn_layer[["1"]][["merge"]](out)

    x <- x + out

    # MLP path
    normed <- self$mlp_layer[["0"]](x)
    mlp_out <- self$mlp_layer[["1"]](normed)
    mlp_out <- self$mlp_layer[["2"]](mlp_out)
    mlp_out <- self$mlp_layer[["3"]](mlp_out)
    x <- x + mlp_out

    x <- x$reshape(c(b, gh, gw, gs, gs, c))$permute(c(1, 6, 2, 4, 3, 5))$reshape(c(b, c, h, w))
    x
  }
)

#' @importFrom torch nn_module nn_module_dict nn_batch_norm2d nn_conv2d nn_identity nn_sequential
mbconv <- nn_module(
  initialize = function(in_channels, out_channels, expansion = 4, stride = 1) {
    hidden_dim <- in_channels * expansion

    self$layers <- nn_module_dict(list(
      pre_norm = nn_batch_norm2d(in_channels, track_running_stats = TRUE),
      conv_a = nn_sequential(
        "0" = nn_conv2d(in_channels, hidden_dim, 1, bias = FALSE),
        "1" = nn_batch_norm2d(hidden_dim, track_running_stats = TRUE),
        "2" = nn_gelu()
      ),
      conv_b = nn_sequential(
        "0" = nn_conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups = hidden_dim, bias = FALSE),
        "1" = nn_batch_norm2d(hidden_dim, track_running_stats = TRUE),
        "2" = nn_gelu()
      ),
      squeeze_excitation = squeeze_excitation(hidden_dim),
      conv_c = nn_conv2d(hidden_dim, out_channels, 1, bias = TRUE)
    ))

    # Only add projection layer for blocks that change dimensions
    self$has_proj <- (stride != 1 || in_channels != out_channels)
    if (self$has_proj) {
      self$proj <- nn_sequential(
        "0" = nn_identity(),
        "1" = nn_conv2d(in_channels, out_channels, 1, bias = TRUE)
      )
    } else {
      self$proj <- nn_sequential(
        "0" = nn_identity()
      )
    }

    self$use_res_connect <- stride == 1 && in_channels == out_channels
  },

  forward = function(x) {
    identity <- x
    out <- self$layers$pre_norm(x)
    out <- self$layers$conv_a(out)
    out <- self$layers$conv_b(out)
    out <- self$layers$squeeze_excitation(out)
    out <- self$layers$conv_c(out)

    if (self$use_res_connect) {
      out <- out + identity
    } else if (self$has_proj) {
      identity <- self$proj(identity)
      if (identity$size(3) != out$size(3) || identity$size(4) != out$size(4)) {
        identity <- nnf_avg_pool2d(identity, kernel_size = 2, stride = 2)
      }
      out <- out + identity
    }

    out
  }
)

#' @importFrom torch nn_module nn_module_dict nn_sequential
maxvit_block <- nn_module(
  initialize = function(in_channels, out_channels, expansion = 4, stride = 1, num_heads = 2) {
    self$layers <- nn_module_dict(list(
      MBconv = mbconv(in_channels, out_channels, expansion, stride),
      window_attention = window_attention(out_channels, heads = num_heads),
      grid_attention = grid_attention(out_channels, heads = num_heads)
    ))
  },

  forward = function(x) {
    x <- self$layers$MBconv(x)
    x <- self$layers$window_attention(x)
    x <- self$layers$grid_attention(x)
    x
  }
)

#' @importFrom torch nn_module nn_module_dict nn_sequential
maxvit_stage <- nn_module(
  initialize = function(...) {
    self$layers <- nn_sequential(...)
  },
  forward = function(x) {
    self$layers(x)
  }
)

#' @importFrom torch nn_module nn_module_list nn_adaptive_avg_pool2d nn_identity nn_layer_norm nn_linear
maxvit_impl <- nn_module(
  initialize = function(num_classes = 1000) {
    self$stem <- maxvit_conv_norm_act(
      3, mid_channels = 64, out_channels = 64,
      kernel_size = 3, stride = 2, padding = 1
    )

    self$blocks <- nn_module_list(list(
      # stage 0 - 2 heads (default)
      maxvit_stage(
        "0" = maxvit_block(64, 64, expansion = 4, stride = 2, num_heads = 2),
        "1" = maxvit_block(64, 64, expansion = 4, stride = 1, num_heads = 2)
      ),
      # stage 1 - 4 heads
      maxvit_stage(
        "0" = maxvit_block(64, 128, expansion = 8, stride = 2, num_heads = 4),
        "1" = maxvit_block(128, 128, expansion = 4, stride = 1, num_heads = 4)
      ),
      # stage 2 - 8 heads
      maxvit_stage(
        "0" = maxvit_block(128, 256, expansion = 8, stride = 2, num_heads = 8),
        "1" = maxvit_block(256, 256, expansion = 4, stride = 1, num_heads = 8),
        "2" = maxvit_block(256, 256, expansion = 4, stride = 1, num_heads = 8),
        "3" = maxvit_block(256, 256, expansion = 4, stride = 1, num_heads = 8),
        "4" = maxvit_block(256, 256, expansion = 4, stride = 1, num_heads = 8)
      ),
      # stage 3 - 16 heads
      maxvit_stage(
        "0" = maxvit_block(256, 512, expansion = 8, stride = 2, num_heads = 16),
        "1" = maxvit_block(512, 512, expansion = 4, stride = 1, num_heads = 16)
      )
    ))

    self$pool <- nn_adaptive_avg_pool2d(c(1, 1))

    # Fixed classifier structure to match original naming
    self$classifier <- nn_sequential(
      "0" = nn_identity(),
      "1" = nn_identity(),
      "2" = nn_layer_norm(512),
      "3" = nn_linear(512, 512, bias = TRUE),
      "4" = nn_identity(),
      "5" = nn_linear(512, num_classes, bias = FALSE)  # Additional classifier layer
    )
  },

  forward = function(x) {
    x <- self$stem(x)

    for (i in seq_along(self$blocks)) {
      stage <- self$blocks[[i]]
      x <- stage(x)
    }

    x <- self$pool(x)
    x <- x$flatten(start_dim = 2)
    self$classifier(x)
  }
)

.rename_maxvit_state_dict <- function(state_dict) {
  renamed <- list()

  for (nm in names(state_dict)) {

    new_nm <- nm

    # MBConv projection layers - fix the renaming
    if (grepl("\\.proj\\.0\\.(weight|bias)$", new_nm)) {
      # Skip identity layer renaming
      next
    }

    # Don't rename the relative position bias and index - keep original names
    if (grepl("relative_position_bias_table|relative_position_index", new_nm)) {
      renamed[[new_nm]] <- state_dict[[nm]]
      next
    }

    # Attention + MLP layer renaming
    new_nm <- sub("attn_layer\\.0\\.", "attn_layer.0.", new_nm)
    new_nm <- sub("attn_layer\\.1\\.to_qkv\\.", "attn_layer.1.to_qkv.", new_nm)
    new_nm <- sub("attn_layer\\.1\\.merge\\.", "attn_layer.1.merge.", new_nm)
    new_nm <- sub("mlp_layer\\.0\\.", "mlp_layer.0.", new_nm)
    new_nm <- sub("mlp_layer\\.1\\.", "mlp_layer.1.", new_nm)
    new_nm <- sub("mlp_layer\\.3\\.", "mlp_layer.3.", new_nm)

    # Keep classifier names exactly as they are in the .pth file
    # Don't rename classifier.2, classifier.3, classifier.5

    renamed[[new_nm]] <- state_dict[[nm]]
  }

  renamed
}

#' MaxViT Model
#'
#' Implementation of the MaxViT architecture described in
#' [MaxViT: Multi-Axis Vision Transformer](https://arxiv.org/abs/2204.01697).
#' The model performs image classification and by default returns logits for
#' 1000 ImageNet classes.
#'
#' @inheritParams model_resnet18
#' @param num_classes (integer) Number of output classes.
#' @param ... Additional parameters passed to the model initializer.
#'
#' @family classification_model
#'
#' @examples
#' \dontrun{
#' library(magrittr)
#' # 1. Load the basketball image
#' img_url <- "https://upload.wikimedia.org/wikipedia/commons/7/7a/Basketball.png"
#' img <- base_loader(img_url)
#'
#' # 2. Define normalization (ImageNet)
#' norm_mean <- c(0.485, 0.456, 0.406)
#' norm_std <- c(0.229, 0.224, 0.225)
#'
#' # 3. Preprocess: convert to tensor, resize, Normalize
#' input <- img %>%
#'   transform_to_tensor() %>%
#'   transform_resize(c(400, 400)) %>%
#'   transform_normalize(norm_mean, norm_std)
#' batch <- input$unsqueeze(1)    # Add batch dimension (1, 3, H, W)
#'
#' # 4. Display the image before normalization
#' tensor_image_browse(input)
#'
#' # 5. Load MaxViT model
#' model <- model_maxvit(pretrained = TRUE)
#' model$eval()
#'
#' # 6. Run inference
#' output <- model(batch)
#' topk <- output$topk(k = 5, dim = 2)
#' indices <- as.integer(topk[[2]][1, ])
#' scores <- as.numeric(topk[[1]][1, ])
#'
#' # 7. Show Top-5 predictions
#' glue::glue("{seq_along(indices)}. {imagenet_label(indices)} ({round(scores, 2)}%)")
#' }
#'
#' @export
model_maxvit <- function(pretrained = FALSE, progress = TRUE, num_classes = 1000, ...) {
  model <- maxvit_impl(num_classes = num_classes)

  if (pretrained) {
    path <- download_and_cache("https://torch-cdn.mlverse.org/models/vision/v2/models/maxvit.pth")
    state_dict <- torch::load_state_dict(path)
    state_dict <- .rename_maxvit_state_dict(state_dict)
    state_dict <- state_dict[!grepl("num_batches_tracked$", names(state_dict))]

    model_state <- model$state_dict()
    model_state <- model_state[!grepl("num_batches_tracked$", names(model_state))]

    model$load_state_dict(state_dict, strict = FALSE)
  }

  model
}

Try the torchvision package in your browser

Any scripts or data that you put into this service are public.

torchvision documentation built on Nov. 6, 2025, 9:07 a.m.