Download the data

if (!dir.exists("datasets/images")) {
  if (!dir.exists("datasets")) dir.create("datasets")

  options(timeout = 5000)
  download.file(
    "https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz",
    "datasets/images.tar.gz"
  )
  download.file(
    "https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz",
    "datasets/annotations.tar.gz"
  )

  untar("datasets/images.tar.gz", exdir = "datasets")
  untar("datasets/annotations.tar.gz", exdir = "datasets")
}
options(timeout = 5000)
download.file(
  "https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz",
  "datasets/images.tar.gz"
)
download.file(
  "https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz",
  "datasets/annotations.tar.gz"
)

untar("datasets/images.tar.gz", exdir = "datasets")
untar("datasets/annotations.tar.gz", exdir = "datasets")

Prepare paths of input images and target segmentation masks

library(keras3)
input_dir <- "datasets/images/"
target_dir <- "datasets/annotations/trimaps/"
img_size <- c(160, 160)
num_classes <- 3
batch_size <- 32

input_img_paths <- fs::dir_ls(input_dir, glob = "*.jpg") |> sort()
target_img_paths <- fs::dir_ls(target_dir, glob = "*.png") |> sort()

cat("Number of samples:", length(input_img_paths), "\n")
for (i in 1:10) {
  cat(input_img_paths[i], "|", target_img_paths[i], "\n")
}

What does one input image and corresponding segmentation mask look like?

# Display input image #10
input_img_paths[10] |>
  jpeg::readJPEG() |>
  as.raster() |>
  plot()

target_img_paths[10] |>
  png::readPNG() |>
  magrittr::multiply_by(255)|>
  as.raster(max = 3) |>
  plot()

Prepare dataset to load & vectorize batches of data

library(tensorflow, exclude = c("shape", "set_random_seed"))
library(tfdatasets, exclude = "shape")


# Returns a tf_dataset
get_dataset <- function(batch_size, img_size, input_img_paths, target_img_paths,
                        max_dataset_len = NULL) {

  img_size <- as.integer(img_size)

  load_img_masks <- function(input_img_path, target_img_path) {
    input_img <- input_img_path |>
      tf$io$read_file() |>
      tf$io$decode_jpeg(channels = 3) |>
      tf$image$resize(img_size) |>
      tf$image$convert_image_dtype("float32")

    target_img <- target_img_path |>
      tf$io$read_file() |>
      tf$io$decode_png(channels = 1) |>
      tf$image$resize(img_size, method = "nearest") |>
      tf$image$convert_image_dtype("uint8")

    # Ground truth labels are 1, 2, 3. Subtract one to make them 0, 1, 2:
    target_img <- target_img - 1L

    list(input_img, target_img)
  }

  if (!is.null(max_dataset_len)) {
    input_img_paths <- input_img_paths[1:max_dataset_len]
    target_img_paths <- target_img_paths[1:max_dataset_len]
  }

  list(input_img_paths, target_img_paths) |>
    tensor_slices_dataset() |>
    dataset_map(load_img_masks, num_parallel_calls = tf$data$AUTOTUNE)|>
    dataset_batch(batch_size)
}

Prepare U-Net Xception-style model

