Yuhang He's Blog

Some birds are not meant to be caged, their feathers are just too bright.

TensorFlow Function Currying in Slim NetFactory

Slim is widely used in TensorFlow. All networks , including ResNet, Inception and MobileNet, are wrapped by a net factory: nets_factory.py. A typical network call looks like:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import tensorflow as tf
import functools

def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False):
  """Returns a network_fn such as `logits, end_points = network_fn(images)`.
  Args:
    name: The name of the network.
    num_classes: The number of classes to use for classification. If 0 or None,
      the logits layer is omitted and its input features are returned instead.
    weight_decay: The l2 coefficient for the model weights.
    is_training: `True` if the model is being used for training and `False`
      otherwise.
  Returns:
    network_fn: A function that applies the model to a batch of images. It has
      the following signature:
          net, end_points = network_fn(images)
      The `images` input is a tensor of shape [batch_size, height, width, 3]
      with height = width = network_fn.default_image_size. (The permissibility
      and treatment of other sizes depends on the network_fn.)
      The returned `end_points` are a dictionary of intermediate activations.
      The returned `net` is the topmost layer, depending on `num_classes`:
      If `num_classes` was a non-zero integer, `net` is a logits tensor
      of shape [batch_size, num_classes].
      If `num_classes` was 0 or `None`, `net` is a tensor with the input
      to the logits layer of shape [batch_size, 1, 1, num_features] or
      [batch_size, num_features]. Dropout has not been applied to this
      (even if the network's original classification does); it remains for
      the caller to do this or not.
  Raises:
    ValueError: If network `name` is not recognized.
  """
  if name not in networks_map:
    raise ValueError('Name of network unknown %s' % name)
  func = networks_map[name]
  @functools.wraps(func)
  def network_fn(images, **kwargs):
    arg_scope = arg_scopes_map[name](weight_decay=weight_decay)
    with slim.arg_scope(arg_scope):
      return func(images, num_classes, is_training=is_training, **kwargs)
  if hasattr(func, 'default_image_size'):
    network_fn.default_image_size = func.default_image_size

  return network_fn

The entry function get_network_fn(name, num_classes, weight_decay=0.0, is_training=False) receives merely four args. However, many networks accept more than four args (more precisely, three args here as the name indicates the network name). For example, resnet v2 accepts multiple args:

1
2
3
4
5
6
7
8
9
10
def resnet_v2(inputs,
              blocks,
              num_classes=None,
              is_training=True,
              global_pool=True,
              output_stride=None,
              include_root_block=True,
              spatial_squeeze=True,
              reuse=None,
              scope=None):

A problem naturally arises: What happened here? How to appropriately call the right network function?

The mystery lies in function currying. Currying is the technique of breaking down the evaluation of a function that takes multiple arguments into evaluating a sequence of single-argument functions. Slim network factory just exploits this advantage to break down the complex and multiple arguments into prerequiste arguments that are shared by all network entry functions and accessory arguments typically belong each individual network function.

Let’s again take a look at the wrapper function:

1
2
3
4
5
6
7
8
9
10
11
12
13
def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False):
  if name not in networks_map:
    raise ValueError('Name of network unknown %s' % name)
  func = networks_map[name]
  @functools.wraps(func)
  def network_fn(images, **kwargs):
    arg_scope = arg_scopes_map[name](weight_decay=weight_decay)
    with slim.arg_scope(arg_scope):
      return func(images, num_classes, is_training=is_training, **kwargs)
  if hasattr(func, 'default_image_size'):
    network_fn.default_image_size = func.default_image_size

  return network_fn

A decorator is applied to define the network_fn() which accepts **kwargs and **kwargs handles all the extra arguments. Therefore, a typical call of resnet might look like:

1
2
3
name = 'resnet_v2_50'
network_fn = net_factory.get_network_fn( name, num_classes, weight_decay, is_training )
logit, endpoint = network_fn( inputs = img_batch, global_pool = True, spatial_squeeze = True, scope = None )