So now we all know that Capsule Networks (by Geoffrey Hinton) is shaking up the AI space and literature state that it will push the limits of Convolutional Neural Network (CNN) to the next level. There are a lot of Medium posts, articles and research papers available that discuss the theory and how it is better than traditional CNN’s. So I am not going to cover that part, instead I would try to implement the CpNet on TensorFlow using Google’s amazing internal tool called Colaboratory.



Few links you can follow to understand the theory part of CapNet:

Now we start writing code.
Step 1. Clone this repository to local.

!git clone https://github.com/XifengGuo/CapsNet-Keras.git capsnet-keras
cd capsnet-keras
!git checkout tf2.2 # Only if use Tensorflow>=2.0

Step 2. Install Tensorflow>=2.0

!pip install tensorflow==2.2

Step 3. Train a CapsNet on MNIST

Training with default settings:

!python capsulenet.py



Step 4. Test a pre-trained CapsNet model

Suppose you have trained a model using the above command, then the trained model will be saved to result/trained_model.h5. Now just launch the following command to get test results.

!python capsulenet.py -t -w result/trained_model.h5

It will output the testing accuracy and show the reconstructed images. The testing data is the same as the validation data. It will be easy to test on new data, just change the code as you want.

You can also just download a trained model from https://pan.baidu.com/s/1sldqQo1 or https://drive.google.com/open?id=1A7pRxH7iWzYZekzr-O0nrwqdUUpUpkik

Step 5. Train on multi GPUs

This requires Keras>=2.0.9. After updating Keras:

!python capsulenet-multi-gpu.py --gpus 2

It will automatically train on multi GPUs for 50 epochs and then output the performance on the test dataset. But during training, no validation accuracy is reported.

Results

Test Errors

CapsNet classification test error on MNIST. Average and standard deviation results are reported by 3 trials. The results can be reproduced by launching the following commands.

python capsulenet.py --routings 1 --lam_recon 0.0    #CapsNet-v1   
python capsulenet.py --routings 1 --lam_recon 0.392  #CapsNet-v2
python capsulenet.py --routings 3 --lam_recon 0.0    #CapsNet-v3 
python capsulenet.py --routings 3 --lam_recon 0.392  #CapsNet-v4
Method Routing Reconstruction MNIST (%) Paper
Baseline 0.39
CapsNet-v1 1 no 0.39 (0.024) 0.34 (0.032)
CapsNet-v2 1 yes 0.36 (0.009) 0.29 (0.011)
CapsNet-v3 3 no 0.40 (0.016) 0.35 (0.036)
CapsNet-v4 3 yes 0.34 (0.016) 0.25 (0.005)