Implementation of a Deep Residual Neural Network
Imports
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
models = tf.contrib.keras.models
layers = tf.contrib.keras.layers
initializers = tf.contrib.keras.initializers
regularizers = tf.contrib.keras.regularizers
Pre-activation Bottleneck Residual Block
def residual_block(input_tensor, filters, stage, reg=0.0, use_shortcuts=True):
bn_name = 'bn' + str(stage)
conv_name = 'conv' + str(stage)
relu_name = 'relu' + str(stage)
merge_name = 'merge' + str(stage)
# 1x1 conv
# batchnorm-relu-conv
# from input_filters to bottleneck_filters
if stage>1: # first activation is just after conv1
x = layers.BatchNormalization(name=bn_name+'a')(input_tensor)
x = layers.Activation('relu', name=relu_name+'a')(x)
else:
x = input_tensor
x = layers.Convolution2D(
filters[0], (1,1),
kernel_regularizer=regularizers.l2(reg),
use_bias=False,
name=conv_name+'a'
)(x)
# 3x3 conv
# batchnorm-relu-conv
# from bottleneck_filters to bottleneck_filters
x = layers.BatchNormalization(name=bn_name+'b')(x)
x = layers.Activation('relu', name=relu_name+'b')(x)
x = layers.Convolution2D(
filters[1], (3,3),
padding='same',
kernel_regularizer=regularizers.l2(reg),
use_bias = False,
name=conv_name+'b'
)(x)
# 1x1 conv
# batchnorm-relu-conv
# from bottleneck_filters to input_filters
x = layers.BatchNormalization(name=bn_name+'c')(x)
x = layers.Activation('relu', name=relu_name+'c')(x)
x = layers.Convolution2D(
filters[2], (1,1),
kernel_regularizer=regularizers.l2(reg),
name=conv_name+'c'
)(x)
# merge output with input layer (residual connection)
if use_shortcuts:
x = layers.add([x, input_tensor], name=merge_name)
return x
Full Residual Network
def ResNetPreAct(input_shape=(32,32,3), nb_classes=5, num_stages=5,
use_final_conv=False, reg=0.0):
# Input
img_input = layers.Input(shape=input_shape)
#### Input stream ####
# conv-BN-relu-(pool)
x = layers.Convolution2D(
128, (3,3), strides=(2, 2),
padding='same',
kernel_regularizer=regularizers.l2(reg),
use_bias=False,
name='conv0'
)(img_input)
x = layers.BatchNormalization(name='bn0')(x)
x = layers.Activation('relu', name='relu0')(x)
# x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='pool0')(x)
#### Residual Blocks ####
# 1x1 conv: batchnorm-relu-conv
# 3x3 conv: batchnorm-relu-conv
# 1x1 conv: batchnorm-relu-conv
for stage in range(1,num_stages+1):
x = residual_block(x, [32,32,128], stage=stage, reg=reg)
#### Output stream ####
# BN-relu-(conv)-avgPool-softmax
x = layers.BatchNormalization(name='bnF')(x)
x = layers.Activation('relu', name='reluF')(x)
# Optional final conv layer
if use_final_conv:
x = layers.Convolution2D(
64, (3,3),
padding='same',
kernel_regularizer=regularizers.l2(reg),
name='convF'
)(x)
pool_size = input_shape[0] / 2
x = layers.AveragePooling2D((pool_size,pool_size),name='avg_pool')(x)
x = layers.Flatten(name='flat')(x)
x = layers.Dense(nb_classes, activation='softmax', name='fc10')(x)
return models.Model(img_input, x, name='rnpa')
Inspect Model Architecture
model = ResNetPreAct()
model.summary()
Layer (type) Output Shape Param # Connected to
====================================================================================================
input_3 (InputLayer) (None, 32, 32, 3) 0
____________________________________________________________________________________________________
conv0 (Conv2D) (None, 16, 16, 128) 3456 input_3[0][0]
____________________________________________________________________________________________________
bn0 (BatchNormalization) (None, 16, 16, 128) 512 conv0[0][0]
____________________________________________________________________________________________________
relu0 (Activation) (None, 16, 16, 128) 0 bn0[0][0]
____________________________________________________________________________________________________
conv1a (Conv2D) (None, 16, 16, 32) 4096 relu0[0][0]
____________________________________________________________________________________________________
bn1b (BatchNormalization) (None, 16, 16, 32) 128 conv1a[0][0]
____________________________________________________________________________________________________
relu1b (Activation) (None, 16, 16, 32) 0 bn1b[0][0]
____________________________________________________________________________________________________
conv1b (Conv2D) (None, 16, 16, 32) 9216 relu1b[0][0]
____________________________________________________________________________________________________
bn1c (BatchNormalization) (None, 16, 16, 32) 128 conv1b[0][0]
____________________________________________________________________________________________________
relu1c (Activation) (None, 16, 16, 32) 0 bn1c[0][0]
____________________________________________________________________________________________________
conv1c (Conv2D) (None, 16, 16, 128) 4224 relu1c[0][0]
____________________________________________________________________________________________________
merge1 (Add) (None, 16, 16, 128) 0 conv1c[0][0]
relu0[0][0]
____________________________________________________________________________________________________
bn2a (BatchNormalization) (None, 16, 16, 128) 512 merge1[0][0]
____________________________________________________________________________________________________
relu2a (Activation) (None, 16, 16, 128) 0 bn2a[0][0]
____________________________________________________________________________________________________
conv2a (Conv2D) (None, 16, 16, 32) 4096 relu2a[0][0]
____________________________________________________________________________________________________
bn2b (BatchNormalization) (None, 16, 16, 32) 128 conv2a[0][0]
____________________________________________________________________________________________________
relu2b (Activation) (None, 16, 16, 32) 0 bn2b[0][0]
____________________________________________________________________________________________________
conv2b (Conv2D) (None, 16, 16, 32) 9216 relu2b[0][0]
____________________________________________________________________________________________________
bn2c (BatchNormalization) (None, 16, 16, 32) 128 conv2b[0][0]
____________________________________________________________________________________________________
relu2c (Activation) (None, 16, 16, 32) 0 bn2c[0][0]
____________________________________________________________________________________________________
conv2c (Conv2D) (None, 16, 16, 128) 4224 relu2c[0][0]
____________________________________________________________________________________________________
merge2 (Add) (None, 16, 16, 128) 0 conv2c[0][0]
merge1[0][0]
____________________________________________________________________________________________________
bn3a (BatchNormalization) (None, 16, 16, 128) 512 merge2[0][0]
____________________________________________________________________________________________________
relu3a (Activation) (None, 16, 16, 128) 0 bn3a[0][0]
____________________________________________________________________________________________________
conv3a (Conv2D) (None, 16, 16, 32) 4096 relu3a[0][0]
____________________________________________________________________________________________________
bn3b (BatchNormalization) (None, 16, 16, 32) 128 conv3a[0][0]
____________________________________________________________________________________________________
relu3b (Activation) (None, 16, 16, 32) 0 bn3b[0][0]
____________________________________________________________________________________________________
conv3b (Conv2D) (None, 16, 16, 32) 9216 relu3b[0][0]
____________________________________________________________________________________________________
bn3c (BatchNormalization) (None, 16, 16, 32) 128 conv3b[0][0]
____________________________________________________________________________________________________
relu3c (Activation) (None, 16, 16, 32) 0 bn3c[0][0]
____________________________________________________________________________________________________
conv3c (Conv2D) (None, 16, 16, 128) 4224 relu3c[0][0]
____________________________________________________________________________________________________
merge3 (Add) (None, 16, 16, 128) 0 conv3c[0][0]
merge2[0][0]
____________________________________________________________________________________________________
bn4a (BatchNormalization) (None, 16, 16, 128) 512 merge3[0][0]
____________________________________________________________________________________________________
relu4a (Activation) (None, 16, 16, 128) 0 bn4a[0][0]
____________________________________________________________________________________________________
conv4a (Conv2D) (None, 16, 16, 32) 4096 relu4a[0][0]
____________________________________________________________________________________________________
bn4b (BatchNormalization) (None, 16, 16, 32) 128 conv4a[0][0]
____________________________________________________________________________________________________
relu4b (Activation) (None, 16, 16, 32) 0 bn4b[0][0]
____________________________________________________________________________________________________
conv4b (Conv2D) (None, 16, 16, 32) 9216 relu4b[0][0]
____________________________________________________________________________________________________
bn4c (BatchNormalization) (None, 16, 16, 32) 128 conv4b[0][0]
____________________________________________________________________________________________________
relu4c (Activation) (None, 16, 16, 32) 0 bn4c[0][0]
____________________________________________________________________________________________________
conv4c (Conv2D) (None, 16, 16, 128) 4224 relu4c[0][0]
____________________________________________________________________________________________________
merge4 (Add) (None, 16, 16, 128) 0 conv4c[0][0]
merge3[0][0]
____________________________________________________________________________________________________
bn5a (BatchNormalization) (None, 16, 16, 128) 512 merge4[0][0]
____________________________________________________________________________________________________
relu5a (Activation) (None, 16, 16, 128) 0 bn5a[0][0]
____________________________________________________________________________________________________
conv5a (Conv2D) (None, 16, 16, 32) 4096 relu5a[0][0]
____________________________________________________________________________________________________
bn5b (BatchNormalization) (None, 16, 16, 32) 128 conv5a[0][0]
____________________________________________________________________________________________________
relu5b (Activation) (None, 16, 16, 32) 0 bn5b[0][0]
____________________________________________________________________________________________________
conv5b (Conv2D) (None, 16, 16, 32) 9216 relu5b[0][0]
____________________________________________________________________________________________________
bn5c (BatchNormalization) (None, 16, 16, 32) 128 conv5b[0][0]
____________________________________________________________________________________________________
relu5c (Activation) (None, 16, 16, 32) 0 bn5c[0][0]
____________________________________________________________________________________________________
conv5c (Conv2D) (None, 16, 16, 128) 4224 relu5c[0][0]
____________________________________________________________________________________________________
merge5 (Add) (None, 16, 16, 128) 0 conv5c[0][0]
merge4[0][0]
____________________________________________________________________________________________________
bnF (BatchNormalization) (None, 16, 16, 128) 512 merge5[0][0]
____________________________________________________________________________________________________
reluF (Activation) (None, 16, 16, 128) 0 bnF[0][0]
____________________________________________________________________________________________________
avg_pool (AveragePooling2D) (None, 1, 1, 128) 0 reluF[0][0]
____________________________________________________________________________________________________
flat (Flatten) (None, 128) 0 avg_pool[0][0]
____________________________________________________________________________________________________
fc10 (Dense) (None, 5) 645 flat[0][0]
====================================================================================================
Total params: 96,133
Trainable params: 93,957
Non-trainable params: 2,176
____________________________________________________________________________________________________
Next Lesson
Train and Evaluate ResNet
Image classification task with Flower dataset
Last updated