in Blog Posts, Solution

A PyTorch GPU Memory Leak Example

I ran into this GPU memory leak issue when building a PyTorch training pipeline. After spending quite some time, I finally figured out this minimal reproducible example.

import torch

class AverageMeter(object):
    """
    Keeps track of most recent, average, sum, and count of a metric.
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

device: torch.device = torch.device("cuda:0")
model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=False)
model.to(device)
model.train()

data = torch.zeros(size=[1,3,128,128],device=device,dtype=torch.float)
loss_avg = AverageMeter()
for i in range(1000):
    outputs = model(data)
    loss = outputs.mean()
    loss_avg.update(loss)
    if i % 10 == 0:
        print ('Loss, current batch {:.3f}, moving average {:.3f}'.format(loss_avg.val, loss_avg.avg))
        print ('GPU Memory Allocated {} MB'.format(torch.cuda.memory_allocated(device=device)/1024./1024.))

Kicking off the training, it shows constantly increasing allocated GPU memory.

Using cache found in /home/haoxiangli/.cache/torch/hub/pytorch_vision_v0.9.0
Loss, current batch -0.001, moving average -0.001
GPU Memory Allocated 51.64306640625 MB
Loss, current batch -0.001, moving average -0.001
GPU Memory Allocated 119.24072265625 MB
Loss, current batch -0.001, moving average -0.001
GPU Memory Allocated 186.83837890625 MB
Loss, current batch -0.001, moving average -0.001
GPU Memory Allocated 254.43603515625 MB
Loss, current batch -0.001, moving average -0.001
GPU Memory Allocated 322.03369140625 MB
Loss, current batch -0.001, moving average -0.001
GPU Memory Allocated 389.63134765625 MB
Loss, current batch -0.001, moving average -0.001

This “AverageMeter” has been used in many popular repositories (e.g., https://github.com/facebookresearch/moco). It’s by-design tracking the average of a given value and can be used to track training speed, loss value, and etc.

class AverageMeter(object):
    """
    Keeps track of most recent, average, sum, and count of a metric.
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

It comes with some existing training pipeline code but once I got this into my codebase, I started to use it elsewhere.

The implementation is straightforward and bug-free but it turns out there is something tricky here.

Following is a modified version without the GPU memory leak problem:

import torch

class AverageMeter(object):
    """
    Keeps track of most recent, average, sum, and count of a metric.
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

device: torch.device = torch.device("cuda:0")
model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=False)
model.to(device)
model.train()

data = torch.zeros(size=[1,3,128,128],device=device,dtype=torch.float)
loss_avg = AverageMeter()
for i in range(1000):
    outputs = model(data)
    loss = outputs.mean()
    loss_avg.update(loss.item()) # <-----
    if i % 10 == 0:
        print ('Loss, current batch {:.3f}, moving average {:.3f}'.format(loss_avg.val, loss_avg.avg))
        print ('GPU Memory Allocated {} MB'.format(torch.cuda.memory_allocated(device=device)/1024./1024.))

The annotated line is the little nuance. When something part of the computation graph is tracked with the “AverageMeter”, somehow PyTorch stops releasing related part of GPU memory. The fix is to cast it into a plain value beforehand.

This paradigm may show in the codebase in other forms. Once there is some utility classes being implemented, it is quite easy to accidentally use it over trainable PyTorch tensors. I do feel this is a bug of PyTorch (at least <=1.8.0) though.

Write a Comment

Comment