TensorFlow bir dosyaya kaydetme / bir dosyadan grafik yükleme


98

Şimdiye kadar topladığım kadarıyla, bir TensorFlow grafiğini bir dosyaya dökmenin ve ardından başka bir programa yüklemenin birkaç farklı yolu var, ancak nasıl çalıştıklarına dair net örnekler / bilgiler bulamadım. Zaten bildiğim şey şu:

  1. Modelin değişkenlerini bir kontrol noktası dosyasına (.ckpt) kaydedin tf.train.Saver()ve daha sonra geri yükleyin ( kaynak )
  2. Bir modeli .pb dosyasına kaydedin ve tf.train.write_graph()ve tf.import_graph_def()( kaynak ) kullanarak tekrar yükleyin.
  3. .Pb dosyasından bir model yükleyin, yeniden eğitin ve Bazel ( kaynak ) kullanarak modeli yeni bir .pb dosyasına aktarın
  4. Grafiği ve ağırlıkları birlikte kaydetmek için grafiği dondurun ( kaynak )
  5. as_graph_def()Modeli kaydetmek için kullanın ve ağırlıklar / değişkenler için bunları sabitlerle eşleyin ( kaynak )

Ancak, bu farklı yöntemlerle ilgili birkaç soruyu çözemedim:

  1. Kontrol noktası dosyalarıyla ilgili olarak, bir modelin yalnızca eğitilmiş ağırlıklarını mı kaydediyorlar? Kontrol noktası dosyaları yeni bir programa yüklenebilir ve modeli çalıştırmak için kullanılabilir mi, yoksa belirli bir zamanda / aşamada bir modeldeki ağırlıkları kaydetmenin bir yolu olarak mı hizmet ediyorlar?
  2. Bununla ilgili tf.train.write_graph()olarak, ağırlıklar / değişkenler de kaydediliyor mu?
  3. Bazel ile ilgili olarak, yeniden eğitim için yalnızca .pb dosyalarına kaydedebilir / buradan yükleyebilir mi? Bir grafiği .pb'ye dökmek için basit bir Bazel komutu var mı?
  4. Donma ile ilgili olarak, donmuş bir grafik kullanılarak yüklenebilir tf.import_graph_def()mi?
  5. TensorFlow için Android demosu, Google'ın Inception modelinde bir .pb dosyasından yüklenir. Kendi .pb dosyamı değiştirmek isteseydim, bunu nasıl yapardım? Herhangi bir yerel kodu / yöntemi değiştirmem gerekir mi?
  6. Genel olarak, tüm bu yöntemler arasındaki fark tam olarak nedir? Veya daha genel olarak, as_graph_def()/.ckpt/.pb arasındaki fark nedir?

Kısacası, aradığım şey hem bir grafiği (çeşitli işlemlerde olduğu gibi) hem de ağırlıklarını / değişkenlerini bir dosyaya kaydetmek için bir yöntemdir, bu daha sonra grafiği ve ağırlıkları başka bir programa yüklemek için kullanılabilir. , kullanım için (mutlaka devam etmek / yeniden eğitmek gerekmez).

Bu konuyla ilgili dokümantasyon çok basit değildir, bu nedenle herhangi bir cevap / bilgi çok takdir edilecektir.


2
En yeni / en eksiksiz API, size üçünü de aynı anda kaydetmenin bir yolunu sunan meta grafiktir - 1) grafik 2) parametre değerleri 3) koleksiyonlar: tensorflow.org/versions/r0.10/how_tos/meta_graph/ index.html
Yaroslav Bulatov

Yanıtlar:


80