get_model <- function(img_size, num_classes) {

  inputs <- keras_input(shape = c(img_size, 3))

  ### [First half of the network: downsampling inputs] ###

  # Entry block
  x <- inputs |>
    layer_conv_2d(filters = 32, kernel_size = 3, strides = 2, padding = "same") |>
    layer_batch_normalization() |>
    layer_activation("relu")

  previous_block_activation <- x  # Set aside residual

  for (filters in c(64, 128, 256)) {
    x <- x |>
      layer_activation("relu") |>
      layer_separable_conv_2d(filters = filters, kernel_size = 3, padding = "same") |>
      layer_batch_normalization() |>

      layer_activation("relu") |>
      layer_separable_conv_2d(filters = filters, kernel_size = 3, padding = "same") |>
      layer_batch_normalization() |>

      layer_max_pooling_2d(pool_size = 3, strides = 2, padding = "same")

    residual <- previous_block_activation |>
      layer_conv_2d(filters = filters, kernel_size = 1, strides = 2, padding = "same")

    x <- layer_add(x, residual)  # Add back residual
    previous_block_activation <- x  # Set aside next residual
  }

  ### [Second half of the network: upsampling inputs] ###

  for (filters in c(256, 128, 64, 32)) {
    x <- x |>
      layer_activation("relu") |>
      layer_conv_2d_transpose(filters = filters, kernel_size = 3, padding = "same") |>
      layer_batch_normalization() |>

      layer_activation("relu") |>
      layer_conv_2d_transpose(filters = filters, kernel_size = 3, padding = "same") |>
      layer_batch_normalization() |>

      layer_upsampling_2d(size = 2)

    # Project residual
    residual <- previous_block_activation |>
      layer_upsampling_2d(size = 2) |>
      layer_conv_2d(filters = filters, kernel_size = 1, padding = "same")

    x <- layer_add(x, residual)     # Add back residual
    previous_block_activation <- x  # Set aside next residual
  }

  # Add a per-pixel classification layer
  outputs <- x |>
    layer_conv_2d(num_classes, 3, activation = "softmax", padding = "same")

  # Define the model
  keras_model(inputs, outputs)
}

# Build model
model <- get_model(img_size, num_classes)
summary(model)

Set aside a validation split

# Split our img paths into a training and a validation set
val_samples <- 1000
val_samples <- sample.int(length(input_img_paths), val_samples)

train_input_img_paths <- input_img_paths[-val_samples]
train_target_img_paths <- target_img_paths[-val_samples]

val_input_img_paths <- input_img_paths[val_samples]
val_target_img_paths <- target_img_paths[val_samples]

# Instantiate dataset for each split
# Limit input files in `max_dataset_len` for faster epoch training time.
# Remove the `max_dataset_len` arg when running with full dataset.
train_dataset <- get_dataset(
  batch_size,
  img_size,
  train_input_img_paths,
  train_target_img_paths,
  max_dataset_len = 1000
)
valid_dataset <- get_dataset(
  batch_size, img_size, val_input_img_paths, val_target_img_paths
)

Train the model

# Configure the model for training.
# We use the "sparse" version of categorical_crossentropy
# because our target data is integers.
model |> compile(
  optimizer = optimizer_adam(1e-4), 
  loss = "sparse_categorical_crossentropy"
)

callbacks <- list(
  callback_model_checkpoint(
    "models/oxford_segmentation.keras", save_best_only = TRUE
  )
)

# Train the model, doing validation at the end of each epoch.
epochs <- 50
model |> fit(
    train_dataset,
    epochs=epochs,
    validation_data=valid_dataset,
    callbacks=callbacks,
    verbose=2
)

Visualize predictions

model <- load_model("models/oxford_segmentation.keras")
# Generate predictions for all images in the validation set
val_dataset <- get_dataset(
  batch_size, img_size, val_input_img_paths, val_target_img_paths
)
val_preds <- predict(model, val_dataset)

display_mask <- function(i) {
  # Quick utility to display a model's prediction.
  mask <- val_preds[i,,,] %>% 
    apply(c(1,2), which.max) %>% 
    array_reshape(dim = c(img_size, 1))
  mask <- abind::abind(mask, mask, mask, along = 3)
  plot(as.raster(mask, max = 3))
}

# Display results for validation image #10
i <- 10

par(mfrow = c(1, 3))
# Display input image
input_img_paths[i] |>
  jpeg::readJPEG() |>
  as.raster() |>
  plot()

# Display ground-truth target mask
target_img_paths[i] |>
  png::readPNG() |>
  magrittr::multiply_by(255)|>
  as.raster(max = 3) |>
  plot()

# Display mask predicted by our model
display_mask(i)  # Note that the model only sees inputs at 150x150.


rstudio/keras documentation built on May 17, 2024, 9:23 p.m.