Joy
Joy

Reputation: 41

How to use Model Parallelism with a custom Tensorflow 2.0 model on TPUs?

To replicate Multimodal Few-Shot Learning with Frozen Language Models, I am trying to train a ~7B parameter subclassed TF2 model on a TPUv3-32. Out of the 7B parameters, roughly 6B parameters are frozen.

I want to use model and data parallelism to train it as efficiently as possible. As far as I know, MeshTensorflow can only be used for models written in TF1.

I tried using experimental_device_assignment from TPUStrategy but it was placing all the variables only on the 1st(0th) core of the TPU which quickly ran out of memory.

Using TPUStrategy

On a TPUv3-8, I tried to keep computation_shape = [2, 2, 1, 2] and [1, 1, 1, 2] and num_replicas = 1 but it didn't work.

I am also open to using GPUs to train it.

Upvotes: 0

Views: 595

Answers (1)

Mike Holcomb
Mike Holcomb

Reputation: 413

According to the cloud TPU documents, there is no official support:

Does Cloud TPU support model parallelism?

Model parallelism (or executing non-identical programs on the multiple cores within a single Cloud TPU device) is not currently supported.

https://cloud.google.com/tpu/docs/faq

The underlying issue may be that there is no automatic sharding of the computation graph in TPUStrategy so the graph is all placed one device, unless (in the model code) you manually assign device placements for weights and operations to the logical devices as created by DeviceAssignment.build and handle communication across the devices carefully.

That said, there is another TF2-compatible library (also from Google) that could help if you are building a big Transformer where you want layers that are friendly to graph sharding: Lingvo. In their Github, there is an example of sharding a model on a TPU v3-512 node. The library has Google's open sourced GPipe which can help speed up model parallel training loops. Lingvo should also work with GPUs.

Upvotes: 2

Related Questions