The Keras distribution API is a new interface designed to facilitate
distributed deep learning across a variety of backends like JAX,
TensorFlow and PyTorch. This powerful API introduces a suite of tools
enabling data and model parallelism, allowing for efficient scaling of
deep learning models on multiple accelerators and hosts. Whether
leveraging the power of GPUs or TPUs, the API provides a streamlined
approach to initializing distributed environments, defining device
meshes, and orchestrating the layout of tensors across computational
resources. Through classes like DataParallel and
ModelParallel, it abstracts the complexity involved in
parallel computation, making it easier for developers to accelerate
their machine learning workflows.
The Keras distribution API provides a global programming model that allows developers to compose applications that operate on tensors in a global context (as if working with a single device) while automatically managing distribution across many devices. The API leverages the underlying framework (e.g. JAX) to distribute the program and tensors according to the sharding directives through a procedure called single program, multiple data (SPMD) expansion.
By decoupling the application from sharding directives, the API enables running the same application on a single device, multiple devices, or even multiple clients, while preserving its global semantics.
# This guide assumes there are 8 GPUs available for testing. If you don't have
# 8 gpus available locally, you can set the following envvar to
# make xla initialize the CPU as 8 devices, to enable local testing
Sys.setenv("CUDA_VISIBLE_DEVICES" = "")
Sys.setenv("XLA_FLAGS" = "--xla_force_host_platform_device_count=8")DeviceMesh and TensorLayoutThe keras$distribution$DeviceMesh class in Keras
distribution API represents a cluster of computational devices
configured for distributed computation. It aligns with similar concepts
in jax.sharding.Mesh
and tf.dtensor.Mesh,
where it’s used to map the physical devices to a logical mesh
structure.
The TensorLayout class then specifies how tensors are
distributed across the DeviceMesh, detailing the sharding
of tensors along specified axes that correspond to the names of the axes
in the DeviceMesh.
You can find more detailed concept explainers in the TensorFlow DTensor guide.
## List of 8
##  $ :TFRT_CPU_0
##  $ :TFRT_CPU_1
##  $ :TFRT_CPU_2
##  $ :TFRT_CPU_3
##  $ :TFRT_CPU_4
##  $ :TFRT_CPU_5
##  $ :TFRT_CPU_6
##  $ :TFRT_CPU_7# Define a 2x4 device mesh with data and model parallel axes
mesh <- keras$distribution$DeviceMesh(
  shape = shape(2, 4),
  axis_names = list("data", "model"),
  devices = devices
)
# A 2D layout, which describes how a tensor is distributed across the
# mesh. The layout can be visualized as a 2D grid with "model" as rows and
# "data" as columns, and it is a [4, 2] grid when it mapped to the physical
# devices on the mesh.
layout_2d <- keras$distribution$TensorLayout(
  axes = c("model", "data"),
  device_mesh = mesh
)
# A 4D layout which could be used for data parallelism of an image input.
replicated_layout_4d <- keras$distribution$TensorLayout(
  axes = list("data", NULL, NULL, NULL),
  device_mesh = mesh
)The Distribution class in Keras serves as a foundational
abstract class designed for developing custom distribution strategies.
It encapsulates the core logic needed to distribute a model’s variables,
input data, and intermediate computations across a device mesh. As an
end user, you won’t have to interact directly with this class, but its
subclasses like DataParallel or
ModelParallel.
The DataParallel class in the Keras distribution API is
designed for the data parallelism strategy in distributed training,
where the model weights are replicated across all devices in the
DeviceMesh, and each device processes a portion of the
input data.
Here is a sample usage of this class.
# Create DataParallel with list of devices.
# As a shortcut, the devices can be skipped,
# and Keras will detect all local available devices.
# E.g. data_parallel <- DataParallel()
data_parallel <- keras$distribution$DataParallel(devices = devices)
# Or you can choose to create DataParallel with a 1D `DeviceMesh`.
mesh_1d <- keras$distribution$DeviceMesh(
  shape = shape(8),
  axis_names = list("data"),
  devices = devices
)
data_parallel <- keras$distribution$DataParallel(device_mesh = mesh_1d)
inputs <- random_normal(c(128, 28, 28, 1))
labels <- random_normal(c(128, 10))
dataset <- tensor_slices_dataset(c(inputs, labels)) |>
  dataset_batch(16)
# Set the global distribution.
keras$distribution$set_distribution(data_parallel)
# Note that all the model weights from here on are replicated to
# all the devices of the `DeviceMesh`. This includes the RNG
# state, optimizer states, metrics, etc. The dataset fed into `model |> fit()` or
# `model |> evaluate()` will be split evenly on the batch dimension, and sent to
# all the devices. You don't have to do any manual aggregation of losses,
# since all the computation happens in a global context.
inputs <- keras_input(shape = c(28, 28, 1))
outputs <- inputs |>
  layer_flatten() |>
  layer_dense(units = 200, use_bias = FALSE, activation = "relu") |>
  layer_dropout(0.4) |>
  layer_dense(units = 10, activation = "softmax")
