情報系の手考ノート

数学とか情報系の技術とか調べたり勉強したりしてメモしていきます.

GPUメモリ不足時超初歩的トラブルシューティング:PyTorch

2023年現在,PyTorch や TensorFlow があるおかげで,かなり簡単にニューラルネットワーク関連のプログラムが作れるようになっています. 特に PyTorch はカスタマイズ性が高く,研究用途等で非常に有用です. しかし研究で使うぞとなっても,なんとなくの仕組みも知らない状態で使うとメモリ効率が非常に悪いプログラムになったりします. そんな効率が悪いプログラムは,学部や修士の研究において「とりあえず実装しよう」というような思考になっても実装されたりするでしょう.

実際,私も PyTorch で書いた効率の悪いプログラムで研究を進めていた1人でした. しかし実験の過程でメモリ不足となったことでプログラムの効率化を図る機会がありました. その際に行った修正は非常に初歩的なものではあるものの,意識しておかないと簡単にメモリ不足へと陥いる要因となります. そこで当時私が確認した内容を記すことで,後学の徒へ向けた記事とします.

事前知識

最低限 PyTorch について知っているべき(?)内容についてここで述べます.

PyTorch は機械学習(特にニューラルネットワーク)の実装によく用いられている Python のライブラリです. 特に研究用途で用いられることが多いように感じます(偏見). 機械学習を実装する上で頻繁に扱うのが多次元配列と勾配の計算です. PyTorch ではテンソルという名前で多次元配列を扱っており,基本的にはテンソル同士の演算によって機械学習のモデルを実装します. また近年よく話題になる人工知能であるニューラルネットワークは,勾配を計算することで学習を行います. PyTorch では勾配を計算するための仕組みとして計算グラフというものを用いています. 特に気をつけずに実装していた場合,GPU 上のメモリが不足する理由の上位を独占しているのがこの計算グラフでしょう. 私もこれが原因でした. 以下で示すのは,いかにして不要な計算グラフを保持しないようにするか,という内容になります.

超初歩的トラブルシューティング

まずは計算グラフを切る方法です. 計算グラフはただ値を計算するのに加えて,どうやって値が計算されたかも保持しています. その分余計にメモリを必要とするわけです. となれば使わない計算グラフは持たない方がメモリを無駄にすることはなくなります. これは計算グラフを切るという操作によって実現されます. 下のソースコードにおいて detach_flag の有無を切り替えて実行すると,実際にメモリ使用量に差が出ます.

import torch
from torchvision.models import alexnet

def print_allocated_memory():
    print("{:.2f} GB".format(torch.cuda.memory_allocated() / 1024 ** 3))

detach_flag = True

model = alexnet().cuda()

x = torch.rand(1024, 3, 224, 224, requires_grad=False).cuda()
y = model(x)
if detach_flag:
    y = y.detach()
print_allocated_memory()

私の環境では detach_flag が真のときは 0.85 GB,偽のときは 3.76 GB となりました. どれだけ省メモリ化できるかはモデルによって変わりますが,実際に detach によって省メモリ化できていることは確認できたかと思います. ただし,detach した y からは勾配を計算してモデルのパラメータ更新はできなくなります. なので勾配がいらない値は適宜 detach することで無駄にメモリを食うようなことは避けられるでしょう.

また下記プログラムでも detach と同様の効果があります.

import torch
from torchvision.models import alexnet

def print_allocated_memory():
    print("{:.2f} GB".format(torch.cuda.memory_allocated() / 1024 ** 3))

model = alexnet().cuda()

x = torch.rand(1024, 3, 224, 224, requires_grad=False).cuda()
with torch.no_grad():
    y = model(x)
print_allocated_memory()

no_grad を使う場合は計算全体で計算グラフが構築されなくなります. そのため学習ではなくテストデータの評価時に使うことが多くなるでしょう. 対して detach では計算グラフ自体は計算されます. 私は学習中に計算した損失関数の値を Tensor でとっておきたいときに多用しました. 学習の過程で損失関数の値を累積してとっておきたい,というような状況で下のプログラムの y_sum のように累積させていくことはあると思います. このような状況で detach は効果を発揮します. まぁ損失関数の累積なら Tensor である必要が無いんですけどね.

下のプログラムで detach_flag が真のときは 0.81 GB,偽のときはメモリ不足となりました(GPU のメモリは 8GB). detach しない場合は,10 回のループ全てにおいて計算グラフを保持する必要があるため膨大なメモリが必要になります. しかし detach しておけば過去の計算グラフは不要になるため,1回の計算分のみ計算グラフを保持しておけば良くなります. このように学習中に Tensor の値を累積させて保持したい場合には detach を使うことでメモリを効率的に使えます.

import torch
from torchvision.models import alexnet

def print_allocated_memory():
    print("{:.2f} GB".format(torch.cuda.memory_allocated() / 1024 ** 3))

detach_flag = True

model = alexnet().cuda()

y_sum = 0
for i in range(10):
    x = torch.rand(1024, 3, 224, 224, requires_grad=False).cuda()
    y = model(x)
    if detach_flag:
        y_sum = (y_sum + y).detach()
    else:
        y_sum = y_sum + y
    del y
print_allocated_memory()

ちなみにプログラム中の del は python のスコープの関係でローカル変数 y を破棄するために使っています. このような変数のスコープを意識するのも無駄なメモリの使用を避ける方法の1つです.

まとめ

以上になります. 以上といっても項目的には1つしかなかったですが,PyTorch を使い始めた状態では意識できないポイントだったと思うので,初学者の方には十分参考になるんじゃないでしょうか. ここで述べたことに限らずメモリをうまく使う方法はあります. ただ PyTorch を使う時にメモリ不足に悩まされたら,無駄な計算グラフを作っていないか,無駄な計算グラフをいつまでも大切に持っておいていないか,そういう所を意識すると救われる場合があると思います. むしろそういう場合がほとんどだと思います. とりあえずこれだけでも意識しておくと良いのかなと個人的には思います.