TensorFlow'da Değişken ve get_variable arasındaki fark


125

Bildiğim kadarıyla, Variablebir değişken oluşturmak için varsayılan işlemdir ve get_variableesas olarak ağırlık paylaşımı için kullanılır.

Bir yandan, bir değişkene ihtiyaç duyduğunuzda get_variableilkel Variableişlem yerine kullanmayı öneren bazı insanlar var . Öte yandan, get_variableTensorFlow'un resmi belgelerinde ve demolarında herhangi bir şekilde kullanıldığını görüyorum .

Bu nedenle, bu iki mekanizmanın nasıl doğru bir şekilde kullanılacağına dair bazı temel kuralları bilmek istiyorum. Herhangi bir "standart" ilke var mı?


6
get_variable yeni bir yoldur, Değişken eski yöntemdir (sonsuza kadar desteklenebilir) Lukasz'ın dediği gibi (Not: Değişken adı kapsamının çoğunu TF'de yazdı)
Yaroslav Bulatov

Yanıtlar:


90

Her zaman kullanmanızı öneririm tf.get_variable(...)- eğer değişkenleri herhangi bir zamanda paylaşmanız gerekirse, örneğin çoklu gpu ayarında (çoklu gpu CIFAR örneğine bakın) kodunuzu yeniden düzenlemenizi kolaylaştıracaktır. Bunun dezavantajı yok.

Saf tf.Variabledaha düşük seviyelidir; bir noktada tf.get_variable()mevcut değildi, bu yüzden bazı kodlar hala düşük seviyeli yolu kullanıyor.


5
Cevabınız için çok teşekkür ederim. Ama yine de nasıl değiştirileceği konusunda bir sorum var tf.Variableolan tf.get_variableher yerde. Bu, uyuşmuş bir dizi ile bir değişkeni başlatmak istediğimde, yaptığım gibi bunu yapmanın temiz ve verimli bir yolunu bulamıyorum tf.Variable. Bunu nasıl çözersiniz? Teşekkürler.
Lifu Huang

69

tf.Variable bir sınıftır ve tf.Variable'ı oluşturmanın tf.Variable.__init__ve dahil olmak üzere çeşitli yolları vardır tf.get_variable.

tf.Variable.__init__: İnitial_value ile yeni bir değişken oluşturur .

W = tf.Variable(<initial-value>, name=<optional-name>)

tf.get_variable: Bu parametrelerle var olan bir değişkeni alır veya yeni bir tane oluşturur. Başlatıcıyı da kullanabilirsiniz.

W = tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None,
       regularizer=None, trainable=True, collections=None)

Aşağıdaki gibi başlatıcıları kullanmak çok yararlıdır xavier_initializer:

W = tf.get_variable("W", shape=[784, 256],
       initializer=tf.contrib.layers.xavier_initializer())

Daha fazla bilgi burada .


Evet, Variableaslında kullanmaktan bahsediyorum __init__. Bu get_variablekadar kullanışlı olduğu için, TensorFlow kodlarının çoğunun Variableyerine neden kullanıldığını merak ediyorum get_variable. Aralarında seçim yaparken dikkate alınması gereken herhangi bir kural veya faktör var mı? Teşekkür ederim!
Lifu Huang

Belirli bir değere sahip olmak istiyorsanız, Değişken kullanmak basittir: x = tf.Variable (3).
Sung Kim

@SungKim normalde kullandığımızda tf.Variable()onu kesilmiş bir normal dağılımdan rastgele bir değer olarak başlatabiliriz. İşte benim örneğim w1 = tf.Variable(tf.truncated_normal([5, 50], stddev = 0.01), name = 'w1'). Bunun eşdeğeri ne olabilir? Kesilmiş bir normal istediğimi nasıl söylerim? Sadece yapmalı mıyım w1 = tf.get_variable(name = 'w1', shape = [5,50], initializer = tf.truncated_normal, regularizer = tf.nn.l2_loss)?
Euler_Salter

@Euler_Salter: tf.truncated_normal_initializer()İstenilen sonucu almak için kullanabilirsiniz .
Beta

46

Biri ve diğeri arasında iki temel fark bulabilirim:

  1. Birincisi, bu tf.Variableher zaman yeni bir değişken oluşturacaktır, oysa grafikten belirtilen parametrelere sahip mevcut bir değişkeni tf.get_variablealır ve yoksa yeni bir tane oluşturur.

  2. tf.Variable bir başlangıç ​​değerinin belirtilmesini gerektirir.

tf.get_variableYeniden kullanım kontrollerini gerçekleştirmek için işlevin adın önekini geçerli değişken kapsamıyla açıklığa kavuşturmak önemlidir . Örneğin:

with tf.variable_scope("one"):
    a = tf.get_variable("v", [1]) #a.name == "one/v:0"
with tf.variable_scope("one"):
    b = tf.get_variable("v", [1]) #ValueError: Variable one/v already exists
with tf.variable_scope("one", reuse = True):
    c = tf.get_variable("v", [1]) #c.name == "one/v:0"

with tf.variable_scope("two"):
    d = tf.get_variable("v", [1]) #d.name == "two/v:0"
    e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0"

assert(a is c)  #Assertion is true, they refer to the same object.
assert(a is d)  #AssertionError: they are different objects
assert(d is e)  #AssertionError: they are different objects

Son iddia hatası ilginçtir: Aynı kapsamda aynı ada sahip iki değişkenin aynı değişken olduğu varsayılır. Ancak değişkenlerin adlarını test ederseniz dve eTensorflow'un değişkenin adını değiştirdiğini fark edeceksiniz e:

d.name   #d.name == "two/v:0"
e.name   #e.name == "two/v_1:0"

Harika örnek! İlgili d.nameve e.nameben rastlamak sadece ettik tensör grafik adlandırma işlemi bu TensorFlow doc it açıklıyor:If the default graph already contained an operation named "answer", the TensorFlow would append "_1", "_2", and so on to the name, in order to make it unique.
Atlas7

2

Diğer bir fark, birinin ('variable_store',)koleksiyonda olması, ancak diğerinin olmamasıdır.

Lütfen kaynak koduna bakın :

def _get_default_variable_store():
  store = ops.get_collection(_VARSTORE_KEY)
  if store:
    return store[0]
  store = _VariableStore()
  ops.add_to_collection(_VARSTORE_KEY, store)
  return store

Bunu göstermeme izin verin:

import tensorflow as tf
from tensorflow.python.framework import ops

embedding_1 = tf.Variable(tf.constant(1.0, shape=[30522, 1024]), name="word_embeddings_1", dtype=tf.float32) 
embedding_2 = tf.get_variable("word_embeddings_2", shape=[30522, 1024])

graph = tf.get_default_graph()
collections = graph.collections

for c in collections:
    stores = ops.get_collection(c)
    print('collection %s: ' % str(c))
    for k, store in enumerate(stores):
        try:
            print('\t%d: %s' % (k, str(store._vars)))
        except:
            print('\t%d: %s' % (k, str(store)))
    print('')

Çıktı:

collection ('__variable_store',): 0: {'word_embeddings_2': <tf.Variable 'word_embeddings_2:0' shape=(30522, 1024) dtype=float32_ref>}

Sitemizi kullandığınızda şunları okuyup anladığınızı kabul etmiş olursunuz: Çerez Politikası ve Gizlilik Politikası.
Licensed under cc by-sa 3.0 with attribution required.