Data Parallelism in JAX

3 minute read

Published:

This post will introduce data parallelism in JAX. Specifically, we will look at parallelizing matrices across multiple devices by leveraging pmap. This post is based on Misha Laskin’s data parallelism blog post.

Parallelism with pmap

JAX offers a range of utilities for distributing models and data across multiple devices. Perhaps the most commonly used functionality is pmap. Similar to vmap, pmap vectorizes matrix multiplications using low-level XLA operations. The function additionally segragates data matrices into smaller chunks and sends them to different devices for multiplication.

Following is an example of pmap usage for parallelizing matrix multiplication in a linear layer-

xs = jnp.random.randint((16, 4))  # input of shape (16, 4)
ws = jnp.random.randint((4,))
devices = 8

x_parts = np.stack(jnp.split(xs, devices))  # input split into shape (8, 2, 4)
linear_layer = lambda x, w: jnp.dot(x, w)
out = jax.pmap(linear_layer, in_axes=(0, None))(x_parts, ws)
print(out.shape) # output of shape (8, 2) inferred as (devices, samples // devices)

There are two things to note above. (1) We split and reshape the input into shape (devices, samples // devices, dims) as pmap requires the leading dimension to be less than or equal to devices. (2) We use in_axes=(0, None) to map the result along leading dimension of xs. This line simultaneously distributes xs and replicates ws to their respective devices.

A Linear Model

Now that we have a good understanding of pmap and its usage, we utilize it for training a 3-dimensional linear regerssion model. We begin by defining our model and loss function. We then train the model using the update function with input samples distributed across devices.

from typing import NamedTuple, Tuple
import functools
import jax
import jax.numpy as jnp
import numpy as np

LEARNING_RATE = 0.1

# class for storing model parameters
class Params(NamedTuple):
  weight: jnp.ndarray
  bias: jnp.ndarray
 
# function for initializing model parameters
def init(rng) -> Params:
  weights_key, bias_key = jax.random.split(rng)
  weight = jax.random.normal(weights_key, (3,))
  bias = jax.random.normal(bias_key, (4,))
  return Params(weight, bias)
 
# function for computing the MSE loss
def loss_fn(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> jnp.ndarray:
  pred = xs @ params.weight + params.bias
  print(xs.shape, params.weight.shape, params.bias.shape)
  print(pred.shape, ys.shape)
  return jnp.mean((pred - ys) ** 2)
 
# function for performing one SGD update step
@functools.partial(jax.pmap, in_axes=(None, 0, None), axis_name='num_devices')
def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:
  loss, grads = jax.value_and_grad(loss_fn)(params, xs, ys)
  grads = jax.lax.pmean(grads, axis_name='num_devices')
  loss = jax.lax.pmean(loss, axis_name='num_devices')
  new_params = jax.tree_map(
      lambda param, g: param - g * LEARNING_RATE, params, grads) 
  return new_params, loss

n = 4
d = 3
devices = jax.local_device_count()  # parallelized across 4 devices
xs = jnp.array(np.random.rand(n, d))
ys = jnp.array(np.random.rand(n,))
rng = jax.random.PRNGKey(42)
params = init(rng)
params, loss = update(params, xs, ys)

In the above example, we map the output of pmap across the leading dimension of xs as before. We use axis_name = num_devices as a placeholder to denote the leading axis as num_devices. In principle we can use any name to fix the mapping axis. Finally, we parallelize across devices = 4 as the leading dimension of our input matrix is 4.

Print statements inside the loss function denote the shapes of arrays on each device. We note that xs have shape (,3), params.weight have shape (4,3), params.bias have shape (4,), preds have shape (,4) and ys have shape (,4). Since the original xs is split into 4 equal parts, data parallelism was achieved using pmap for updating the model.

What if we have a large dataset where the number of samples are greater than the number of devices? We would need to split and reshape our data as we did in the previous section.