PyTorch'da eğitimli bir modeli kaydetmenin en iyi yolu?


193

PyTorch'ta eğitimli bir modeli kurtarmanın alternatif yollarını arıyordum. Şimdiye kadar iki alternatif buldum.

  1. torch.save () bir model ve kaydetmek için torch.load () bir model yüklemek için.
  2. model.state_dict () eğitimli bir model ve kaydetmek için model.load_state_dict () kaydedilmiş modeli yükleme.

Yaklaşım 1 üzerinde yaklaşım 2'nin önerildiği bu tartışmaya rastladım .

Benim sorum, ikinci yaklaşım neden tercih ediliyor? Sadece torch.nn modüllerinin bu iki işlevi olduğu ve bunları kullanmamızın teşvik edildiği için mi?


2
Ben torch.save () geri yayılım kullanımı için ara çıkışlar gibi, tüm ara değişkenleri kaydetmek çünkü düşünüyorum. Ancak sadece ağırlık / sapma vb. Gibi model parametrelerini kaydetmeniz gerekir. Bazen birincisi, ikincisinden çok daha büyük olabilir.
Dawei Yang

2
Test ettim torch.save(model, f)ve torch.save(model.state_dict(), f). Kaydedilen dosyalar aynı boyuttadır. Şimdi kafam karıştı. Ayrıca, model.state_dict () kaydetmek için turşu kullanarak son derece yavaş buldum. Bence en iyi yol, torch.save(model.state_dict(), f)modelin oluşturulmasını ele aldığınız için kullanmaktır ve meşale model ağırlıklarının yüklenmesini idare eder, böylece olası sorunları ortadan kaldırır. Referans: tartış.pytorch.org/t/saving
Dawei Yang

PyTorch'un bu konuya öğreticiler bölümünde biraz daha açık bir şekilde değinmiş gibi görünüyor — bir seferde birden fazla model kaydetmek ve sıcak başlangıç ​​modelleri de dahil olmak üzere burada cevaplarda listelenmeyen birçok iyi bilgi var.
whlteXbread

kullanmanın nesi yanlış pickle?
Charlie Parker

1
@CharlieParker torch.save turşu dayanmaktadır. Aşağıdaki bağlantılı öğretici aşağıdadır: "[torch.save] tüm modülü Python'un turşu modülünü kullanarak kaydedecektir. Bu yaklaşımın dezavantajı, serileştirilmiş verilerin belirli sınıflara ve modelde kullanılan tam dizin yapısına bağlı olmasıdır Bunun nedeni, turşunun model sınıfının kendisini kaydetmemesidir, bunun yerine, yükleme süresi sırasında kullanılan sınıfı içeren dosyaya bir yol kaydeder.Bu nedenle, kodunuz çeşitli şekillerde bozulabilir diğer projelerde veya refactorlardan sonra kullanılır. "
David Miller

Yanıtlar:


215

Ben buldum bu sayfayı kendi github repo, sadece burada içeriği yapıştırın.


Bir modeli kaydetmek için önerilen yaklaşım

Bir modeli serileştirmek ve geri yüklemek için iki ana yaklaşım vardır.

İlk (önerilen) yalnızca model parametrelerini kaydeder ve yükler:

torch.save(the_model.state_dict(), PATH)

Daha sonra:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

İkincisi tüm modeli kaydeder ve yükler:

torch.save(the_model, PATH)

Daha sonra:

the_model = torch.load(PATH)

Ancak bu durumda, serileştirilmiş veriler belirli sınıflara ve kullanılan tam dizin yapısına bağlıdır, bu nedenle diğer projelerde veya bazı ciddi refaktörlerden sonra kullanıldığında çeşitli şekillerde kırılabilir.


