Data Parallelism in JAX
Published:
This post will introduce data parallelism in JAX. Specifically, we will look at parallelizing matrices across multiple devices by leveraging pmap
. This post is based on Misha Laskin’s data parallelism blog post.
Parallelism with pmap
JAX offers a range of utilities for distributing models and data across multiple devices. Perhaps the most commonly used functionality is pmap
. Similar to vmap
, pmap
vectorizes matrix multiplications using low-level XLA operations. The function additionally segragates data matrices into smaller chunks and sends them to different devices for multiplication.
Following is an example of pmap
usage for parallelizing matrix multiplication in a linear layer-
There are two things to note above. (1) We split and reshape the input into shape (devices, samples // devices, dims)
as pmap
requires the leading dimension to be less than or equal to devices
. (2) We use in_axes=(0, None)
to map the result along
leading dimension of xs
. This line simultaneously distributes xs
and replicates ws
to their respective devices.
A Linear Model
Now that we have a good understanding of pmap
and its usage, we utilize it for training a 3-dimensional linear regerssion model. We begin by defining our model and loss function. We then train the model using the update function with input samples distributed across devices.
In the above example, we map the output of pmap
across the leading dimension of xs
as before. We use axis_name = num_devices
as a placeholder to denote the leading axis as num_devices
. In principle we can use any name to fix the mapping axis. Finally, we parallelize across devices = 4
as the leading dimension of our input matrix is 4.
Print statements inside the loss function denote the shapes of arrays on each device. We note that xs
have shape (,3)
, params.weight
have shape (4,3)
, params.bias
have shape (4,)
, preds
have shape (,4)
and ys
have shape (,4)
. Since the original xs
is split into 4 equal parts, data parallelism was achieved using pmap
for updating the model.
What if we have a large dataset where the number of samples are greater than the number of devices? We would need to split and reshape our data as we did in the previous section.