PyTorch'da neden zero_grad () çağırmamız gerekiyor?


Yanıtlar:


148

İçinde PyTorch, geri yayılım yapmaya başlamadan önce degradeleri sıfıra ayarlamamız gerekir çünkü PyTorch , sonraki geri geçişlerde degradeleri biriktirir . Bu, RNN'leri eğitirken kullanışlıdır. Bu nedenle, varsayılan eylem her çağrıda gradyanları toplamaktır (yani toplamaktır)loss.backward() .

Bu nedenle, egzersiz döngünüzü başlattığınızda, ideal zero out the gradientsolarak parametre güncellemesini doğru şekilde yapmanız gerekir . Aksi takdirde, gradyan minimuma (veya maksimizasyon hedefleri olması durumunda maksimuma) doğru amaçlanan yönden başka bir yöne işaret eder .

İşte basit bir örnek:

import torch
from torch.autograd import Variable
import torch.optim as optim

def linear_model(x, W, b):
    return torch.matmul(x, W) + b

data, targets = ...

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

optimizer = optim.Adam([W, b])

for sample, target in zip(data, targets):
    # clear out the gradients of all Variables 
    # in this optimizer (i.e. W, b)
    optimizer.zero_grad()
    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()
    optimizer.step()

Alternatif olarak, vanilya gradyan inişi yapıyorsanız , o zaman:

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

for sample, target in zip(data, targets):
    # clear out the gradients of Variables 
    # (i.e. W, b)
    W.grad.data.zero_()
    b.grad.data.zero_()

    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()

    W -= learning_rate * W.grad.data
    b -= learning_rate * b.grad.data

Not : Gradyanların birikimi (yani toplamı ) tensör .backward()çağrıldığındaloss gerçekleşir .


3
çok teşekkür ederim, bu gerçekten yardımcı oldu! Tensorflow'un davranışa sahip olup olmadığını biliyor musunuz?
layser

Emin olmak için .. eğer bunu yapmazsanız, patlayan bir gradyan problemiyle karşılaşacaksınız, değil mi?
zwep

2
@zwep Degradeleri biriktirirsek, bu onların büyüklüklerinin arttığı anlamına gelmez: Degradenin işareti sürekli değişmeye devam ederse bir örnek olabilir. Yani patlayan gradyan problemiyle karşılaşacağınızı garanti etmez. Ayrıca, doğru şekilde sıfırlasanız bile patlayan gradyanlar mevcuttur.
Tom Roth

Vanilya gradyan inişini çalıştırdığınızda, ağırlıkları güncellemeye çalıştığınızda "grad gerektiren bir yaprak Değişkeni yerinde işlemde kullanılmış" hatası almıyor musunuz?
MUAS

1

zero_grad (), hatayı (veya kayıpları) azaltmak için gradyan yöntemini kullanırsanız, son adımdan kayıp olmadan yeniden döngüyü yeniden başlatır

zero_grad () kullanmazsanız, kayıp gerektiği kadar artmaz

örneğin zero_grad () kullanırsanız aşağıdaki çıktıyı bulacaksınız:

model training loss is 1.5
model training loss is 1.4
model training loss is 1.3
model training loss is 1.2

zero_grad () kullanmazsanız aşağıdaki çıktıyı bulacaksınız:

model training loss is 1.4
model training loss is 1.9
model training loss is 2
model training loss is 2.8
model training loss is 3.5
Sitemizi kullandığınızda şunları okuyup anladığınızı kabul etmiş olursunuz: Çerez Politikası ve Gizlilik Politikası.
Licensed under cc by-sa 3.0 with attribution required.