8
@Smth Göre tartışma.pytorch.org/ t/ saving- and -loading- a- model -in- pytorch/… modeli varsayılan olarak tren modeline yeniden yüklenir . bu yüzden, yüklemeye başladıktan sonra, çıkarım için yüklüyorsanız, eğitime devam etmemek için manuel olarak the_model.eval () öğesini çağırmanız gerekir.
WillZ

İkinci yöntem windows 10'da stackoverflow.com/questions/53798009/… hatası veriyor
Gulzar

Model sınıfına erişime gerek kalmadan kaydetme seçeneği var mı?
Michael D

Bu yaklaşımla, yük durumu için geçmeniz gereken * args ve ** kwarg'ları nasıl takip edersiniz?
Mariano Kamp

kullanmanın nesi yanlış pickle?
Charlie Parker

144

Bu ne yapmak istediğinize bağlıdır.

Durum # 1: Çıkarım yapmak için modeli kendiniz kullanmak üzere kaydedin: Modeli kaydedersiniz , geri yüklersiniz ve sonra modeli değerlendirme moduna değiştirirsiniz. Bu, genellikle yapımda tren modunda varsayılan olarak BatchNormve Dropoutkatmanlarınız olduğu için yapılır :

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Durum # 2: Eğitime daha sonra devam etmek için modeli kaydedin: Kaydetmek üzere olduğunuz modeli eğitmeye devam etmeniz gerekiyorsa, modelden daha fazlasını kaydetmeniz gerekir. Ayrıca optimize edicinin durumunu, dönemleri, skoru vb. Kaydetmeniz gerekir. Bunu şöyle yapabilirsiniz:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

Eğitime devam etmek için şu gibi şeyler yaparsınız: state = torch.load(filepath)ve sonra, her bir nesnenin durumunu geri yüklemek için, şöyle bir şey:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Eğer eğitim sürdürme olduğundan, YAPMAYIN diyoruz model.eval()ne zaman yükleme durumları geri kere.

Durum # 3: Kodunuza erişimi olmayan bir kişi tarafından kullanılacak model : Tensorflow'da .pb, modelin hem mimarisini hem de ağırlıklarını tanımlayan bir dosya oluşturabilirsiniz . Bu, özellikle kullanırken çok kullanışlıdır Tensorflow serve. Pytorch'ta bunu yapmanın eşdeğer yolu:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

Bu şekilde hala kurşun geçirmez değildir ve pitorch hala birçok değişiklik geçirdiğinden, tavsiye etmem.


1
3 vaka için önerilen bir dosya var mı? Yoksa her zaman .pth mi?
Verena Haunschmid

1
Durum # 3'te torch.loadsadece bir OrderedDict döner. Tahmin yapmak için modeli nasıl elde edersiniz?
Alber8295

Merhaba, Bahsedilen "Vaka # 2: Daha sonra eğitime devam etmek için modeli kaydet" i nasıl yapabilirim? Ben modele kontrol noktasından yüklemek başardı, o zaman "(model, ölçüt, optimizer, sched, çağları) model.to (cihaz) modeline = train_model_epoch" çalıştırmak veya benzeri tren modeline sürdüremeyeceğini
dnez

1
Merhaba, çıkarım için olan bir durum için, resmi pytorch belgesinde, çıkarsama veya eğitimi tamamlama için optimize edici state_dict'i kaydetmesi gerektiğini söyleyin. "Çıkarım veya devam ettirme eğitimi için kullanılacak genel bir kontrol noktasını kaydederken, yalnızca modelin durum_denetinden daha fazlasını kaydetmelisiniz. Model tren olarak güncellenen arabellekleri ve parametreleri içerdiğinden, optimize edicinin durum_decini de kaydetmeniz önemlidir. . "
Mohammed Awney

1
3. durumda, model sınıfı bir yerde tanımlanmalıdır.
Michael D

12

Turşu seri ve bir Python nesnesi de-serializing için Python kütüphanesi uygular ikili protokoller.

