TensorFlow Gets Weird Prediction Result in Test Mode

I have recently encountered a weird bug with TensorFlow. Here I record what the problem is and the way to solve it.

Probelm Description

I use TensorFlow implemented MobileNet V1 for example to elaborate the problem. It is naturally to set is_traning = True in train mode but is_traning = False in test mode,

However, if you follow the test mode to test your trained model, sometimes you will get weird prediction result: all predictions fall into the same class, look like:

Pretty weird! Similar results have been encountered by multiple TensorFlowers. When I digged into the source code, I gradually found the root cause lies into the BatchNorm layer. According to my own experience, BatchNorm is often the curse of many deep learning bugs

Bug Traceback

In a nutshell, for the input and output pair $(x,y)$, BatchNorm layer normalizes the input $x$ with a mean value $\mu$ and variance value $\sigma$, plus a learnable scale parameter $\gamma$ and a shift parameter $\beta$,

During training mode, $\mu$ and $\sigma$ are calculated within each minibatch, while in the test mode the accumulated $\mu$ and $\sigma$ are utilized to to normalize the input. That is, train mode and test mode use different $\mu$ and $\sigma$ value. Actually, TensorFlow exploits exponential moving average to calculate $\mu$ and $\sigma$ for test mode,

The decay rate $decay$ is recommended to be near 1, such as 0.999, 0.997 and 0.90. This means $\sigma$ and $\mu$ for test mode are quite different than the $\mu$ and $\sigma$ value in train mode during the train progress, but the difference gradually fakes while the traning processes. Please note that $\mu$ and $\sigma$ are initialized as $0$ and $1$ respectively, which easily leads to relatively fixed $\mu$ and $\sigma$ value during the early training stage. This is why early trained models predict all images as the same class.

Solution

Two solutions are recommended here:

1. Set the decay rate as zero. It means $\mu$ and $\sigma$ in test mode use temporarily calculated value within the minibatch. Although it guarantees the avoid the aforementioned bug, it is not recommended here.
2. Set the decay rate to a relatively smaller value and train the whole model for more steps, as more training steps reduces the gap between the two values between the train and test mode.