Gradient Checkpointing: Memory Optimization
For a simple feed-forward neural network with layers, the computation graph for obtaining gradients looks as follows:

The activations of the neural network layers correspond to the nodes marked with an . During the forward pass all these nodes are evaluated in order. The gradient of the loss with respect to the activations and parameters of these layers is indicated by the nodes marked with . During the backward pass, all these nodes are evaluated in the reversed order. The results obtained for the nodes are needed to compute the nodes, and hence all nodes are kept in memory after the forward pass. Only when backpropagation has progressed far enough to have computed all dependencies, or children, of an node, can it be erased from memory. This means that the memory required by simple backprop grows linearly with the number of neural net layers . We show the order in which these nodes are computed below. The purple shaded circles indicate which of the nodes need to be held in memory at any given time.

Simple backpropagation as described above is optimal in terms of computation: it only computes each node once. However, if we are willing to recompute nodes we can potentially save a lot of memory. We might for instance simply recompute every node from the forward pass each time we need it. The order of execution, and the memory used, then look as follows:

With this strategy, the memory required to compute gradients in our graph is constant in the number of neural network layers , which is optimal in terms of memory. However, note that the number of node evaluations now scales with , whereas it previously scaled as : Each of the nodes is recomputed on the order of times. The computation graph thus becomes much slower to evaluate for deep networks, which makes this method impractical for use in deep learning.
To strike a balance between memory and computation we need to come up with a strategy that allows nodes to be recomputed, but not too often. The strategy is to mark a subset of the neural net activations as checkpoint nodes.

For the simple feed-forward network in our example, the optimal choice is to mark every -th node as a checkpoint. This way, both the number of checkpoint nodes and the number of nodes in between checkpoints are on the order of , which means that the required memory now also scales with the square root of the number of layers in our network. Since every node is recomputed at most once, the additional computation required by this strategy is equivalent to a single forward pass through the network.
This technique fundamentally trades computational cycles for memory capacity.
- Memory Savings: By storing only checkpoints (where is the number of layers), the memory complexity drops from to approximately .
- Compute Penalty: Because the forward pass for the non-checkpointed layers is executed twice (once during the initial forward pass, and once during the backward pass regeneration), training becomes slower. Typically, this results in a 20% to 30% increase in training time.
Code Example (PyTorch)
In PyTorch, we do not need to implement the recomputation logic manually. We can wrap specific modules (like Transformer
blocks) with torch.utils.checkpoint. For example
TIPIn the following implementation of a model using checkpointing, we define a custom module and wrap its execution to save memory.
import torchimport torch.nn as nnfrom torch.utils.checkpoint import checkpoint
class HeavyLayer(nn.Module): def __init__(self): super().__init__() # A large layer that produces heavy activations self.linear1 = nn.Linear(4096, 4096) self.relu = nn.ReLU() self.linear2 = nn.Linear(4096, 4096)
def forward(self, x): return self.linear2(self.relu(self.linear1(x)))
class CheckpointedModel(nn.Module): def __init__(self, num_layers=10): super().__init__() self.layers = nn.ModuleList([HeavyLayer() for _ in range(num_layers)])
def forward(self, x): for layer in self.layers: # Instead of layer(x), we use checkpoint(layer, x). # This prevents intermediate activations inside 'layer' # from being saved during the forward pass. x = checkpoint(layer, x, use_reentrant=False) return x
# Usagedevice = "cuda" if torch.cuda.is_available() else "cpu"model = CheckpointedModel().to(device)input_data = torch.randn(32, 4096, requires_grad=True).to(device)
# This forward pass uses significantly less VRAM than standard executionoutput = model(input_data)output.sum().backward()