Customize Process Group Backends Using Cpp Extensions¶
Note
View the source code for this tutorial in github.
Prerequisites:
- PyTorch Distributed Overview
- PyTorch Collective Communication Package
- PyTorch Cpp Extension
- Writing Distributed Applications with PyTorch
This tutorial demonstrates how to implement a custom ProcessGroup
backend and plug that into
PyTorch distributed package using
cpp extensions. This is helpful when you need a specialized software
stack for your hardware, or when you would like to experiment with new
collective communication algorithms.
Basics¶
PyTorch collective communications power several widely adopted distributed
training features, including
DistributedDataParallel,
ZeroRedundancyOptimizer,
FullyShardedDataParallel.
In order to make the same collective communication API work with
different communication backends, the distributed package abstracts collective
communication operations into a
ProcessGroup
class. Different backends can
then be implemented as subclasses of ProcessGroup
using preferred
third-party libraries. PyTorch distributed comes with three default backends,
ProcessGroupNCCL
, ProcessGroupGloo
, and ProcessGroupMPI
. However,
beyond these three backends, there are also other communication libraries
(e.g., UCC,
OneCCL), different types of hardware
(e.g., TPU,
Trainum), and emerging
communication algorithms (e.g.,
Herring,
Reduction Server).
Therefore, the distributed package exposes extension APIs to allow customizing
collective communication backends.
The 4 steps below show how to implement a dummy ProcessGroup
backend
and use that in Python application code. Please note that this tutorial focuses
on demonstrating the extension APIs, instead of developing a functioning
communication backend. Hence, the dummy
backend just covers a subset of the
APIs (all_reduce
and all_gather
), and simply sets the values of tensors
to 0.
Step 1: Implement a Subclass of ProcessGroup
¶
This first step is to implement a ProcessGroup
subclass that overrides
target collective communication APIs and runs the custom communication algorithm.
The extension also needs to implement a ProcessGroup::Work
subclass, which
serves as a future of communication results and allows asynchronous execution in
application code. If the extension uses third-party libraries, it can
include the headers and call into the library APIs from the ProcessGroupDummy
subclass. The two code snippets below present the implementation of dummy.h
and
dummy.cpp
. See the dummy collectives
repository for the full implementation.
// file name: dummy.hpp
#include <torch/python.h>
#include <c10d/ProcessGroup.hpp>
#include <c10d/Store.hpp>
#include <c10d/Types.hpp>
#include <c10d/Utils.hpp>
#include <pybind11/chrono.h>
namespace c10d {
class ProcessGroupDummy : public ProcessGroup {
public:
class WorkDummy : public ProcessGroup::Work {
public:
WorkDummy(
OpType opType,
c10::intrusive_ptr<c10::ivalue::Future> future) // future of the output
: ProcessGroup::Work(
-1, // rank, only used by recvAnySource, irrelevant in this demo
opType),
future_(std::move(future)) {}
// There are several additional helper functions that need to be
// implemented. Please refer to https://github.com/mrshenli/dummy_collectives
// for the full implementation.
private:
c10::intrusive_ptr<c10::ivalue::Future> future_;
};
ProcessGroupDummy(int rank, int size);
c10::intrusive_ptr<ProcessGroup::Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;
// The collective communication APIs without a custom implementation
// will error out if invoked by application code.
};
} // namespace c10d
// file name: dummy.cpp
#include "dummy.hpp"
namespace c10d {
// This is a dummy allgather that sets all output tensors to zero
// Modify the implementation to conduct real communication asynchronously
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupDummy::allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& /* unused */) {
for (auto& outputTensorVec : outputTensors) {
for (auto& outputTensor : outputTensorVec) {
outputTensor.zero_();
}
}
auto future = c10::make_intrusive<c10::ivalue::Future>(
c10::ListType::create(c10::ListType::create(c10::TensorType::get())));
future->markCompleted(c10::IValue(outputTensors));
return c10::make_intrusive<WorkDummy>(OpType::ALLGATHER, std::move(future));
}
// This is a dummy allreduce that sets all output tensors to zero
// Modify the implementation to conduct real communication asynchronously
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupDummy::allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts) {
for (auto& tensor : tensors) {
tensor.zero_();
}
auto future = c10::make_intrusive<c10::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()));
future->markCompleted(c10::IValue(tensors));
return c10::make_intrusive<WorkDummy>(OpType::ALLGATHER, std::move(future));
}
} // namespace c10d
Step 2: Expose The Extension Python APIs¶
The backend constructors are called
from Python side,
so the extension also needs to expose the constructor APIs to Python. This can
be done by adding the following methods. In this example, store
and
timeout
are ignored by the ProcessGroupDummy
instantiation method, as
those are not used in this dummy implementation. However, real-world extensions
should consider using the store
to perform rendezvous and supporting the
timeout
argument.
class ProcessGroupDummy : public ProcessGroup {
static c10::intrusive_ptr<ProcessGroup> createProcessGroupDummy(
const c10::intrusive_ptr<::c10d::Store>& store,
int rank,
int size,
const std::chrono::duration<float>& timeout);
static void ProcessGroupDummyConstructor() __attribute__((constructor)) {
py::object module = py::module::import("torch.distributed");
py::object register_backend =
module.attr("Backend").attr("register_backend");
// torch.distributed.Backend.register_backend will add `dummy` as a
// new valid backend.
register_backend("dummy", py::cpp_function(createProcessGroupDummy));
}
}
c10::intrusive_ptr<ProcessGroup> ProcessGroupDummy::createProcessGroupDummy(
const c10::intrusive_ptr<::c10d::Store>& /* unused */,
int rank,
int size,
const std::chrono::duration<float>& /* unused */) {
return c10::make_intrusive<ProcessGroupDummy>(rank, size);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("createProcessGroupDummy", &ProcessGroupDummy::createProcessGroupDummy);
}
Step 3: Build The Custom Extension¶
Now, the extension source code files are ready. We can then use
cpp extensions
to build it. To do that, create a setup.py
file that prepares the paths and
commands. Then call python setup.py install
to install the extension.
If the extension depends on third-party libraries, you can also specify
libraries_dirs
and libraries
to the cpp extension APIs. See the
torch ucc
project as a real-world example.
# file name: setup.py
import os
import sys
import torch
from setuptools import setup
from torch.utils import cpp_extension
sources = ["src/dummy.cpp"]
include_dirs = [f"{os.path.dirname(os.path.abspath(__file__))}/include/"]
if torch.cuda.is_available():
module = cpp_extension.CUDAExtension(
name = "dummy_collectives",
sources = sources,
include_dirs = include_dirs,
)
else:
module = cpp_extension.CppExtension(
name = "dummy_collectives",
sources = sources,
include_dirs = include_dirs,
)
setup(
name = "Dummy-Collectives",
version = "0.0.1",
ext_modules = [module],
cmdclass={'build_ext': cpp_extension.BuildExtension}
)
Step 4: Use The Extension in Application¶
After installation, you can conveniently use the dummy
backend when calling
init_process_group
as if it is an builtin backend.
import os
import torch
# importing dummy_collectives makes torch.distributed recognize `dummy`
# as a valid backend.
import dummy_collectives
import torch.distributed as dist
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group("dummy", rank=0, world_size=1)
x = torch.ones(6)
dist.all_reduce(x)
y = x.cuda()
dist.all_reduce(y)
print(f"cpu allreduce: {x}")
print(f"cuda allreduce: {y}")
try:
dist.broadcast(x, 0)
except RuntimeError:
print("got RuntimeError as broadcast is not implemented in Dummy ProcessGroup")