Açıklamak için bazı örnek tensorflow kodu koydum (tam, çalışma kodu bu özünde ). Bu kod, bağladığınız kağıttaki 2. bölümün ilk kısmından kapsül ağını uygular:
N_REC_UNITS = 10
N_GEN_UNITS = 20
N_CAPSULES = 30
# input placeholders
img_input_flat = tf.placeholder(tf.float32, shape=(None, 784))
d_xy = tf.placeholder(tf.float32, shape=(None, 2))
# translate the image according to d_xy
img_input = tf.reshape(img_input_flat, (-1, 28, 28, 1))
trans_img = image.translate(img_input, d_xy)
flat_img = tf.layers.flatten(trans_img)
capsule_img_list = []
# build several capsules and store the generated output in a list
for i in range(N_CAPSULES):
# hidden recognition layer
h_rec = tf.layers.dense(flat_img, N_REC_UNITS, activation=tf.nn.relu)
# inferred xy values
xy = tf.layers.dense(h_rec, 2) + d_xy
# inferred probability of feature
p = tf.layers.dense(h_rec, 1, activation=tf.nn.sigmoid)
# hidden generative layer
h_gen = tf.layers.dense(xy, N_GEN_UNITS, activation=tf.nn.relu)
# the flattened generated image
cap_img = p*tf.layers.dense(h_gen, 784, activation=tf.nn.relu)
capsule_img_list.append(cap_img)
# combine the generated images
gen_img_stack = tf.stack(capsule_img_list, axis=1)
gen_img = tf.reduce_sum(gen_img_stack, axis=1)
Giriş pikselleri ile kapsüller arasındaki eşlemenin nasıl çalışması gerektiğini bilen var mı?
Bu ağ yapısına bağlıdır. Bu makaledeki ilk deney için (ve yukarıdaki kod), her kapsülün tüm girdi görüntüsünü içeren alıcı bir alanı vardır. Bu en basit düzenleme. Bu durumda, giriş görüntüsü ile her bir kapsülün ilk gizli katmanı arasında tamamen bağlı bir katmandır.
Alternatif olarak, kapsül alıcı alanlar, bu makaledeki sonraki deneylerde olduğu gibi, adımlarla CNN çekirdekleri gibi daha fazla düzenlenebilir.
Tanıma birimlerinde tam olarak ne olmalı?
Tanıma birimleri, her bir kapsülün sahip olduğu dahili bir temsildir. Her kapsül, bu iç temsili p
, kapsülün özelliğinin mevcut olma olasılığını xy
ve çıkarılan çeviri değerlerini hesaplamak için kullanır . Bu makaledeki Şekil 2, ağın xy
doğru kullanmayı öğrendiğinden emin olmak için bir kontroldür (öyle).
Nasıl eğitilmelidir? Her bağlantı arasında sadece standart sırt desteği var mı?
Özellikle, üretilen çıktı ile orijinal arasındaki benzerliği zorlayan bir kayıp kullanarak bir otomatik kodlayıcı olarak eğitmelisiniz. Ortalama kare hatası burada iyi çalışıyor. Bunun yanı sıra, evet, degrade inişini backprop ile yaymanız gerekecek.
loss = tf.losses.mean_squared_error(img_input_flat, gen_img)
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)