A Concrete Introduction to Using TensorFlow Object Detection API to train a NN on your Own Data
This article aims to introduce how to use TensorFlow Object Detection API to train a Network on your own data step by step, from image data labeling to the final prediction procedure.
1. Image Labeling and Data Organization
We firstly introduce the PASCAL_VOC
data organization style. As shown below, image information is divided into three different parts: image files as .jpg
format, image annotations files as .xml
format, and train test split results as .txt
format. It’s worth mentioning that in the train.txt
and val.txt
files, each line is the name indicated for training or validation (without extensions), e.g., 1
or 2
in this scenario.
datasetName
|-- JPEGImages
|-- 1.jpg
|-- 2.jpg
|-- Annotations
|-- 1.xml
|-- 2.xml
|-- ImageSets
|-- Main
|-- train.txt
|-- val.txt
It will be convenient to conform this style at the beginning of data preprocessing. Otherwise you’d spend some time converting your own format to this before using TensorFlow Object Detection APIs.
Additionally, labelImg is a cool tool for image labeling and the generated annotation files are exactly PASCAL_VOC
formatted.
2. TensorFlow Object Detection API Installation
Because TensorFlow Object Detection API is not integrated in the pip or conda TensorFlow installation, we need further steps for downloading and configuration. You can refer to the official guide on github. But there is some things worth mentioning during installation.
- Firstly download this TensorFlow Models Repo.
- Install additional dependencies shown here.
- Install protobuf 3 for your own system (Problems may occur when you are using protobuf 2).
- Configure
PYTHONPATH
for Object Detection API. Refer to this. - Using this command to test your installation.
3. File Preparation (Again)
It would be better to create a new directory and copy or download to that dir. Otherwise you could just use the original object_detection
dir in models/research/
you downloaded just now.
The preferred directory structure is shown below:
myObjectDetection
|-- data # Manually created
|-- datasetName # Your own dataset folder
|-- pascal_label_map.pbtxt # See 3.1
|-- creat_pascal_tf_record.py # See 3.1
|-- record # Manually created
|-- pascal_train.record # See 3.1
|-- pascal_val.record # See 3.1
|-- models # Mannually created
|-- model_name_path # See 3.2
|-- model_config_file.config # See 3.2
|-- train.py # See 3.3
|-- eval.py # See 3.3
|-- export_inference_graph.py # See 3.3
|-- train_dir # See 3.3
|-- exported_inference_graph # See 4
3.1 Generate train and valid records
In order to use TensorFlow Object Detection, you have to create TF record files for training and validation. To do this, you need create_pascal_tf_record.py
file from here and pascal_label_map.pbtxt
file form here. Next, modify them according to your situation.
For pascal_label_map.pbtxt
, you could change the item id and name according to your preferences. Note, id 0 is reserved for background. For create_pascal_tf_record.py
, the required change is shown below:
- change line 56 into
YEARS = ['VOC2007', 'VOC2012', 'merged', 'Your dataset name`]
. - change the
data['folder']
in line 85 into your own dataset path. - In line 154, similarly add your own dataset name into the
years
list. - In line 164, delete the
'aeroplane_' +
when constructing path name.
Next, run the command to generate the tf-record files in record
folder:
# In bash in myObjectDetection dir
# For train
python3 data/create_pascal_tf_record.py \
--data_dir=data\
--year=dataset \
--set=train \
--output_path=record/pascal_train.record \
--label_map_path=data/pascal_label_map.pbtxt
# For valid
python3 data/create_pascal_tf_record.py \
--data_dir=data\
--year=dataset \
--set=val \
--output_path=record/pascal_val.record \
--label_map_path=data/pascal_label_map.pbtxt
3.2 Download required model file
Generally, transfer learning is preferred when training on your own data. Hence, you could use weights from a base model to speed up your training. You can check the supported model zoo shown here and download the model file according to your preference. Subsequently, move the unzipped folder to the models
dir shown above and the corresponding config file shown here to the myObjectDetection
dir.
You may also need to change some configurations in the model config file. There are generally two types of configurations, model hyper parameters and file related configurations. You must change the file related configurations otherwise it would raise errors. Take the ssd_inception_v2_coco.config
file as an example. You have to change:
num_classes
in line 9fine_tune_checkpoint
in line 151 to something likemodels/model_name/model.ckpt
.input_path
andlabel_map_path
for train and val. And you can tune some important hyper parameters like:batch_szie
in line 136num_steps
in line 157
3.3 Other Required Code
Of course, you will need train.py
and eval.py
for training and testing on new images. export_inference_graph.py
will be helpful when exporting your trained weights to inference graphs. They are all located at models/research/object_detection/
. You can also create a train_dir
folder for storing files generated during training.
4. Run training
Finally, you could run training after such a laundry list of preprocessing procedures. Run the following command in bash for training:
# In myObjectDetection dir
python train.py \
--logtostderr \
--train_dir=train_dir \
--model_config_path=model_config_file.config
You can also use tensorboard to view your learning curves during training:
# In myObjectDetection dir
tensorboard --logdir=train_dir
5. Export a trained model for inference
This is an easy step. Just follow the tutorial:
# In myObjectDetection dir
python object_detection/export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path model_config_file.config \
--trained_checkpoint_prefix train_dir/model.ckpt-xxxx \
--output_directory exported_inference_graph
6. Predict on new data
There is no ready code for prediction in the official repo. Hence, you need implement your own version of prediction code. This is rather simple compared training ops. You can refer to this repo or the code shown at the bottom of this post.
This work is licensed under a Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International License.