TensorFlow'da bir modeli kaydetme sorununa yaklaşmanın birçok yolu vardır, bu da onu biraz kafa karıştırıcı hale getirebilir. Sırayla alt sorularınızın her birini almak:

  1. Kontrol noktası dosyaları (örneğin saver.save()bir tf.train.Savernesneye çağrı yapılarak üretilir ) yalnızca ağırlıkları ve aynı programda tanımlanan diğer değişkenleri içerir. Bunları başka bir programda kullanmak için tf.import_graph_def(), TensorFlow'a bu ağırlıklarla ne yapacağını söyleyen ilişkili grafik yapısını yeniden oluşturmanız gerekir (örneğin, yeniden oluşturmak için kodu çalıştırarak veya çağırarak ). Çağrının saver.save()aynı zamanda MetaGraphDefbir grafik içeren ve bir kontrol noktasından ağırlıkların bu grafikle nasıl ilişkilendirileceğinin ayrıntılarını içeren a içeren bir dosya oluşturduğuna dikkat edin. Daha fazla ayrıntı için eğiticiye bakın.

  2. tf.train.write_graph()yalnızca grafik yapısını yazar; ağırlıklar değil.

  3. Bazel'in TensorFlow grafiklerini okumak veya yazmakla ilgisi yoktur. (Belki de sorunuzu yanlış anladım: bir yorumda açıklığa kavuşturmaktan çekinmeyin.)

  4. Dondurulmuş bir grafik kullanılarak yüklenebilir tf.import_graph_def(). Bu durumda, ağırlıklar (tipik olarak) grafiğe gömülüdür, bu nedenle ayrı bir kontrol noktası yüklemeniz gerekmez.

  5. Ana değişiklik, modele beslenen tensörlerin adlarını ve modelden alınan tensörlerin adlarını güncellemektir. TensorFlow Android demosunda, bu , iletilen inputNameve outputNamedizelerine karşılık gelir TensorFlowClassifier.initializeTensorFlow().

  6. GraphDefTipik eğitim sürecinde değişmez programı yapıdır. Kontrol noktası, genellikle eğitim sürecinin her adımında değişen bir eğitim sürecinin durumunun anlık görüntüsüdür. Sonuç olarak, TensorFlow bu tür veriler için farklı depolama formatları kullanır ve düşük seviyeli API, bunları kaydetmek ve yüklemek için farklı yollar sunar. Gibi yüksek seviyeli kütüphaneler, MetaGraphDefkütüphaneler, keras ve skflow bu mekanizmalara yapı tasarruf ve bütün bir modeli sağlanması için daha fazla uygun yollar sunmaktır.


Bu , kaydedilen grafiği yükleyip sonra çalıştırabileceğinizi söylediğinde C ++ API belgelerinin yalan söylediği anlamına mı tf.train.write_graph()geliyor?
mnicky

2
C ++ API belgeleri yalan söylemez, ancak birkaç ayrıntı eksiktir. En önemli ayrıntı, GraphDefkaydedilene ek olarak tf.train.write_graph(), grafiği çalıştırırken beslemek ve getirmek istediğiniz tensörlerin adlarını da hatırlamanız gerektiğidir (yukarıdaki madde 5).
mrry

@mrry: Tensorflows DeepDream örneğini kullanmayı denedim. ama pb formatında önceden eğitilmiş modellere ihtiyacı var gibi görünüyor! Cifar10 örneğini çalıştırdım, ancak yalnızca kontrol noktaları oluşturuyor! Herhangi bir pb dosyası veya başka bir şey bulamadım! kontrol noktalarımı deepdream örneğinin kullandığı pb formatına nasıl dönüştürebilirim?
Rika

2
@ Coderx7 Kontrol noktası yalnızca ağırlıkları ve değişkenleri içerdiğinden ve grafiğin yapısı hakkında hiçbir şey bilmediğinden, bir .ckpt'yi .pb'ye dönüştüremeyeceğinizi düşünüyorum
davidivad

1
.pb dosyasını yüklemek ve ardından çalıştırmak için basit bir kod var mı?
Kong

1

Aşağıdaki kodu deneyebilirsiniz:

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)
Sitemizi kullandığınızda şunları okuyup anladığınızı kabul etmiş olursunuz: Çerez Politikası ve Gizlilik Politikası.
Licensed under cc by-sa 3.0 with attribution required.