Distributed training with TensorFlow

When we have a large number of computational resources, we can leverage these computational resources by using a suitable distributed strategy, which can significantly compress the time spent on model training. For different use scenarios, TensorFlow provides us with several distributed strategies in tf.distribute.Strategy that allow us to train models more efficiently.

Training on a single machine with multiple GPUs: MirroredStrategy

MirroredStrategy is a simple and high-performance, data-parallel, synchronous distributed strategy that supports training on multiple GPUs of the same machine. To use this strategy, we simply instantiate a MirroredStrategy strategy:

strategy = tf.distribute.MirroredStrategy()

and place the model construction code in the context of strategy.scope():

with strategy.scope():
    # Model construction code

Tip

You can specify devices in parameters such as:

strategy = tf.distribute.MirroredStrategy(devices=["/gpu:0", "/gpu:1"])

That is, only GPUs 0 and 1 are specified to participate in the distributed policy.

The following code demonstrates using the MirroredStrategy strategy to train MobileNetV2 using Keras on some of the image datasets in TensorFlow Datasets.

import tensorflow as tf
import tensorflow_datasets as tfds

num_epochs = 5
batch_size_per_replica = 64
learning_rate = 0.001

strategy = tf.distribute.MirroredStrategy()
print('Number of devices: %d' % strategy.num_replicas_in_sync)  # 输出设备数量
batch_size = batch_size_per_replica * strategy.num_replicas_in_sync

# 载入数据集并预处理
def resize(image, label):
    image = tf.image.resize(image, [224, 224]) / 255.0
    return image, label

# 使用 TensorFlow Datasets 载入猫狗分类数据集,详见“TensorFlow Datasets数据集载入”一章
dataset = tfds.load("cats_vs_dogs", split=tfds.Split.TRAIN, as_supervised=True)
dataset = dataset.map(resize).shuffle(1024).batch(batch_size)

with strategy.scope():
    model = tf.keras.applications.MobileNetV2(weights=None, classes=2)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        loss=tf.keras.losses.sparse_categorical_crossentropy,
        metrics=[tf.keras.metrics.sparse_categorical_accuracy]
    )

model.fit(dataset, epochs=num_epochs)

In the following test, we used four NVIDIA GeForce GTX 1080 Ti graphics cards on the same machine to do multi-GPU training. The number of epochs is 5 in all cases. when using a single machine with no distributed configuration, although the machine still has four graphics cards, the program just trains directly, with batch size set to 64. When using a distributed training strategy, both total batch size of 64 (batch size of 16 distributed to a single machine) and total batch size of 256 (batch size of 64 distributed to a single machine) were tested.

Dataset

No distributed strategy

Distributed training with 4 gpus (batch size 64)

Distributed training with 4 gpus (batch size 256)

cats_vs_dogs

146s/epoch

39s/epoch

29s/epoch

tf_flowers

22s/epoch

7s/epoch

5s/epoch

It can be seen that the speed of model training has increased significantly with MirroredStrategy.

MirroredStrategy` Process

The steps of MirroredStrategy are as follows.

  • The strategy replicates a complete model on each of the N computing devices before training begins.

  • Each time a batch of data is passed in for training, the data is divided into N copies and passed into N computing devices (i.e. data parallel).

  • N computing devices use local variables (mirror variables) to calculate the gradient of their data separately.

  • Apply all-reduce operations to efficiently exchange and sum gradient data between computing devices, so that each device eventually has the sum of all devices’ gradients.

  • Update local variables (mirror variables) using the results of gradient summation.

  • After all devices have updated their local variables, the next round of training takes place (i.e., this parallel strategy is synchronized).

By default, the MirroredStrategy strategy in TensorFlow uses NVIDIA NCCL for All-reduce operations.

Training on multiple machines: MultiWorkerMirroredStrategy

Multi-machine distributed training in TensorFlow is similar to multi-GPU training in the previous section, just replacing MirroredStrategy with MultiWorkerMirroredStrategy. However, there are some additional settings that need to be made as communication between multiple computers is involved. Specifically, the environment variable TF_CONFIG needs to be set, for example:

os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["localhost:20000", "localhost:20001"]
    },
    'task': {'type': 'worker', 'index': 0}
})

TF_CONFIG consists of two parts, cluster and task.

  • The cluster describes the structure of the entire multi-machine cluster and the network address (IP + port number) of each machine. The value of cluster is the same for each machine.

  • The task describes the role of the current machine. For example, {'type': 'worker', 'index': 0} indicates that the current machine is the 0th worker in cluster (i.e. localhost:20000). The task value of each machine needs to be set separately for the current host.

Once the above is set up, just run the training code on all machines one by one. The machine that runs first will wait before it is connected to other machines. When all the machines is connected, they will start training at the same time.

Hint

Please pay attention to the firewall settings on each machine, especially the need to open ports for communication with other machines. As in the example above, worker 0 needs to open port 20000 and worker 1 needs to open port 20001.

The training tasks in the following example are the same as in the previous section, except that they have been migrated to a multi-computer training environment. Suppose we have two machines, we first deploy the following program on both machines. The only difference is the task part, the first machine is set to {'type': 'worker', 'index': 0} and the second machine is set to {'type': 'worker', 'index': 1}. Next, run the programs on both machines, and when the communication is established, the training process begins automatically.

In the following tests, we build two separate virtual machine instances with a single NVIDIA Tesla K80 on Google Cloud Platform (see the appendix for the usage of GCP), and report the training time with one GPU and the training time with two virtual machine instances for distributed training, respectively. The number of epochs is 5. The batch size is set to 64 when using a single machine with a single GPU, and tested with both a total batch size of 64 (batch size 32 when distributed to a single machine) and a total batch size of 128 (batch size 64 when distributed to a single machine) when using two machines with single GPU.

Dataset

No distributed strategy

Distributed training with 2 machines (batch size 64)

Distributed training with 2 machines (batch size 128)

cats_vs_dogs

1622s

858s

755s

tf_flowers

301s

152s

144s

It can be seen that the speed of model training has also increased considerably.