options(timeout = 5000) download.file( "https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", "datasets/images.tar.gz" ) download.file( "https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", "datasets/annotations.tar.gz" ) untar("datasets/images.tar.gz", exdir = "datasets") untar("datasets/annotations.tar.gz", exdir = "datasets")
library(keras3) input_dir <- "datasets/images/" target_dir <- "datasets/annotations/trimaps/" img_size <- c(160, 160) num_classes <- 3 batch_size <- 32 input_img_paths <- fs::dir_ls(input_dir, glob = "*.jpg") |> sort() target_img_paths <- fs::dir_ls(target_dir, glob = "*.png") |> sort() cat("Number of samples:", length(input_img_paths), "\n")
## Number of samples: 7390
for (i in 1:10) { cat(input_img_paths[i], "|", target_img_paths[i], "\n") }
## datasets/images/Abyssinian_1.jpg | datasets/annotations/trimaps/Abyssinian_1.png ## datasets/images/Abyssinian_10.jpg | datasets/annotations/trimaps/Abyssinian_10.png ## datasets/images/Abyssinian_100.jpg | datasets/annotations/trimaps/Abyssinian_100.png ## datasets/images/Abyssinian_101.jpg | datasets/annotations/trimaps/Abyssinian_101.png ## datasets/images/Abyssinian_102.jpg | datasets/annotations/trimaps/Abyssinian_102.png ## datasets/images/Abyssinian_103.jpg | datasets/annotations/trimaps/Abyssinian_103.png ## datasets/images/Abyssinian_104.jpg | datasets/annotations/trimaps/Abyssinian_104.png ## datasets/images/Abyssinian_105.jpg | datasets/annotations/trimaps/Abyssinian_105.png ## datasets/images/Abyssinian_106.jpg | datasets/annotations/trimaps/Abyssinian_106.png ## datasets/images/Abyssinian_107.jpg | datasets/annotations/trimaps/Abyssinian_107.png
# Display input image #10 input_img_paths[10] |> jpeg::readJPEG() |> as.raster() |> plot()
target_img_paths[10] |> png::readPNG() |> magrittr::multiply_by(255)|> as.raster(max = 3) |> plot()
library(tensorflow, exclude = c("shape", "set_random_seed")) library(tfdatasets, exclude = "shape") # Returns a tf_dataset get_dataset <- function(batch_size, img_size, input_img_paths, target_img_paths, max_dataset_len = NULL) { img_size <- as.integer(img_size) load_img_masks <- function(input_img_path, target_img_path) { input_img <- input_img_path |> tf$io$read_file() |> tf$io$decode_jpeg(channels = 3) |> tf$image$resize(img_size) |> tf$image$convert_image_dtype("float32") target_img <- target_img_path |> tf$io$read_file() |> tf$io$decode_png(channels = 1) |> tf$image$resize(img_size, method = "nearest") |> tf$image$convert_image_dtype("uint8") # Ground truth labels are 1, 2, 3. Subtract one to make them 0, 1, 2: target_img <- target_img - 1L list(input_img, target_img) } if (!is.null(max_dataset_len)) { input_img_paths <- input_img_paths[1:max_dataset_len] target_img_paths <- target_img_paths[1:max_dataset_len] } list(input_img_paths, target_img_paths) |> tensor_slices_dataset() |> dataset_map(load_img_masks, num_parallel_calls = tf$data$AUTOTUNE)|> dataset_batch(batch_size) }
get_model <- function(img_size, num_classes) { inputs <- keras_input(shape = c(img_size, 3)) ### [First half of the network: downsampling inputs] ### # Entry block x <- inputs |> layer_conv_2d(filters = 32, kernel_size = 3, strides = 2, padding = "same") |> layer_batch_normalization() |> layer_activation("relu") previous_block_activation <- x # Set aside residual for (filters in c(64, 128, 256)) { x <- x |> layer_activation("relu") |> layer_separable_conv_2d(filters = filters, kernel_size = 3, padding = "same") |> layer_batch_normalization() |> layer_activation("relu") |> layer_separable_conv_2d(filters = filters, kernel_size = 3, padding = "same") |> layer_batch_normalization() |> layer_max_pooling_2d(pool_size = 3, strides = 2, padding = "same") residual <- previous_block_activation |> layer_conv_2d(filters = filters, kernel_size = 1, strides = 2, padding = "same") x <- layer_add(x, residual) # Add back residual previous_block_activation <- x # Set aside next residual } ### [Second half of the network: upsampling inputs] ### for (filters in c(256, 128, 64, 32)) { x <- x |> layer_activation("relu") |> layer_conv_2d_transpose(filters = filters, kernel_size = 3, padding = "same") |> layer_batch_normalization() |> layer_activation("relu") |> layer_conv_2d_transpose(filters = filters, kernel_size = 3, padding = "same") |> layer_batch_normalization() |> layer_upsampling_2d(size = 2) # Project residual residual <- previous_block_activation |> layer_upsampling_2d(size = 2) |> layer_conv_2d(filters = filters, kernel_size = 1, padding = "same") x <- layer_add(x, residual) # Add back residual previous_block_activation <- x # Set aside next residual } # Add a per-pixel classification layer outputs <- x |> layer_conv_2d(num_classes, 3, activation = "softmax", padding = "same") # Define the model keras_model(inputs, outputs) } # Build model model <- get_model(img_size, num_classes) summary(model)
## [1mModel: "functional_1"[0m ## ┏━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━┓ ## ┃[1m [0m[1mLayer (type) [0m[1m [0m┃[1m [0m[1mOutput Shape [0m[1m [0m┃[1m [0m[1m Param #[0m[1m [0m┃[1m [0m[1mConnected to [0m[1m [0m┃[1m [0m[1mTrai…[0m[1m [0m┃ ## ┡━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━┩ ## │ input_layer │ ([38;5;45mNone[0m, [38;5;34m160[0m, │ [38;5;34m0[0m │ - │ [1m-[0m │ ## │ ([38;5;33mInputLayer[0m) │ [38;5;34m160[0m, [38;5;34m3[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ conv2d ([38;5;33mConv2D[0m) │ ([38;5;45mNone[0m, [38;5;34m80[0m, [38;5;34m80[0m, │ [38;5;34m896[0m │ input_layer[[38;5;34m0[0m… │ [1;38;5;34mY[0m │ ## │ │ [38;5;34m32[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ batch_normalizat… │ ([38;5;45mNone[0m, [38;5;34m80[0m, [38;5;34m80[0m, │ [38;5;34m128[0m │ conv2d[[38;5;34m0[0m][[38;5;34m0[0m] │ [1;38;5;34mY[0m │ ## │ ([38;5;33mBatchNormalizat…[0m │ [38;5;34m32[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ activation │ ([38;5;45mNone[0m, [38;5;34m80[0m, [38;5;34m80[0m, │ [38;5;34m0[0m │ batch_normali… │ [1m-[0m │ ## │ ([38;5;33mActivation[0m) │ [38;5;34m32[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ activation_1 │ ([38;5;45mNone[0m, [38;5;34m80[0m, [38;5;34m80[0m, │ [38;5;34m0[0m │ activation[[38;5;34m0[0m]… │ [1m-[0m │ ## │ ([38;5;33mActivation[0m) │ [38;5;34m32[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ separable_conv2d │ ([38;5;45mNone[0m, [38;5;34m80[0m, [38;5;34m80[0m, │ [38;5;34m2,400[0m │ activation_1[[38;5;34m…[0m │ [1;38;5;34mY[0m │ ## │ ([38;5;33mSeparableConv2D[0m) │ [38;5;34m64[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ batch_normalizat… │ ([38;5;45mNone[0m, [38;5;34m80[0m, [38;5;34m80[0m, │ [38;5;34m256[0m │ separable_con… │ [1;38;5;34mY[0m │ ## │ ([38;5;33mBatchNormalizat…[0m │ [38;5;34m64[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ activation_2 │ ([38;5;45mNone[0m, [38;5;34m80[0m, [38;5;34m80[0m, │ [38;5;34m0[0m │ batch_normali… │ [1m-[0m │ ## │ ([38;5;33mActivation[0m) │ [38;5;34m64[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ separable_conv2d… │ ([38;5;45mNone[0m, [38;5;34m80[0m, [38;5;34m80[0m, │ [38;5;34m4,736[0m │ activation_2[[38;5;34m…[0m │ [1;38;5;34mY[0m │ ## │ ([38;5;33mSeparableConv2D[0m) │ [38;5;34m64[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ batch_normalizat… │ ([38;5;45mNone[0m, [38;5;34m80[0m, [38;5;34m80[0m, │ [38;5;34m256[0m │ separable_con… │ [1;38;5;34mY[0m │ ## │ ([38;5;33mBatchNormalizat…[0m │ [38;5;34m64[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ max_pooling2d │ ([38;5;45mNone[0m, [38;5;34m40[0m, [38;5;34m40[0m, │ [38;5;34m0[0m │ batch_normali… │ [1m-[0m │ ## │ ([38;5;33mMaxPooling2D[0m) │ [38;5;34m64[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ conv2d_1 ([38;5;33mConv2D[0m) │ ([38;5;45mNone[0m, [38;5;34m40[0m, [38;5;34m40[0m, │ [38;5;34m2,112[0m │ activation[[38;5;34m0[0m]… │ [1;38;5;34mY[0m │ ## │ │ [38;5;34m64[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ add ([38;5;33mAdd[0m) │ ([38;5;45mNone[0m, [38;5;34m40[0m, [38;5;34m40[0m, │ [38;5;34m0[0m │ max_pooling2d… │ [1m-[0m │ ## │ │ [38;5;34m64[0m) │ │ conv2d_1[[38;5;34m0[0m][[38;5;34m0[0m] │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ activation_3 │ ([38;5;45mNone[0m, [38;5;34m40[0m, [38;5;34m40[0m, │ [38;5;34m0[0m │ add[[38;5;34m0[0m][[38;5;34m0[0m] │ [1m-[0m │ ## │ ([38;5;33mActivation[0m) │ [38;5;34m64[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ separable_conv2d… │ ([38;5;45mNone[0m, [38;5;34m40[0m, [38;5;34m40[0m, │ [38;5;34m8,896[0m │ activation_3[[38;5;34m…[0m │ [1;38;5;34mY[0m │ ## │ ([38;5;33mSeparableConv2D[0m) │ [38;5;34m128[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ batch_normalizat… │ ([38;5;45mNone[0m, [38;5;34m40[0m, [38;5;34m40[0m, │ [38;5;34m512[0m │ separable_con… │ [1;38;5;34mY[0m │ ## │ ([38;5;33mBatchNormalizat…[0m │ [38;5;34m128[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ activation_4 │ ([38;5;45mNone[0m, [38;5;34m40[0m, [38;5;34m40[0m, │ [38;5;34m0[0m │ batch_normali… │ [1m-[0m │ ## │ ([38;5;33mActivation[0m) │ [38;5;34m128[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ separable_conv2d… │ ([38;5;45mNone[0m, [38;5;34m40[0m, [38;5;34m40[0m, │ [38;5;34m17,664[0m │ activation_4[[38;5;34m…[0m │ [1;38;5;34mY[0m │ ## │ ([38;5;33mSeparableConv2D[0m) │ [38;5;34m128[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ batch_normalizat… │ ([38;5;45mNone[0m, [38;5;34m40[0m, [38;5;34m40[0m, │ [38;5;34m512[0m │ separable_con… │ [1;38;5;34mY[0m │ ## │ ([38;5;33mBatchNormalizat…[0m │ [38;5;34m128[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ max_pooling2d_1 │ ([38;5;45mNone[0m, [38;5;34m20[0m, [38;5;34m20[0m, │ [38;5;34m0[0m │ batch_normali… │ [1m-[0m │ ## │ ([38;5;33mMaxPooling2D[0m) │ [38;5;34m128[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ conv2d_2 ([38;5;33mConv2D[0m) │ ([38;5;45mNone[0m, [38;5;34m20[0m, [38;5;34m20[0m, │ [38;5;34m8,320[0m │ add[[38;5;34m0[0m][[38;5;34m0[0m] │ [1;38;5;34mY[0m │ ## │ │ [38;5;34m128[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ add_1 ([38;5;33mAdd[0m) │ ([38;5;45mNone[0m, [38;5;34m20[0m, [38;5;34m20[0m, │ [38;5;34m0[0m │ max_pooling2d… │ [1m-[0m │ ## │ │ [38;5;34m128[0m) │ │ conv2d_2[[38;5;34m0[0m][[38;5;34m0[0m] │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ activation_5 │ ([38;5;45mNone[0m, [38;5;34m20[0m, [38;5;34m20[0m, │ [38;5;34m0[0m │ add_1[[38;5;34m0[0m][[38;5;34m0[0m] │ [1m-[0m │ ## │ ([38;5;33mActivation[0m) │ [38;5;34m128[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ separable_conv2d… │ ([38;5;45mNone[0m, [38;5;34m20[0m, [38;5;34m20[0m, │ [38;5;34m34,176[0m │ activation_5[[38;5;34m…[0m │ [1;38;5;34mY[0m │ ## │ ([38;5;33mSeparableConv2D[0m) │ [38;5;34m256[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ batch_normalizat… │ ([38;5;45mNone[0m, [38;5;34m20[0m, [38;5;34m20[0m, │ [38;5;34m1,024[0m │ separable_con… │ [1;38;5;34mY[0m │ ## │ ([38;5;33mBatchNormalizat…[0m │ [38;5;34m256[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ activation_6 │ ([38;5;45mNone[0m, [38;5;34m20[0m, [38;5;34m20[0m, │ [38;5;34m0[0m │ batch_normali… │ [1m-[0m │ ## │ ([38;5;33mActivation[0m) │ [38;5;34m256[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ separable_conv2d… │ ([38;5;45mNone[0m, [38;5;34m20[0m, [38;5;34m20[0m, │ [38;5;34m68,096[0m │ activation_6[[38;5;34m…[0m │ [1;38;5;34mY[0m │ ## │ ([38;5;33mSeparableConv2D[0m) │ [38;5;34m256[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ batch_normalizat… │ ([38;5;45mNone[0m, [38;5;34m20[0m, [38;5;34m20[0m, │ [38;5;34m1,024[0m │ separable_con… │ [1;38;5;34mY[0m │ ## │ ([38;5;33mBatchNormalizat…[0m │ [38;5;34m256[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ max_pooling2d_2 │ ([38;5;45mNone[0m, [38;5;34m10[0m, [38;5;34m10[0m, │ [38;5;34m0[0m │ batch_normali… │ [1m-[0m │ ## │ ([38;5;33mMaxPooling2D[0m) │ [38;5;34m256[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ conv2d_3 ([38;5;33mConv2D[0m) │ ([38;5;45mNone[0m, [38;5;34m10[0m, [38;5;34m10[0m, │ [38;5;34m33,024[0m │ add_1[[38;5;34m0[0m][[38;5;34m0[0m] │ [1;38;5;34mY[0m │ ## │ │ [38;5;34m256[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ add_2 ([38;5;33mAdd[0m) │ ([38;5;45mNone[0m, [38;5;34m10[0m, [38;5;34m10[0m, │ [38;5;34m0[0m │ max_pooling2d… │ [1m-[0m │ ## │ │ [38;5;34m256[0m) │ │ conv2d_3[[38;5;34m0[0m][[38;5;34m0[0m] │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ activation_7 │ ([38;5;45mNone[0m, [38;5;34m10[0m, [38;5;34m10[0m, │ [38;5;34m0[0m │ add_2[[38;5;34m0[0m][[38;5;34m0[0m] │ [1m-[0m │ ## │ ([38;5;33mActivation[0m) │ [38;5;34m256[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ conv2d_transpose │ ([38;5;45mNone[0m, [38;5;34m10[0m, [38;5;34m10[0m, │ [38;5;34m590,080[0m │ activation_7[[38;5;34m…[0m │ [1;38;5;34mY[0m │ ## │ ([38;5;33mConv2DTranspose[0m) │ [38;5;34m256[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ batch_normalizat… │ ([38;5;45mNone[0m, [38;5;34m10[0m, [38;5;34m10[0m, │ [38;5;34m1,024[0m │ conv2d_transp… │ [1;38;5;34mY[0m │ ## │ ([38;5;33mBatchNormalizat…[0m │ [38;5;34m256[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ activation_8 │ ([38;5;45mNone[0m, [38;5;34m10[0m, [38;5;34m10[0m, │ [38;5;34m0[0m │ batch_normali… │ [1m-[0m │ ## │ ([38;5;33mActivation[0m) │ [38;5;34m256[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ conv2d_transpose… │ ([38;5;45mNone[0m, [38;5;34m10[0m, [38;5;34m10[0m, │ [38;5;34m590,080[0m │ activation_8[[38;5;34m…[0m │ [1;38;5;34mY[0m │ ## │ ([38;5;33mConv2DTranspose[0m) │ [38;5;34m256[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ batch_normalizat… │ ([38;5;45mNone[0m, [38;5;34m10[0m, [38;5;34m10[0m, │ [38;5;34m1,024[0m │ conv2d_transp… │ [1;38;5;34mY[0m │ ## │ ([38;5;33mBatchNormalizat…[0m │ [38;5;34m256[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ up_sampling2d_1 │ ([38;5;45mNone[0m, [38;5;34m20[0m, [38;5;34m20[0m, │ [38;5;34m0[0m │ add_2[[38;5;34m0[0m][[38;5;34m0[0m] │ [1m-[0m │ ## │ ([38;5;33mUpSampling2D[0m) │ [38;5;34m256[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ up_sampling2d │ ([38;5;45mNone[0m, [38;5;34m20[0m, [38;5;34m20[0m, │ [38;5;34m0[0m │ batch_normali… │ [1m-[0m │ ## │ ([38;5;33mUpSampling2D[0m) │ [38;5;34m256[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ conv2d_4 ([38;5;33mConv2D[0m) │ ([38;5;45mNone[0m, [38;5;34m20[0m, [38;5;34m20[0m, │ [38;5;34m65,792[0m │ up_sampling2d… │ [1;38;5;34mY[0m │ ## │ │ [38;5;34m256[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ add_3 ([38;5;33mAdd[0m) │ ([38;5;45mNone[0m, [38;5;34m20[0m, [38;5;34m20[0m, │ [38;5;34m0[0m │ up_sampling2d… │ [1m-[0m │ ## │ │ [38;5;34m256[0m) │ │ conv2d_4[[38;5;34m0[0m][[38;5;34m0[0m] │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ activation_9 │ ([38;5;45mNone[0m, [38;5;34m20[0m, [38;5;34m20[0m, │ [38;5;34m0[0m │ add_3[[38;5;34m0[0m][[38;5;34m0[0m] │ [1m-[0m │ ## │ ([38;5;33mActivation[0m) │ [38;5;34m256[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ conv2d_transpose… │ ([38;5;45mNone[0m, [38;5;34m20[0m, [38;5;34m20[0m, │ [38;5;34m295,040[0m │ activation_9[[38;5;34m…[0m │ [1;38;5;34mY[0m │ ## │ ([38;5;33mConv2DTranspose[0m) │ [38;5;34m128[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ batch_normalizat… │ ([38;5;45mNone[0m, [38;5;34m20[0m, [38;5;34m20[0m, │ [38;5;34m512[0m │ conv2d_transp… │ [1;38;5;34mY[0m │ ## │ ([38;5;33mBatchNormalizat…[0m │ [38;5;34m128[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ activation_10 │ ([38;5;45mNone[0m, [38;5;34m20[0m, [38;5;34m20[0m, │ [38;5;34m0[0m │ batch_normali… │ [1m-[0m │ ## │ ([38;5;33mActivation[0m) │ [38;5;34m128[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ conv2d_transpose… │ ([38;5;45mNone[0m, [38;5;34m20[0m, [38;5;34m20[0m, │ [38;5;34m147,584[0m │ activation_10… │ [1;38;5;34mY[0m │ ## │ ([38;5;33mConv2DTranspose[0m) │ [38;5;34m128[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ batch_normalizat… │ ([38;5;45mNone[0m, [38;5;34m20[0m, [38;5;34m20[0m, │ [38;5;34m512[0m │ conv2d_transp… │ [1;38;5;34mY[0m │ ## │ ([38;5;33mBatchNormalizat…[0m │ [38;5;34m128[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ up_sampling2d_3 │ ([38;5;45mNone[0m, [38;5;34m40[0m, [38;5;34m40[0m, │ [38;5;34m0[0m │ add_3[[38;5;34m0[0m][[38;5;34m0[0m] │ [1m-[0m │ ## │ ([38;5;33mUpSampling2D[0m) │ [38;5;34m256[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ up_sampling2d_2 │ ([38;5;45mNone[0m, [38;5;34m40[0m, [38;5;34m40[0m, │ [38;5;34m0[0m │ batch_normali… │ [1m-[0m │ ## │ ([38;5;33mUpSampling2D[0m) │ [38;5;34m128[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ conv2d_5 ([38;5;33mConv2D[0m) │ ([38;5;45mNone[0m, [38;5;34m40[0m, [38;5;34m40[0m, │ [38;5;34m32,896[0m │ up_sampling2d… │ [1;38;5;34mY[0m │ ## │ │ [38;5;34m128[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ add_4 ([38;5;33mAdd[0m) │ ([38;5;45mNone[0m, [38;5;34m40[0m, [38;5;34m40[0m, │ [38;5;34m0[0m │ up_sampling2d… │ [1m-[0m │ ## │ │ [38;5;34m128[0m) │ │ conv2d_5[[38;5;34m0[0m][[38;5;34m0[0m] │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ activation_11 │ ([38;5;45mNone[0m, [38;5;34m40[0m, [38;5;34m40[0m, │ [38;5;34m0[0m │ add_4[[38;5;34m0[0m][[38;5;34m0[0m] │ [1m-[0m │ ## │ ([38;5;33mActivation[0m) │ [38;5;34m128[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ conv2d_transpose… │ ([38;5;45mNone[0m, [38;5;34m40[0m, [38;5;34m40[0m, │ [38;5;34m73,792[0m │ activation_11… │ [1;38;5;34mY[0m │ ## │ ([38;5;33mConv2DTranspose[0m) │ [38;5;34m64[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ batch_normalizat… │ ([38;5;45mNone[0m, [38;5;34m40[0m, [38;5;34m40[0m, │ [38;5;34m256[0m │ conv2d_transp… │ [1;38;5;34mY[0m │ ## │ ([38;5;33mBatchNormalizat…[0m │ [38;5;34m64[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ activation_12 │ ([38;5;45mNone[0m, [38;5;34m40[0m, [38;5;34m40[0m, │ [38;5;34m0[0m │ batch_normali… │ [1m-[0m │ ## │ ([38;5;33mActivation[0m) │ [38;5;34m64[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ conv2d_transpose… │ ([38;5;45mNone[0m, [38;5;34m40[0m, [38;5;34m40[0m, │ [38;5;34m36,928[0m │ activation_12… │ [1;38;5;34mY[0m │ ## │ ([38;5;33mConv2DTranspose[0m) │ [38;5;34m64[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ batch_normalizat… │ ([38;5;45mNone[0m, [38;5;34m40[0m, [38;5;34m40[0m, │ [38;5;34m256[0m │ conv2d_transp… │ [1;38;5;34mY[0m │ ## │ ([38;5;33mBatchNormalizat…[0m │ [38;5;34m64[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ up_sampling2d_5 │ ([38;5;45mNone[0m, [38;5;34m80[0m, [38;5;34m80[0m, │ [38;5;34m0[0m │ add_4[[38;5;34m0[0m][[38;5;34m0[0m] │ [1m-[0m │ ## │ ([38;5;33mUpSampling2D[0m) │ [38;5;34m128[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ up_sampling2d_4 │ ([38;5;45mNone[0m, [38;5;34m80[0m, [38;5;34m80[0m, │ [38;5;34m0[0m │ batch_normali… │ [1m-[0m │ ## │ ([38;5;33mUpSampling2D[0m) │ [38;5;34m64[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ conv2d_6 ([38;5;33mConv2D[0m) │ ([38;5;45mNone[0m, [38;5;34m80[0m, [38;5;34m80[0m, │ [38;5;34m8,256[0m │ up_sampling2d… │ [1;38;5;34mY[0m │ ## │ │ [38;5;34m64[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ add_5 ([38;5;33mAdd[0m) │ ([38;5;45mNone[0m, [38;5;34m80[0m, [38;5;34m80[0m, │ [38;5;34m0[0m │ up_sampling2d… │ [1m-[0m │ ## │ │ [38;5;34m64[0m) │ │ conv2d_6[[38;5;34m0[0m][[38;5;34m0[0m] │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ activation_13 │ ([38;5;45mNone[0m, [38;5;34m80[0m, [38;5;34m80[0m, │ [38;5;34m0[0m │ add_5[[38;5;34m0[0m][[38;5;34m0[0m] │ [1m-[0m │ ## │ ([38;5;33mActivation[0m) │ [38;5;34m64[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ conv2d_transpose… │ ([38;5;45mNone[0m, [38;5;34m80[0m, [38;5;34m80[0m, │ [38;5;34m18,464[0m │ activation_13… │ [1;38;5;34mY[0m │ ## │ ([38;5;33mConv2DTranspose[0m) │ [38;5;34m32[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ batch_normalizat… │ ([38;5;45mNone[0m, [38;5;34m80[0m, [38;5;34m80[0m, │ [38;5;34m128[0m │ conv2d_transp… │ [1;38;5;34mY[0m │ ## │ ([38;5;33mBatchNormalizat…[0m │ [38;5;34m32[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ activation_14 │ ([38;5;45mNone[0m, [38;5;34m80[0m, [38;5;34m80[0m, │ [38;5;34m0[0m │ batch_normali… │ [1m-[0m │ ## │ ([38;5;33mActivation[0m) │ [38;5;34m32[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ conv2d_transpose… │ ([38;5;45mNone[0m, [38;5;34m80[0m, [38;5;34m80[0m, │ [38;5;34m9,248[0m │ activation_14… │ [1;38;5;34mY[0m │ ## │ ([38;5;33mConv2DTranspose[0m) │ [38;5;34m32[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ batch_normalizat… │ ([38;5;45mNone[0m, [38;5;34m80[0m, [38;5;34m80[0m, │ [38;5;34m128[0m │ conv2d_transp… │ [1;38;5;34mY[0m │ ## │ ([38;5;33mBatchNormalizat…[0m │ [38;5;34m32[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ up_sampling2d_7 │ ([38;5;45mNone[0m, [38;5;34m160[0m, │ [38;5;34m0[0m │ add_5[[38;5;34m0[0m][[38;5;34m0[0m] │ [1m-[0m │ ## │ ([38;5;33mUpSampling2D[0m) │ [38;5;34m160[0m, [38;5;34m64[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ up_sampling2d_6 │ ([38;5;45mNone[0m, [38;5;34m160[0m, │ [38;5;34m0[0m │ batch_normali… │ [1m-[0m │ ## │ ([38;5;33mUpSampling2D[0m) │ [38;5;34m160[0m, [38;5;34m32[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ conv2d_7 ([38;5;33mConv2D[0m) │ ([38;5;45mNone[0m, [38;5;34m160[0m, │ [38;5;34m2,080[0m │ up_sampling2d… │ [1;38;5;34mY[0m │ ## │ │ [38;5;34m160[0m, [38;5;34m32[0m) │ │ │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ add_6 ([38;5;33mAdd[0m) │ ([38;5;45mNone[0m, [38;5;34m160[0m, │ [38;5;34m0[0m │ up_sampling2d… │ [1m-[0m │ ## │ │ [38;5;34m160[0m, [38;5;34m32[0m) │ │ conv2d_7[[38;5;34m0[0m][[38;5;34m0[0m] │ │ ## ├───────────────────┼─────────────────┼───────────┼────────────────┼───────┤ ## │ conv2d_8 ([38;5;33mConv2D[0m) │ ([38;5;45mNone[0m, [38;5;34m160[0m, │ [38;5;34m867[0m │ add_6[[38;5;34m0[0m][[38;5;34m0[0m] │ [1;38;5;34mY[0m │ ## │ │ [38;5;34m160[0m, [38;5;34m3[0m) │ │ │ │ ## └───────────────────┴─────────────────┴───────────┴────────────────┴───────┘ ## [1m Total params: [0m[38;5;34m2,058,979[0m (7.85 MB) ## [1m Trainable params: [0m[38;5;34m2,055,203[0m (7.84 MB) ## [1m Non-trainable params: [0m[38;5;34m3,776[0m (14.75 KB)
# Split our img paths into a training and a validation set val_samples <- 1000 val_samples <- sample.int(length(input_img_paths), val_samples) train_input_img_paths <- input_img_paths[-val_samples] train_target_img_paths <- target_img_paths[-val_samples] val_input_img_paths <- input_img_paths[val_samples] val_target_img_paths <- target_img_paths[val_samples] # Instantiate dataset for each split # Limit input files in `max_dataset_len` for faster epoch training time. # Remove the `max_dataset_len` arg when running with full dataset. train_dataset <- get_dataset( batch_size, img_size, train_input_img_paths, train_target_img_paths, max_dataset_len = 1000 ) valid_dataset <- get_dataset( batch_size, img_size, val_input_img_paths, val_target_img_paths )
# Configure the model for training. # We use the "sparse" version of categorical_crossentropy # because our target data is integers. model |> compile( optimizer = optimizer_adam(1e-4), loss = "sparse_categorical_crossentropy" ) callbacks <- list( callback_model_checkpoint( "models/oxford_segmentation.keras", save_best_only = TRUE ) ) # Train the model, doing validation at the end of each epoch. epochs <- 50 model |> fit( train_dataset, epochs=epochs, validation_data=valid_dataset, callbacks=callbacks, verbose=2 )
## Epoch 1/50 ## 32/32 - 27s - 844ms/step - loss: 1.4284 - val_loss: 1.5502 ## Epoch 2/50 ## 32/32 - 2s - 60ms/step - loss: 0.9221 - val_loss: 1.9881 ## Epoch 3/50 ## 32/32 - 2s - 60ms/step - loss: 0.7764 - val_loss: 2.5123 ## Epoch 4/50 ## 32/32 - 2s - 60ms/step - loss: 0.7200 - val_loss: 3.0148 ## Epoch 5/50 ## 32/32 - 2s - 60ms/step - loss: 0.6848 - val_loss: 3.2898 ## Epoch 6/50 ## 32/32 - 2s - 60ms/step - loss: 0.6556 - val_loss: 3.4525 ## Epoch 7/50 ## 32/32 - 2s - 60ms/step - loss: 0.6302 - val_loss: 3.5625 ## Epoch 8/50 ## 32/32 - 2s - 60ms/step - loss: 0.6082 - val_loss: 3.6537 ## Epoch 9/50 ## 32/32 - 2s - 60ms/step - loss: 0.5894 - val_loss: 3.7334 ## Epoch 10/50 ## 32/32 - 2s - 60ms/step - loss: 0.5726 - val_loss: 3.7954 ## Epoch 11/50 ## 32/32 - 2s - 60ms/step - loss: 0.5566 - val_loss: 3.8266 ## Epoch 12/50 ## 32/32 - 2s - 60ms/step - loss: 0.5407 - val_loss: 3.8028 ## Epoch 13/50 ## 32/32 - 2s - 60ms/step - loss: 0.5241 - val_loss: 3.7225 ## Epoch 14/50 ## 32/32 - 2s - 60ms/step - loss: 0.5063 - val_loss: 3.5891 ## Epoch 15/50 ## 32/32 - 2s - 60ms/step - loss: 0.4862 - val_loss: 3.4399 ## Epoch 16/50 ## 32/32 - 2s - 60ms/step - loss: 0.4640 - val_loss: 3.2488 ## Epoch 17/50 ## 32/32 - 2s - 60ms/step - loss: 0.4398 - val_loss: 2.9895 ## Epoch 18/50 ## 32/32 - 2s - 60ms/step - loss: 0.4141 - val_loss: 2.7040 ## Epoch 19/50 ## 32/32 - 2s - 60ms/step - loss: 0.3877 - val_loss: 2.3552 ## Epoch 20/50 ## 32/32 - 2s - 60ms/step - loss: 0.3622 - val_loss: 1.9751 ## Epoch 21/50 ## 32/32 - 2s - 60ms/step - loss: 0.3391 - val_loss: 1.6415 ## Epoch 22/50 ## 32/32 - 2s - 65ms/step - loss: 0.3201 - val_loss: 1.3455 ## Epoch 23/50 ## 32/32 - 2s - 65ms/step - loss: 0.3076 - val_loss: 1.1034 ## Epoch 24/50 ## 32/32 - 2s - 65ms/step - loss: 0.3076 - val_loss: 1.0204 ## Epoch 25/50 ## 32/32 - 2s - 65ms/step - loss: 0.3429 - val_loss: 0.9294 ## Epoch 26/50 ## 32/32 - 2s - 60ms/step - loss: 0.3633 - val_loss: 1.0385 ## Epoch 27/50 ## 32/32 - 2s - 65ms/step - loss: 0.3294 - val_loss: 0.8552 ## Epoch 28/50 ## 32/32 - 2s - 60ms/step - loss: 0.2888 - val_loss: 0.9904 ## Epoch 29/50 ## 32/32 - 2s - 60ms/step - loss: 0.2722 - val_loss: 1.1511 ## Epoch 30/50 ## 32/32 - 2s - 60ms/step - loss: 0.2675 - val_loss: 1.1777 ## Epoch 31/50 ## 32/32 - 2s - 60ms/step - loss: 0.2710 - val_loss: 1.1230 ## Epoch 32/50 ## 32/32 - 2s - 60ms/step - loss: 0.2818 - val_loss: 1.1107 ## Epoch 33/50 ## 32/32 - 2s - 60ms/step - loss: 0.3111 - val_loss: 1.2046 ## Epoch 34/50 ## 32/32 - 2s - 60ms/step - loss: 0.3025 - val_loss: 1.2817 ## Epoch 35/50 ## 32/32 - 2s - 60ms/step - loss: 0.2902 - val_loss: 1.0870 ## Epoch 36/50 ## 32/32 - 2s - 60ms/step - loss: 0.2812 - val_loss: 0.9639 ## Epoch 37/50 ## 32/32 - 2s - 60ms/step - loss: 0.2769 - val_loss: 1.3830 ## Epoch 38/50 ## 32/32 - 2s - 60ms/step - loss: 0.2629 - val_loss: 1.0656 ## Epoch 39/50 ## 32/32 - 2s - 60ms/step - loss: 0.2439 - val_loss: 1.2142 ## Epoch 40/50 ## 32/32 - 2s - 60ms/step - loss: 0.2417 - val_loss: 1.1791 ## Epoch 41/50 ## 32/32 - 2s - 60ms/step - loss: 0.2437 - val_loss: 1.3024 ## Epoch 42/50 ## 32/32 - 2s - 60ms/step - loss: 0.2340 - val_loss: 1.5025 ## Epoch 43/50 ## 32/32 - 2s - 60ms/step - loss: 0.2330 - val_loss: 1.2029 ## Epoch 44/50 ## 32/32 - 2s - 60ms/step - loss: 0.2228 - val_loss: 1.1434 ## Epoch 45/50 ## 32/32 - 2s - 60ms/step - loss: 0.2169 - val_loss: 1.1099 ## Epoch 46/50 ## 32/32 - 2s - 60ms/step - loss: 0.2116 - val_loss: 1.1009 ## Epoch 47/50 ## 32/32 - 2s - 60ms/step - loss: 0.2059 - val_loss: 1.0853 ## Epoch 48/50 ## 32/32 - 2s - 60ms/step - loss: 0.2004 - val_loss: 1.1272 ## Epoch 49/50 ## 32/32 - 2s - 60ms/step - loss: 0.1924 - val_loss: 1.0651 ## Epoch 50/50 ## 32/32 - 2s - 60ms/step - loss: 0.1844 - val_loss: 1.1039
model <- load_model("models/oxford_segmentation.keras") # Generate predictions for all images in the validation set val_dataset <- get_dataset( batch_size, img_size, val_input_img_paths, val_target_img_paths ) val_preds <- predict(model, val_dataset)
## 32/32 - 3s - 83ms/step
display_mask <- function(i) { # Quick utility to display a model's prediction. mask <- val_preds[i,,,] %>% apply(c(1,2), which.max) %>% array_reshape(dim = c(img_size, 1)) mask <- abind::abind(mask, mask, mask, along = 3) plot(as.raster(mask, max = 3)) } # Display results for validation image #10 i <- 10 par(mfrow = c(1, 3)) # Display input image input_img_paths[i] |> jpeg::readJPEG() |> as.raster() |> plot() # Display ground-truth target mask target_img_paths[i] |> png::readPNG() |> magrittr::multiply_by(255)|> as.raster(max = 3) |> plot() # Display mask predicted by our model display_mask(i) # Note that the model only sees inputs at 150x150.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.