Base Distributed Module¶
Overview¶
The Module
container is an extension of the Torch torch.nn.Module container which defines the interface that is
required to allow a DistDL distributed layer to perform any necessary setup or
teardown operations.
Note
This container inherits from PyTorch and should be used as the base-class for any DistDL distributed layer.
Motivation¶
DistDL aims to preserve PyTorch-like interfaces, so that information about the
input tensor, other than the partition functions, is not required at the
instantiation of the layer. Consequently, some layer variables, such as
intermediate communication buffers, for example in the
distdl.nn.HaloExchange
layer or the
distdl.nn.Repartition
layer, can only be determined when the
layer is evaluated.
The interfaces defined here allow those properties to be setup (and torn down) safely when the layer is evaluated.
Implementation¶
However, there is significant cost to the setup phase. To avoid this cost,
the setup phase is only called when there is a change to the structure of the
global input tensor. We currently define the structure to be the shape and
requires_grad
status. In the future, the dtype
will also be part of
this determination.
Warning
Each worker will only have knowledge of their local input tensors. It is not possible to determine if the global tensor has changed without a global communication. Practically, this means that the first dimension of the tensor, usually the batch dimension, which should be the same across all workers in a partition, is the primary indicator of change.
It is possible that the feature shape can be changed globally without changing the local feature shape, so care must be taken.
Warning
This is largely designed to allow the batch to change, however, for performance purposes, frequent changes to the batch size and other tensor structure should be avoided.
The check if the input tensor has changed is defined per-layer, allowing for implementation specific behavior.
The Module._distdl_forward_pre_hook
function is registered as a Torch
pre-forward hook. This hook checks if the layer needs to be setup, either
from scratch or as a reset, and then calls the setup function.
The Module
container does not implement any of the following operations,
but the class does define the required interfaces.