model <- keras_model(inputs = inputs, outputs = outputs)
model |> compile(loss = "mse")
model |> fit(dataset, epochs = 3)## Epoch 1/3
## 8/8 - 0s - 40ms/step - loss: 1.1536
## Epoch 2/3
## 8/8 - 0s - 5ms/step - loss: 1.0540
## Epoch 3/3
## 8/8 - 0s - 6ms/step - loss: 1.0072## 8/8 - 0s - 9ms/step - loss: 0.9620## $loss
## [1] 0.9620273ModelParallel and LayoutMapModelParallel will be mostly useful when model weights
are too large to fit on a single accelerator. This setting allows you to
spit your model weights or activation tensors across all the devices on
the DeviceMesh, and enable the horizontal scaling for the
large models.
Unlike the DataParallel model where all weights are
fully replicated, the weights layout under ModelParallel
usually need some customization for best performances. We introduce
LayoutMap to let you specify the TensorLayout
for any weights and intermediate tensors from global perspective.
LayoutMap is a dict-like object that maps a string to
TensorLayout instances. It behaves differently from a
normal dict in that the string key is treated as a regex when retrieving
the value. The class allows you to define the naming schema of
TensorLayout and then retrieve the corresponding
TensorLayout instance. Typically, the key used to query is
the variable$path attribute, which is the identifier of the
variable. As a shortcut, a list of axis names is also allowed when
inserting a value, and it will be converted to
TensorLayout.
The LayoutMap can also optionally contain a
DeviceMesh to populate the
TensorLayout$device_mesh if it is not set. When retrieving
a layout with a key, and if there isn’t an exact match, all existing
keys in the layout map will be treated as regex and matched against the
input key again. If there are multiple matches, a
ValueError is raised. If no matches are found,
NULL is returned.
mesh_2d <- keras$distribution$DeviceMesh(
  shape = shape(2, 4),
  axis_names = c("data", "model"),
  devices = devices
)
layout_map  <- keras$distribution$LayoutMap(mesh_2d)
# The rule below means that for any weights that match with d1/kernel, it
# will be sharded with model dimensions (4 devices), same for the d1/bias.
# All other weights will be fully replicated.
layout_map["d1/kernel"] <- tuple(NULL, "model")
layout_map["d1/bias"] <- tuple("model")
# You can also set the layout for the layer output like
layout_map["d2/output"] <- tuple("data", NULL)
model_parallel <- keras$distribution$ModelParallel(
  layout_map = layout_map, batch_dim_name = "data"
)
keras$distribution$set_distribution(model_parallel)
inputs <- layer_input(shape = c(28, 28, 1))
outputs <- inputs |>
  layer_flatten() |>
  layer_dense(units = 200, use_bias = FALSE,
              activation = "relu", name = "d1") |>
  layer_dropout(0.4) |>
  layer_dense(units = 10,
              activation = "softmax",
              name = "d2")
model <- keras_model(inputs = inputs, outputs = outputs)We can visualize how individual weights will be sharded
## ┌───────┬───────┬───────┬───────┐
## │       │       │       │       │
## │       │       │       │       │
## │       │       │       │       │
## │       │       │       │       │
## │CPU 0,4│CPU 1,5│CPU 2,6│CPU 3,7│
## │       │       │       │       │
## │       │       │       │       │
## │       │       │       │       │
## │       │       │       │       │
## └───────┴───────┴───────┴───────┘## ┌───────────────────┐
## │                   │
## │                   │
## │                   │
## │                   │
## │CPU 0,1,2,3,4,5,6,7│
## │                   │
## │                   │
## │                   │
## │                   │
## └───────────────────┘## ┌───────────────────┐
## │CPU 0,1,2,3,4,5,6,7│
## └───────────────────┘x_batch <- dataset |>
  as_iterator() |> iter_next() |>
  _[[1]] |> op_convert_to_tensor()
output_array <- model(x_batch)
output_array |> jax$debug$visualize_array_sharding()## ┌─────────────┐
## │             │
## │ CPU 0,1,2,3 │
## │             │
## │             │
## ├─────────────┤
## │             │
## │ CPU 4,5,6,7 │
## │             │
## │             │
## └─────────────┘# The data will be sharded across the "data" dimension of the method, which
# has 2 devices.
model |> compile(loss = "mse")
model |> fit(dataset, epochs = 3)## Epoch 1/3
## 8/8 - 0s - 46ms/step - loss: 1.1676
## Epoch 2/3
## 8/8 - 0s - 4ms/step - loss: 1.1134
## Epoch 3/3
## 8/8 - 0s - 5ms/step - loss: 1.1034## 8/8 - 0s - 8ms/step - loss: 1.0676## $loss
## [1] 1.067567It is also easy to change the mesh structure to tune the computation between more data parallel or model parallel. You can do this by adjusting the shape of the mesh. And no changes are needed for any other code.
full_data_parallel_mesh <- keras$distribution$DeviceMesh(
  shape = shape(8, 1),
  axis_names = list("data", "model"),
  devices = devices
)
more_data_parallel_mesh <- keras$distribution$DeviceMesh(
  shape = shape(4, 2),
  axis_names = list("data", "model"),
  devices = devices
)
more_model_parallel_mesh <- keras$distribution$DeviceMesh(
  shape = shape(2, 4),
  axis_names = list("data", "model"),
  devices = devices
)
full_model_parallel_mesh <- keras$distribution$DeviceMesh(
  shape = shape(1, 8),
  axis_names = list("data", "model"),
  devices = devices
)