Train a custom image classification model with Tensorflow 2

Here we'll learn how to train a custom image classification model from a pre-trained MobileNetV2 classifier.

​ β€‹πŸ‘‰ ​HERE


  • tensorflow 2.X installed

  • picsellia installed

pip install tensorflow
pip install picsellia

Setting up your Picsell client

First let's import Tensorflow and the Picsell.ia sdk.

import tensorflow as tf
from picsellia import Client
import os

Let's set the name to your soon to come classification model and put your tokens here:

api_token = "your_token" # API Token from the picsell-IA platform
project_token = "your_project_token" # project token found in project -> settings
model_name = "your_model_name" # Name of your future model

Don't know how to generate a token ?

Now we need to initialize our client so we can communicate with the platform.

clt = Client(api_token=api_token)

We need the annotations and images on our machine. We also need a label map, mapping the labels names to a label ID that the Tensorflow object-detection API can comprehend. When we checked out the network the annotations were downloaded and saved and the label map was generated. We simply need to run dl_pictures() to download the images from the platform if you didn't specified png_dir when checking out the project.

The train_test_split() method is smartly splitting our data in two sets.


Data pre-processing

Converting data into serialized TFRecord files

We want to serialize those images and labels inside a TFRecord format file. By doing so the data will be way more efficiently read by tensorflow. In order to do this we need to generate a tf.Example for each image which stores the image and its label as a protobuf, then we serialize and write those tf.Example objects inside the TFRecord file. First we create some shortcut functions to wrap the features messages. Those functions convert standard TensorFlow types to a tf.Example-compatible tf.train.Feature object. In our case we just want to store the encoded image and the label id.

def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

We can create our .record files from there. To do so, we define a new function which will iterate for each set through each image and generate a tf.Example message that we'll write inside our file.

We use the clt.tf_vars_generator method from the sdk to retrieve the data before converting them into the tf.Example message.

def create_record_files(label_map, record_dir, tfExample_generator):
datasets = ["train", "eval"]
for dataset in datasets:
output_path = os.path.join(record_dir, dataset+".record")
writer =
for variables in tfExample_generator(label_map, ensemble=dataset, annotation_type = "classification"):
(width, height, filename, encoded_jpg, image_format,
classes_text, classes) = variables
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/encoded': _bytes_feature(encoded_jpg),
'image/object/class/label': _int64_feature(classes[0]-1)
print('Successfully created the TFRecords: {}'.format(output_path))
label_map = {v:int(k) for k,v in clt.label_map.items()}
create_record_files(label_map=label_map, record_dir=clt.record_dir,

Building our input pipeline

Now that our data are saved in an efficient format we want to load them as a tf.Data.Dataset object. We have to define a feature_description dictionnary that follows the same structure as the one used to generate the tf.Example. With this dictionnary we can define a parser for the tf.Example

feature_description = {
'image/encoded':[], tf.string),
'image/object/class/label':[], tf.int64, default_value=0)}
def _parse_function(example_proto):
# Parse the input `tf.Example` proto using the dictionary above.
return, feature_description)

Let's create the tf.Data.dataset objects now by mapping the parser to the raw datasets !

raw_dataset =,"train.record"))
train_dataset =
raw_dataset =,"eval.record"))
eval_dataset =

Now that we have our dataset objects we want to do some pre-processing on them. For the label we will simply one_hot encode them. The images require a bit more attention. We will decode them, then resize them according to the size of the mobilenet_v2 model base input. Then we'll use the quite convenient mobilenet_v2.preprocess_input() function that cast the type to tf.float32 and scale the pixels between -1 and 1.

from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
def map_img_label(example_proto):
img =["image/encoded"], channels=3)
img = tf.image.resize(img, (224,224))
img = tf.keras.applications.mobilenet_v2.preprocess_input(img)
label = example_proto["image/object/class/label"]
label = tf.one_hot(label, depth=2)
return (img,label)
train_set =
eval_set =

Now we want to shuffle and batch our datasets. With a tf.Data.dataset it's fairly simple. We just need to apply the corresponding method with some arguments, namely the batch size and the buffer size for the shuffling.

We define some arbitrary values then apply the methods to our datasets. We do not use the repeat() method of a dataset object because we want our epoch to end when the whole dataset is exhausted. If we added this method to both datasets, we would need to pass a steps_per_epoch and validation_steps to the fit method of our model when starting the training. Indeed, Tensorflow would not be able to know when to stop an epoch since the dataset will be infinitely repeating itself.

At this stage we could add some data augmentation by mapping functions to the dataset. However we will not do it in this guide.

train_set = train_set.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
eval_set = eval_set.batch(BATCH_SIZE)

Model creation and training

Model definition

Now that our input pipeline is built it's time to define our model. As said earlier we are going to do some transfer learning on the MobileNetV2 model. First let's import some keras functions and the MobileNetV2 model.

from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import AveragePooling2D
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model

Our model will be made up of two sub-models. The first part will be the MobileNetV2 model with all of its layers frozen and we will plug on top of it a little headModel defined below.

baseModel = MobileNetV2(weights="imagenet", include_top=False,
input_tensor=Input(shape=(224, 224, 3)))
headModel = baseModel.output
headModel = AveragePooling2D(pool_size=(7, 7))(headModel)
headModel = Flatten(name="flatten")(headModel)
headModel = Dense(128, activation="relu")(headModel)
headModel = Dropout(0.5)(headModel)
headModel = Dense(2, activation="softmax")(headModel)
model = Model(inputs = baseModel.input, outputs = headModel)
for layer in baseModel.layers:
layer.trainable = False

We can print the summary of our model and see all the different layers as well as the number of trainable/non-trainable parameters.


Compiling the model

We first define some arbitrary hyperparameters and a specific optimizer. The next step is to compile our model. It's here that we can set the loss, metrics and optimizer chosen.

from tensorflow.keras.optimizers import Adam
INIT_LR = 1e-4
EPOCHS = 100
opt = Adam(lr=INIT_LR, decay=INIT_LR / EPOCHS)
model.compile(loss="binary_crossentropy", optimizer=opt,

Training the model

Let's start the training by using the fit method of our model. As arguments we simply specify a tf.Data train and validation sets and the number of epochs.

History =,



By default the fit method of a model returns a tf.keras.callbacks.History object which has some base logs from the training. We want to send those logs to the platform to see them on the dashboard.

logs = {k:{"step": [str(e) for e in History.epoch], "value":[str(round(val, 3)) for val in v] } for k,v in History.history.items()}

This will create and send a dictionnary containing the logs in the right format for the platform to display them.


We want to save a checkpoint of our model to, for example, continue the training later. To do this we want to create a Checkpoint object with our model and optimizer and save it.

checkpoint = tf.train.Checkpoint(optimizer=opt, model=model), "model.ckpt"))

Saving a checkpoint is nice but mostly useful for future trainings. To directly save the model we simply need to apply the save method to our model and specify the directory.