I have recently encountered a weird bug with TensorFlow. Here I record what the problem is and the way to solve it.
Probelm Description
I use TensorFlow implemented MobileNet V1 for example to elaborate the problem. It is naturally to set is_traning = True
in train mode but is_traning = False
in test mode,
1 2 3 4 5 6 7 8 9 10 11 12 13 |
|
However, if you follow the test mode
to test your trained model, sometimes you will get weird prediction result: all predictions fall into the same class, look like:
1 2 |
|
Pretty weird! Similar results have been encountered by multiple TensorFlowers. When I digged into the source code, I gradually found the root cause lies into the BatchNorm layer. According to my own experience, BatchNorm is often the curse of many deep learning bugs
Bug Traceback
In a nutshell, for the input and output pair $(x,y)$, BatchNorm layer normalizes the input with a mean value $\mu$ and variance value , plus a learnable scale parameter and a shift parameter ,
During training mode, $\mu$ and $\sigma$ are calculated within each minibatch, while in the test mode the accumulated $\mu$ and $\sigma$ are utilized to to normalize the input. That is, train mode and test mode use different $\mu$ and $\sigma$ value. Actually, TensorFlow exploits exponential moving average to calculate $\mu$ and $\sigma$ for test mode,
The decay rate $decay$ is recommended to be near 1, such as 0.999, 0.997 and 0.90. This means $\sigma$ and $\mu$ for test mode are quite different than the $\mu$ and $\sigma$ value in train mode during the train progress, but the difference gradually fakes while the traning processes. Please note that $\mu$ and $\sigma$ are initialized as $0$ and $1$ respectively, which easily leads to relatively fixed $\mu$ and $\sigma$ value during the early training stage. This is why early trained models predict all images as the same class.
Solution
Two solutions are recommended here:
- Set the decay rate as zero. It means $\mu$ and $\sigma$ in test mode use temporarily calculated value within the minibatch. Although it guarantees the avoid the aforementioned bug, it is not recommended here.
- Set the decay rate to a relatively smaller value and train the whole model for more steps, as more training steps reduces the gap between the two values between the train and test mode.