Load and Explore the Flower Dataset

Image Classification Task

Cifar10 is a famous computer-vision dataset used for object recognition.

The dataset consists of:

  • colored images of various sizes

  • 5 classes

  • 700 images per classes

Imports

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import tarfile
from six.moves import urllib
from glob import glob
import random
import shutil

import tensorflow as tf
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

Download Helper Function

def download_and_uncompress_tarball(tarball_url, dataset_dir):
    """Downloads the `tarball_url` and uncompresses it locally.
    Args:
    tarball_url: The URL of a tarball file.
    dataset_dir: The directory where the temporary files are stored.
    """
    filename = tarball_url.split('/')[-1]
    filepath = os.path.join(dataset_dir, filename)

    def _progress(count, block_size, total_size):
        sys.stdout.write('\r>> Downloading %s %.1f%%' % (
            filename, float(count * block_size) / float(total_size) * 100.0))
        sys.stdout.flush()

    filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress)
    print()
    statinfo = os.stat(filepath)
    print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
    tarfile.open(filepath, 'r:gz').extractall(dataset_dir)

Download Flower Dataset

# The URL where the Flowers data can be downloaded.
DATA_URL = "http://download.tensorflow.org/example_images/flower_photos.tgz"
BASE_DIR = "/Users/marvinbertin/Desktop/tmp"
download_and_uncompress_tarball(DATA_URL, BASE_DIR)

Downloading flower_photos.tgz 98.5%

Load Data Files

def load_data_files(base_dir):
    RAW_DATASET = os.path.join(base_dir,"flower_photos")

    sub_dir = map(lambda d: os.path.basename(d.rstrip("/")), glob(os.path.join(RAW_DATASET,'*/')))

    data_dic = {}
    for class_name  in sub_dir:
        imgs = glob(os.path.join(RAW_DATASET,class_name,"*.jpg"))

        data_dic[class_name] = imgs
        print("Class: {}".format(class_name))
        print("Number of images: {} \n".format(len(imgs)))

    return data_dic
data_dic = load_data_files(BASE_DIR)
Class: daisy
Number of images: 633 

Class: dandelion
Number of images: 898 

Class: roses
Number of images: 641 

Class: sunflowers
Number of images: 699 

Class: tulips
Number of images: 799 

Plotting Helper Function

def plot_image_grid(images_files):
    # figure size
    fig = plt.figure(figsize=(8, 8))

    # load images
    images = [tf.contrib.keras.preprocessing.image.load_img(img) for img in images_files]

    # plot image grid
    for x in range(4):
        for y in range(4):
            ax = fig.add_subplot(4, 4, 4*y+x+1)
            plt.imshow(images[4*y+x])
            plt.xticks(np.array([]))
            plt.yticks(np.array([]))
    plt.show()

Explore Flower Dataset

for class_name, imgs in data_dic.iteritems():
    print("Flower type: {}".format(class_name))
    plot_image_grid(imgs[:16])
Flower type: tulips
Flower type: roses
Flower type: dandelion
Flower type: sunflowers
Flower type: daisy

Split Into Train and Validation Sets

# Create new directory and copy files to it
def copy_files_to_directory(files, directory):
    if not os.path.exists(directory):
        os.makedirs(directory)
        print("Created directory: {}".format(directory))

    for f in files:
        shutil.copy(f, directory)
    print("Copied {} files.\n".format(len(files)))
def train_validation_split(base_dir, data_dic, split_ratio=0.2):
    FLOWER_DATASET = os.path.join(base_dir,"flower_dataset")

    if not os.path.exists(FLOWER_DATASET):
        os.makedirs(FLOWER_DATASET)

    for class_name, imgs in data_dic.iteritems():
        idx_split = int(len(imgs) * split_ratio)
        random.shuffle(imgs)
        validation = imgs[:idx_split]
        train = imgs[idx_split:]

        copy_files_to_directory(train, os.path.join(FLOWER_DATASET,"train",class_name))
        copy_files_to_directory(validation, os.path.join(FLOWER_DATASET,"validation",class_name))
train_validation_split(BASE_DIR, data_dic, split_ratio=0.2)
Created directory: /Users/marvinbertin/Desktop/tmp/flower_dataset/train/tulips
Copied 640 files.

Created directory: /Users/marvinbertin/Desktop/tmp/flower_dataset/validation/tulips
Copied 159 files.

Created directory: /Users/marvinbertin/Desktop/tmp/flower_dataset/train/roses
Copied 513 files.

Created directory: /Users/marvinbertin/Desktop/tmp/flower_dataset/validation/roses
Copied 128 files.

Created directory: /Users/marvinbertin/Desktop/tmp/flower_dataset/train/dandelion
Copied 719 files.

Created directory: /Users/marvinbertin/Desktop/tmp/flower_dataset/validation/dandelion
Copied 179 files.

Created directory: /Users/marvinbertin/Desktop/tmp/flower_dataset/train/sunflowers
Copied 560 files.

Created directory: /Users/marvinbertin/Desktop/tmp/flower_dataset/validation/sunflowers
Copied 139 files.

Created directory: /Users/marvinbertin/Desktop/tmp/flower_dataset/train/daisy
Copied 507 files.

Created directory: /Users/marvinbertin/Desktop/tmp/flower_dataset/validation/daisy
Copied 126 files.

Next Lesson

RestNet Architecture

  • Deep Residual Learning for Image Recognition

  • Skip connections and deep residual blocks

Last updated