Implementation of a CNN Fire module for SqueezeNet
Imports
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
layers = tf.contrib.keras.layers
models = tf.contrib.keras.models
Implementation of Fire Module
def fire_module(x, fire_id, squeeze=16, expand=64):
sq1x1 = "squeeze1x1"
exp1x1 = "expand1x1"
exp3x3 = "expand3x3"
relu = "relu_"
s_id = 'fire' + str(fire_id) + '/'
# Squeeze layer
x = layers.Convolution2D(squeeze, (1,1), padding='valid', name=s_id + sq1x1)(x)
x = layers.Activation('relu', name=s_id + relu + sq1x1)
# Expand layer 1x1 filters
left = layers.Convolution2D(expand, (1,1), padding='valid', name=s_id + exp1x1)(x)
left = layers.Activation('relu', name=s_id + relu + exp1x1)(left)
# Expand layer 3x3 filters
right = layers.Convolution2D(expand, (3,3), padding='same', name=s_id + exp3x3)(x)
right = layers.Activation('relu', name=s_id + relu + exp3x3)(right)
# concatenate outputs
x = layers.concatenate([left, right], axis=3, name=s_id + 'concat')
return x
Implementation of SqueezeNet
def SqueezeNet(input_shape=(32,32,3), classes=10):
img_input = layers.Input(shape=input_shape)
x = layers.Convolution2D(64, (3, 3), strides=(2, 2), padding='valid', name='conv1')(img_input)
x = layers.Activation('relu', name='relu_conv1')(x)
# x = layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool1')(x)
x = fire_module(x, fire_id=2, squeeze=16, expand=64)
x = fire_module(x, fire_id=3, squeeze=16, expand=64)
x = layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool3')(x)
x = fire_module(x, fire_id=4, squeeze=32, expand=128)
x = fire_module(x, fire_id=5, squeeze=32, expand=128)
x = layers.Dropout(0.5, name='drop9')(x)
x = layers.Convolution2D(classes, (1, 1), padding='valid', name='conv10')(x)
x = layers.Activation('relu', name='relu_conv10')(x)
x = layers.GlobalAveragePooling2D()(x)
out = layers.Activation('softmax', name='loss')(x)
model = models.Model(img_input, out, name='squeezenet')
return model
Inspect SqueezeNet Architecture
squeeze_net = SqueezeNet()
squeeze_net.summary()
Layer (type) Output Shape Param # Connected to
====================================================================================================
input_3 (InputLayer) (None, 32, 32, 3) 0
____________________________________________________________________________________________________
conv1 (Conv2D) (None, 15, 15, 64) 1792 input_3[0][0]
____________________________________________________________________________________________________
relu_conv1 (Activation) (None, 15, 15, 64) 0 conv1[0][0]
____________________________________________________________________________________________________
fire2/squeeze1x1 (Conv2D) (None, 15, 15, 16) 1040 relu_conv1[0][0]
____________________________________________________________________________________________________
fire2/relu_squeeze1x1 (Activatio (None, 15, 15, 16) 0 fire2/squeeze1x1[0][0]
____________________________________________________________________________________________________
fire2/expand1x1 (Conv2D) (None, 15, 15, 64) 1088 fire2/relu_squeeze1x1[0][0]
____________________________________________________________________________________________________
fire2/expand3x3 (Conv2D) (None, 15, 15, 64) 9280 fire2/relu_squeeze1x1[0][0]
____________________________________________________________________________________________________
fire2/relu_expand1x1 (Activation (None, 15, 15, 64) 0 fire2/expand1x1[0][0]
____________________________________________________________________________________________________
fire2/relu_expand3x3 (Activation (None, 15, 15, 64) 0 fire2/expand3x3[0][0]
____________________________________________________________________________________________________
fire2/concat (Concatenate) (None, 15, 15, 128) 0 fire2/relu_expand1x1[0][0]
fire2/relu_expand3x3[0][0]
____________________________________________________________________________________________________
fire3/squeeze1x1 (Conv2D) (None, 15, 15, 16) 2064 fire2/concat[0][0]
____________________________________________________________________________________________________
fire3/relu_squeeze1x1 (Activatio (None, 15, 15, 16) 0 fire3/squeeze1x1[0][0]
____________________________________________________________________________________________________
fire3/expand1x1 (Conv2D) (None, 15, 15, 64) 1088 fire3/relu_squeeze1x1[0][0]
____________________________________________________________________________________________________
fire3/expand3x3 (Conv2D) (None, 15, 15, 64) 9280 fire3/relu_squeeze1x1[0][0]
____________________________________________________________________________________________________
fire3/relu_expand1x1 (Activation (None, 15, 15, 64) 0 fire3/expand1x1[0][0]
____________________________________________________________________________________________________
fire3/relu_expand3x3 (Activation (None, 15, 15, 64) 0 fire3/expand3x3[0][0]
____________________________________________________________________________________________________
fire3/concat (Concatenate) (None, 15, 15, 128) 0 fire3/relu_expand1x1[0][0]
fire3/relu_expand3x3[0][0]
____________________________________________________________________________________________________
pool3 (MaxPooling2D) (None, 7, 7, 128) 0 fire3/concat[0][0]
____________________________________________________________________________________________________
fire4/squeeze1x1 (Conv2D) (None, 7, 7, 32) 4128 pool3[0][0]
____________________________________________________________________________________________________
fire4/relu_squeeze1x1 (Activatio (None, 7, 7, 32) 0 fire4/squeeze1x1[0][0]
____________________________________________________________________________________________________
fire4/expand1x1 (Conv2D) (None, 7, 7, 128) 4224 fire4/relu_squeeze1x1[0][0]
____________________________________________________________________________________________________
fire4/expand3x3 (Conv2D) (None, 7, 7, 128) 36992 fire4/relu_squeeze1x1[0][0]
____________________________________________________________________________________________________
fire4/relu_expand1x1 (Activation (None, 7, 7, 128) 0 fire4/expand1x1[0][0]
____________________________________________________________________________________________________
fire4/relu_expand3x3 (Activation (None, 7, 7, 128) 0 fire4/expand3x3[0][0]
____________________________________________________________________________________________________
fire4/concat (Concatenate) (None, 7, 7, 256) 0 fire4/relu_expand1x1[0][0]
fire4/relu_expand3x3[0][0]
____________________________________________________________________________________________________
fire5/squeeze1x1 (Conv2D) (None, 7, 7, 32) 8224 fire4/concat[0][0]
____________________________________________________________________________________________________
fire5/relu_squeeze1x1 (Activatio (None, 7, 7, 32) 0 fire5/squeeze1x1[0][0]
____________________________________________________________________________________________________
fire5/expand1x1 (Conv2D) (None, 7, 7, 128) 4224 fire5/relu_squeeze1x1[0][0]
____________________________________________________________________________________________________
fire5/expand3x3 (Conv2D) (None, 7, 7, 128) 36992 fire5/relu_squeeze1x1[0][0]
____________________________________________________________________________________________________
fire5/relu_expand1x1 (Activation (None, 7, 7, 128) 0 fire5/expand1x1[0][0]
____________________________________________________________________________________________________
fire5/relu_expand3x3 (Activation (None, 7, 7, 128) 0 fire5/expand3x3[0][0]
____________________________________________________________________________________________________
fire5/concat (Concatenate) (None, 7, 7, 256) 0 fire5/relu_expand1x1[0][0]
fire5/relu_expand3x3[0][0]
____________________________________________________________________________________________________
drop9 (Dropout) (None, 7, 7, 256) 0 fire5/concat[0][0]
____________________________________________________________________________________________________
conv10 (Conv2D) (None, 7, 7, 10) 2570 drop9[0][0]
____________________________________________________________________________________________________
relu_conv10 (Activation) (None, 7, 7, 10) 0 conv10[0][0]
____________________________________________________________________________________________________
global_average_pooling2d_3 (Glob (None, 10) 0 relu_conv10[0][0]
____________________________________________________________________________________________________
loss (Activation) (None, 10) 0 global_average_pooling2d_3[0][0]
====================================================================================================
Total params: 122,986
Trainable params: 122,986
Non-trainable params: 0
____________________________________________________________________________________________________
Next Lesson
Train and Evaluate SqueezeNet
Image classification task with Cifar10
Last updated