Distributed Training with Uneven Inputs Using the Join Context Manager¶
Author: Andrew Gu
Note
View the source code for this tutorial in github.
Note
Join
is introduced in PyTorch 1.10 as a prototype feature. This
API is subject to change.
In this tutorial, you will see:
- An overview of the Join context manager.
- An example of how to use the context manager with
DistributedDataParallel
. - An example of how to use the context manager with both
DistributedDataParallel
andZeroRedundancyOptimizer
. - An example of passing in keyword arguments to the context manager.
- A dive into how the Join context manager works.
- An example showing how to make a toy class compatible with the context manager.
Requirements¶
What is Join
?¶
In Getting Started with Distributed Data Parallel - Basic Use Case, you saw the general skeleton for using DistributedDataParallel to perform data parallel training. This implicitly schedules all-reduces in each backward pass to synchronize gradients across ranks. Such collective communications require participation from all ranks in the process group, so if a rank has fewer inputs, then the other ranks will hang or error (depending on the backend). More generally, this problem persists for any class that performs per-iteration synchronous collective communications.
Join
is a context manager to be used around your per-rank training loop to
facilitate training with uneven inputs. The context manager allows the ranks
that exhaust their inputs early (i.e. join early) to shadow the collective
communications performed by those that have not yet joined. The ways in which
the communications are shadowed are specified by hooks.
Using Join
with DistributedDataParallel
¶
PyTorch’s DistributedDataParallel works out-of-the-box with the Join
context manager. Here is an example usage:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP
BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5
def worker(rank):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
# Rank 1 gets one more input than rank 0
inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
num_inputs = 0
with Join([model]):
for input in inputs:
num_inputs += 1
loss = model(input).sum()
loss.backward()
print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")
def main():
mp.spawn(worker, nprocs=WORLD_SIZE, join=True)
if __name__ == "__main__":
main()
This produces the following output (where the print()
s from rank 0 and
rank 1 may be arbitrarily ordered):
Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs!
Note
DistributedDataParallel provided its own join() context manager
prior to the introduction of this generic Join
context manager. In the
above example, using with Join([model]):
is equivalent to using
with model.join():
. One limitation of the existing
DistributedDataParallel.join()
is that it does not allow multiple
participating classes, e.g. DistributedDataParallel
and
ZeroRedundancyOptimizer together.
Using Join
with DistributedDataParallel
and ZeroRedundancyOptimizer
¶
The Join
context manager works not only with a single class but also with
multiple classes together. PyTorch’s ZeroRedundancyOptimizer
is also
compatible with the context manager, so here, we examine how to modify the
previous example to use both DistributedDataParallel
and
ZeroRedundancyOptimizer
:
from torch.distributed.optim import ZeroRedundancyOptimizer as ZeRO
from torch.optim import Adam
def worker(rank):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
optim = ZeRO(model.parameters(), Adam, lr=0.01)
# Rank 1 gets one more input than rank 0
inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
num_inputs = 0
# Pass both `model` and `optim` into `Join()`
with Join([model, optim]):
for input in inputs:
num_inputs += 1
loss = model(input).sum()
loss.backward()
optim.step()
print(f"Rank {rank} has exhausted all {num_inputs} of its inputs!")
This will yield the same output as before. The notable change was
additionally passing in the ZeroRedundancyOptimizer
instance into
Join()
.
Passing Keyword Arguments¶
Classes may provide keyword arguments that modify their behavior in the context
manager at run time. For example, DistributedDataParallel
provides an
argument divide_by_initial_world_size
, which determines if gradients are
divided by the initial world size or by the effective world size (i.e. number
of non-joined ranks). Such keyword arguments can be passed directly into the
context manager.
with Join([model, optim], divide_by_initial_world_size=False):
for input in inputs:
...
Warning
The keyword arguments passed into the context manager are shared across
all participating classes. This should not be a limitation since we do
not expect cases where multiple Joinable
s need differing settings
of the same argument. Nonetheless, this is something to keep in mind.
How Does Join
Work?¶
Now that we have seen some preliminary examples of how to use the Join
context manager, let us delve deeper into how it works. This will provide a
greater insight into the full capability that it offers and prepare you to make
your own custom classes compatible. Here, we will go over the Join
class as
well as the supporting classes Joinable
and JoinHook
.
Joinable
¶
To begin, classes compatible with the Join
context manager must inherit
from the abstract base class Joinable
. In particular, a Joinable
must
implement:
join_hook(self, **kwargs) -> JoinHook
This returns the JoinHook
instance for the Joinable
, determining how
joined processes should shadow the per-iteration collective communications
performed by the Joinable
.
join_device(self) -> torch.device
This returns a device to be used by the Join
context manager to perform
collective communications, e.g. torch.device("cuda:0")
or
torch.device("cpu")
.
join_process_group(self) -> ProcessGroup
This returns the process group to be used by the Join
context manager to
perform collective communications.
In particular, the join_device
and join_process_group
are required
attributes to ensure that the context manager can schedule collective
communications between joined and non-joined processes. One usage is to count
the number of non-joined processes on each iteration using an all-reduce.
Another usage is for implementing the mechanism required for
throw_on_early_termination=True
, which we will explain later below.
DistributedDataParallel
and ZeroRedundancyOptimizer
already inherit
from Joinable
and implement the above methods, which is why we could
directly use them in the previous examples.
Joinable
classes should make sure to call the Joinable
constructor
since it initializes a JoinConfig
instance, which is used internally by
the context manager to ensure correctness. This will be saved in each
Joinable
as a field _join_config
.
JoinHook
¶
Next, let us break down the JoinHook
class. A JoinHook
provides two
entry points into a context manager:
main_hook(self) -> None
This hook is called repeatedly by each joined rank while there exists a rank
that has not yet joined. It is meant to shadow the collective communications
performed by the Joinable
in each training iteration (e.g. in one forward
pass, backward pass, and optimizer step).
post_hook(self, is_last_joiner: bool) -> None
This hook is called once all ranks have joined. It is passed an additional
bool
argument is_last_joiner
, which indicates if the rank was one of
the last to join. The argument may be useful for synchronization.
To give concrete examples of what these hooks may look like, the provided
ZeroRedundancyOptimizer
main hook performs an optimizer step per normal
since the joined rank is still responsible for updating and synchronizing its
shard of the parameters, and the provided DistributedDataParallel
post-hook
broadcasts the final updated model from one of the last joining ranks to ensure
that it is the same across all ranks.
Join
¶
Finally, let us examine how these fit into the Join
class itself.
__init__(self, joinables: List[Joinable], enable: bool = True, throw_on_early_termination: bool = False)
As we saw in the previous examples, the constructor takes in a list of the
Joinable
s that participate in the training loop. These should be the
classes that perform collective communications in each iteration.
enable
is a bool
that can be set to False
if you know that there
will not be uneven inputs, in which case the context manager becomes vacuous
similar to contextlib.nullcontext()
. This also may disable join-related
computation in the participating Joinable
s.
throw_on_early_termination
is a bool
that can be set to True
to
have each rank raise an exception the moment that uneven inputs are detected.
This is useful for cases that do not conform to the context manager’s
requirements, which is most typically when there are collective communications
from different classes that may be arbitrarily interleaved, such as when using
DistributedDataParallel
with a model that has SyncBatchNorm
layers. In
such cases, this argument should be set to True
so that the application
logic can catch the exception and determine how to proceed.
- The core logic occurs in the
__exit__()
method, which loops while there exists a non-joined rank, calling eachJoinable
‘s main hook, and then once all ranks have joined, calls their post hooks. Both the main hooks and post-hooks are iterated over in the order that theJoinable
s are passed in. - The context manager requires a heartbeat from non-joined processes. As such,
each
Joinable
class should make a call toJoin.notify_join_context()
before its per-iteration collective communications. The context manager will ensure that only the firstJoinable
passed in actually sends the heartbeat.
Warning
As mentioned above regarding throw_on_early_termination
, the
Join
context manager is not compatible with certain compositions of
classes. The Joinable
‘s JoinHook
s must be serializable since each
hook is fully executed before proceeding to the next. In other words, two
hooks cannot overlap. Moreover, currently, both the main hooks and post-
hooks are iterated over in the same deterministic order. If this appears to
be a major limitation, we may modify the API to permit a customizable
ordering.
Making a Toy Class Work with Join
¶
Since the previous section introduced several concepts, let us see them in
practice with a toy example. Here, we will implement a class that counts the
number of inputs that are seen across all ranks before its rank joins. This
should provide a basic idea of how you may make your own class compatible
with the Join
context manager.
Specifically, the following code has each rank print out (1) the number of inputs across all ranks that seen before it joins and (2) the total number of inputs across all ranks.
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
BACKEND = "nccl"
WORLD_SIZE = 2
NUM_INPUTS = 5
class CounterJoinHook(JoinHook):
r"""
Join hook for :class:`Counter`.
Arguments:
counter (Counter): the :class:`Counter` object using this hook.
sync_max_count (bool): whether to sync the max count once all ranks
join.
"""
def __init__(
self,
counter,
sync_max_count
):
self.counter = counter
self.sync_max_count = sync_max_count
def main_hook(self):
r"""
Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.
"""
t = torch.zeros(1, device=self.counter.device)
dist.all_reduce(t)
def post_hook(self, is_last_joiner: bool):
r"""
Synchronizes the max count across all :class:`Counter` s if
``sync_max_count=True``.
"""
if not self.sync_max_count:
return
rank = dist.get_rank(self.counter.process_group)
common_rank = self.counter.find_common_rank(rank, is_last_joiner)
if rank == common_rank:
self.counter.max_count = self.counter.count.detach().clone()
dist.broadcast(self.counter.max_count, src=common_rank)
class Counter(Joinable):
r"""
Example :class:`Joinable` that counts the number of training iterations
that it participates in.
"""
def __init__(self, device, process_group):
super(Counter, self).__init__()
self.device = device
self.process_group = process_group
self.count = torch.tensor([0], device=device).float()
self.max_count = torch.tensor([0], device=device).float()
def __call__(self):
r"""
Counts the number of inputs processed on this iteration by all ranks
by all-reducing a dim-1 one tensor; increments its own internal count.
"""
Join.notify_join_context(self)
t = torch.ones(1, device=self.device).float()
dist.all_reduce(t)
self.count += t
def join_hook(self, **kwargs) -> JoinHook:
r"""
Return a join hook that shadows the all-reduce in :meth:`__call__`.
This join hook supports the following keyword arguments:
sync_max_count (bool, optional): whether to synchronize the maximum
count across all ranks once all ranks join; default is ``False``.
"""
sync_max_count = kwargs.get("sync_max_count", False)
return CounterJoinHook(self, sync_max_count)
@property
def join_device(self) -> torch.device:
return self.device
@property
def join_process_group(self):
return self.process_group
def find_common_rank(self, rank, to_consider):
r"""
Returns the max rank of the ones to consider over the process group.
"""
common_rank = torch.tensor([rank if to_consider else -1], device=self.device)
dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
common_rank = common_rank.item()
return common_rank
def worker(rank):
assert torch.cuda.device_count() >= WORLD_SIZE
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(BACKEND, rank=rank, world_size=WORLD_SIZE)
counter = Counter(torch.device(f"cuda:{rank}"), dist.group.WORLD)
inputs = [torch.tensor([1]).float() for _ in range(NUM_INPUTS + rank)]
with Join([counter], sync_max_count=True):
for _ in inputs:
counter()
print(f"{int(counter.count.item())} inputs processed before rank {rank} joined!")
print(f"{int(counter.max_count.item())} inputs processed across all ranks!")
def main():
mp.spawn(worker, nprocs=WORLD_SIZE, join=True)
if __name__ == "__main__":
main()
Since rank 0 sees 5 inputs and rank 1 sees 6, this yields the output:
10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks!
Some key points to highlight:
- A
Counter
instance performs a single all-reduce per iteration, so the main hook performs a single all-reduce as well to shadow it. - The
Counter
class makes a call toJoin.notify_join_context()
at the beginning of its__call__()
method since that is a place before its per- iteration collective communications (i.e. its all-reduce). - The
is_last_joiner
argument is used to determine the broadcast source in the post-hooks. - We pass in the
sync_max_count
keyword argument to the context manager, which is then forwarded toCounter
‘s join hook.