七難ハック
CUDA OOM on large model training with PyTorch and multiple GPUs
最終更新: 2021/09/22

When training large models with multiple GPUs in PyTorch, we often come across CUDA OOM. This problem happens suddenly for some reason, even though it seems to have enough memory. And perhaps the issue mentioned in this issue is similar.

I faced this problem while doing fine tuning of the mBART model (facebook / mbart-large-cc25).
(Rather, I currently facing an ongoing 🤣)


Where the error happen

It was the gathering process that I saw the OOM error. (L235 at torch.nn.parallel.comm.py on PyTorch 1.9.0)

    tensors = [_handle_complex(t) for t in tensors]
    if out is None:
        if destination == -1:
            warnings.warn(
                'Using -1 to represent CPU tensor is deprecated. Please use a '
                'device object or string instead, e.g., "cpu".')
        destination = _get_device_index(destination, allow_cpu=True, optional=True)
        return torch._C._gather(tensors, dim, destination)  # here

This is a process that aggregates the results processed by multiple GPUs, but at this time it is processed by one GPU by default. Therefore, I think that the load is biased to one GPU and OOM is occurring. Can I resolve this problem by using DDP instead of DataParallel? I haven't tried it yet so I don't know.

Temporary solution

I don't think this solution is clear at all, but if the load is biased to one GPU and it becomes OOM, I think may be solved by running the gathering process on the CPU.

    tensors = [_handle_complex(t) for t in tensors]
    if out is None:
        if destination == -1:
            warnings.warn(
                'Using -1 to represent CPU tensor is deprecated. Please use a '
                'device object or string instead, e.g., "cpu".')
        # force using cpu
        destination = 'cpu'
        destination = _get_device_index(destination, allow_cpu=True, optional=True)
        return torch._C._gather(tensors, dim, destination)

The gather function in comm.py has a destination argument, so all you have to do is specify this, but I'm currently using the transformers's Trainer and couldn't pass an argument to this gather function, so I forcibly applied the patch like this and ran it 🤫

After that

Of course, although it is only the gathering process, I think that the time is a little slower by using the CPU. But OOM is no longer occurring and learning is proceeding smoothly...
You’d think. Look at GPU memory utilization chart.

memory is leaking.💣
Why? Where process is leaking?
I think difficult for PyTorch training with multiple GPUs.🤮 It's going to take time to solve this problem.
I think it's probably easier to do with Tensorflow, Don't you think so?

I will add it if get anything. Or if anyone knows a solution, please let me know on Twitter.😇

P.S.

I noticed that memory is not leaking when I tried training over and over.
It's inputs size was so bad because I mistake tokenizer settings. Inputs size were not truncated. Sometime, coming big size inputs, It's memory usage was spiked.

In the end, the gathering OOM was solved by running the gathering process on the CPU as described at the beginning.