Siz import torch(veya PyTorch kullandığınızda) bu import picklesizin için olacaktır ve nesneyi kaydetme ve yükleme yöntemleri olan pickle.dump()ve pickle.load()doğrudan aramanıza gerek yoktur .

Aslında, torch.save()ve torch.load()kaydırılır pickle.dump()ve pickle.load()senin için.

A state_dictverilen diğer cevap sadece birkaç not daha hak ediyor.

state_dictPyTorch'un içinde ne var? Aslında iki tane state_dicts var.

PyTorch modelinin öğrenilebilir parametreler (w ve b) alma çağrısı torch.nn.Modulevardır model.parameters(). Bu öğrenilebilir parametreler, rasgele ayarlandığında, öğrendikçe zaman içinde güncellenir. Öğrenilebilir parametreler ilktir state_dict.

İkincisi state_dict, optimize edici durum dikte. Optimize edicinin öğrenilebilir parametrelerimizi geliştirmek için kullanıldığını hatırlıyorsunuz. Ancak optimize edici state_dictdüzeltildi. Orada öğrenecek bir şey yok.

Çünkü state_dictnesneler Python sözlükleri vardır, onlar kolayca PyTorch modelleri ve optimize etmek modülerlik büyük bir ekleyerek, kaydedilen güncellenen, değişmiş ve geri yüklenebilir.

Bunu açıklamak için çok basit bir model oluşturalım:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Bu kod aşağıdakileri verir:

Model's state_dict:
weight   torch.Size([2, 5])
bias     torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

Bunun minimal bir model olduğunu unutmayın. Sıralı yığın eklemeyi deneyebilirsiniz

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

Yalnızca öğrenilebilir parametrelere sahip katmanların (evrişimli katmanlar, doğrusal katmanlar, vb.) Ve kayıtlı arabelleklerin (batchnorm katmanları) modellerde girişleri olduğunu unutmayın state_dict.

Öğrenilemeyen şeyler, optimize state_dictedicinin durumu ve kullanılan hiperparametreler hakkında bilgi içeren optimize edici nesnesine aittir .

Hikayenin geri kalanı aynı; çıkarım aşamasında (bu, eğitimden sonra modeli kullandığımız bir aşamadır) tahmin için; öğrendiğimiz parametrelere dayanarak tahmin yapıyoruz. Yani çıkarım için, sadece parametreleri kaydetmemiz gerekiyor model.state_dict().

torch.save(model.state_dict(), filepath)

Ve daha sonra kullanmak üzere model.load_state_dict (torch.load (dosyayolu)) model.eval ()

Not: model.eval()Bu modeli yükledikten sonra çok önemli olan son satırı unutmayın .

Ayrıca kaydetmeye çalışmayın torch.save(model.parameters(), filepath). model.parameters()Sadece jeneratör nesnesidir.

Diğer tarafta, torch.save(model, filepath)model nesnesinin kendisini kaydeder, ancak modelin optimize ediciye sahip olmadığını unutmayın state_dict. Optimizer'ın devlet kararını kaydetmek için @Jadiel de Armas'ın diğer mükemmel cevabını kontrol edin.


Basit bir çözüm olmasa da, sorunun özü derinlemesine analiz edilir! Oyla.
Jason Young

7

Yaygın bir PyTorch kuralı, modelleri .pt veya .pth dosya uzantısı kullanarak kaydetmektir.

Tüm Modeli Kaydet / Yükle Kaydet:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

Yük:

Model sınıfı bir yerde tanımlanmalıdır

model = torch.load(PATH)
model.eval()

4

Modeli kaydetmek istiyorsanız ve eğitime daha sonra devam etmek istiyorsanız:

Tek GPU: Kaydet:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Yük:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

Çoklu GPU: Kaydet

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Yük:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU
Sitemizi kullandığınızda şunları okuyup anladığınızı kabul etmiş olursunuz: Çerez Politikası ve Gizlilik Politikası.
Licensed under cc by-sa 3.0 with attribution required.