Scikit-Learn Label Yayılımını grafik yapılandırılmış verilerde nasıl kullanabilirim?


11

Araştırmamın bir parçası olarak, grafik üzerinde etiket yayılımı yapmakla ilgileniyorum. Özellikle bu iki yöntemle ilgileniyorum:

Scikit-learn'un bunu yapmak için bir model sunduğunu gördüm . Bununla birlikte, bu modelin vektör yapılandırılmış verilere ( yani veri noktaları) uygulanması gerekmektedir.

Model, bir çekirdek kullanarak veri noktalarından bir yakınlık matrisi oluşturur ve ardından algoritmayı yapılandırılmış matris üzerinde çalıştırır. Benzerlik matrisi yerine grafiğimin bitişiklik matrisini doğrudan girebilmek istiyorum.

Bunu nasıl başaracağınız hakkında bir fikrin var mı? Ya da yukarıda bahsedilen iki yöntem için doğrudan grafik yapılandırılmış veriler üzerinde etiket yayılımını çalıştırmaya izin verecek Python kütüphanesi biliyor musunuz?

Yardımınız için şimdiden teşekkürler!


Yakınlık matrisini hesapladıktan sonra ne yaptığını görmek için Scikit-learn kaynak kodunu kontrol ettiniz mi? Belki de doğrudan bitişiklik matrisinize uygulamak için bu bölümden sonraki kodu "kopyalayabilir".
Tasos

Yorumun için teşekkürler! Yani, aslında, şu anda yapıyorum bu, ama benim ihtiyaçlarına uygun değiştirmek için kodun bazı bölümleri biraz şifreli. Korkarım bu parçaların yeniden yazılması hatalara yol açacaktır. Daha basit bir yöntem olduğunu umuyordum.
Thibaud Martinez

1
Github.com/scikit-learn/scikit-learn/blob/7389dba/sklearn/… 'daki kaynak kodu, uygulamaların _build_graph yöntemini geçersiz kılması gerektiğini söylüyor. Dolayısıyla, doğal olarak önceden hesaplanmış matrisi kabul eden türetilmiş bir sınıf oluşturmayı denemelisiniz.
mikalai

Yanıtlar:


2

Burada kendi sorumu cevaplıyorum, umarım bazı okuyucular için faydalı olacaktır.

Scikit-learn öncelikle vektör yapılandırılmış verilerle başa çıkmak için tasarlanmıştır. Bu nedenle, grafik yapılı veriler üzerinde etiket yayılımı / etiket dağıtımı yapmak istiyorsanız, Scikit arabirimini kullanmak yerine yöntemi kendiniz yeniden uygulamaktan daha iyi bir seçim yapabilirsiniz.

İşte PyTorch'ta Etiket Yayılımı ve Etiket Yayma uygulaması.

Genel olarak iki yöntem, bitişiklik matrisinin nasıl normalleştirildiği ve etiketlerin her adımda nasıl yayıldığı konusunda varyasyonlarla aynı algoritmik adımları takip eder. Bu nedenle, iki modelimiz için bir temel sınıf oluşturalım.

from abc import abstractmethod
import torch

class BaseLabelPropagation:
    """Base class for label propagation models.

    Parameters
    ----------
    adj_matrix: torch.FloatTensor
        Adjacency matrix of the graph.
    """
    def __init__(self, adj_matrix):
        self.norm_adj_matrix = self._normalize(adj_matrix)
        self.n_nodes = adj_matrix.size(0)
        self.one_hot_labels = None 
        self.n_classes = None
        self.labeled_mask = None
        self.predictions = None

    @staticmethod
    @abstractmethod
    def _normalize(adj_matrix):
        raise NotImplementedError("_normalize must be implemented")

    @abstractmethod
    def _propagate(self):
        raise NotImplementedError("_propagate must be implemented")

    def _one_hot_encode(self, labels):
        # Get the number of classes
        classes = torch.unique(labels)
        classes = classes[classes != -1]
        self.n_classes = classes.size(0)

        # One-hot encode labeled data instances and zero rows corresponding to unlabeled instances
        unlabeled_mask = (labels == -1)
        labels = labels.clone()  # defensive copying
        labels[unlabeled_mask] = 0
        self.one_hot_labels = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)
        self.one_hot_labels = self.one_hot_labels.scatter(1, labels.unsqueeze(1), 1)
        self.one_hot_labels[unlabeled_mask, 0] = 0

        self.labeled_mask = ~unlabeled_mask

    def fit(self, labels, max_iter, tol):
        """Fits a semi-supervised learning label propagation model.

        labels: torch.LongTensor
            Tensor of size n_nodes indicating the class number of each node.
            Unlabeled nodes are denoted with -1.
        max_iter: int
            Maximum number of iterations allowed.
        tol: float
            Convergence tolerance: threshold to consider the system at steady state.
        """
        self._one_hot_encode(labels)

        self.predictions = self.one_hot_labels.clone()
        prev_predictions = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)

        for i in range(max_iter):
            # Stop iterations if the system is considered at a steady state
            variation = torch.abs(self.predictions - prev_predictions).sum().item()

            if variation < tol:
                print(f"The method stopped after {i} iterations, variation={variation:.4f}.")
                break

            prev_predictions = self.predictions
            self._propagate()

    def predict(self):
        return self.predictions

    def predict_classes(self):
        return self.predictions.max(dim=1).indices

Model, grafiğin bitişik matrisini ve düğümlerin etiketlerini girdi olarak alır. Etiketler, etiketlenmemiş düğümlerin pozisyonunda bir -1 ile her bir düğümün sınıf numarasını gösteren bir tamsayı vektörü biçimindedir.

Etiket Yayılım algoritması aşağıda sunulmuştur.

W: grafiğin bitişiklik matrisi Çapraz derece matrisini hesaplayın D tarafından DbenbenΣjWbenj Başlatma Y^(0)(y1,...,yl,0,0,...,0) yinelemek  1. Y^(t+1)D-1WY^(t) 2. Y^l(t+1)Yl yakınsamaya kadar Y^() Etiket noktası xben işareti ile y^ben()

Gönderen Xiaojin Zhu ve Zoubin Ghahramani. Etiket yayılımı ile etiketlenmiş ve etiketlenmemiş verilerden öğrenme. Teknik Rapor CMU-CALD-02-107, Carnegie Mellon Üniversitesi, 2002

Aşağıdaki uygulamayı alıyoruz.

class LabelPropagation(BaseLabelPropagation):
    def __init__(self, adj_matrix):
        super().__init__(adj_matrix)

    @staticmethod
    def _normalize(adj_matrix):
        """Computes D^-1 * W"""
        degs = adj_matrix.sum(dim=1)
        degs[degs == 0] = 1  # avoid division by 0 error
        return adj_matrix / degs[:, None]

    def _propagate(self):
        self.predictions = torch.matmul(self.norm_adj_matrix, self.predictions)

        # Put back already known labels
        self.predictions[self.labeled_mask] = self.one_hot_labels[self.labeled_mask]

    def fit(self, labels, max_iter=1000, tol=1e-3):
        super().fit(labels, max_iter, tol)

Etiket Yayma algoritması:

W: grafiğin bitişiklik matrisi Çapraz derece matrisini hesaplayın D tarafından DbenbenΣjWbenj Laplacian normalleştirilmiş grafiğini hesapla LD-1/2WD-1/2 Başlatma Y^(0)(y1,...,yl,0,0,...,0) Bir parametre seçin α[0,1) yinelemek Y^(t+1)αLY^(t)+(1-α)Y^(0) yakınsamaya kadar Y^() Etiket noktası xben işareti ile y^ben()

Gönderen Dengyong Zhou, Olivier Bousquet'nin Thomas Navin Lal Jason Weston, Bernhard Schoelkopf. Yerel ve küresel tutarlılıkla öğrenme (2004)

Dolayısıyla uygulama aşağıdaki gibidir.

class LabelSpreading(BaseLabelPropagation):
    def __init__(self, adj_matrix):
        super().__init__(adj_matrix)
        self.alpha = None

    @staticmethod
    def _normalize(adj_matrix):
        """Computes D^-1/2 * W * D^-1/2"""
        degs = adj_matrix.sum(dim=1)
        norm = torch.pow(degs, -0.5)
        norm[torch.isinf(norm)] = 1
        return adj_matrix * norm[:, None] * norm[None, :]

    def _propagate(self):
        self.predictions = (
            self.alpha * torch.matmul(self.norm_adj_matrix, self.predictions)
            + (1 - self.alpha) * self.one_hot_labels
        )

    def fit(self, labels, max_iter=1000, tol=1e-3, alpha=0.5):
        """
        Parameters
        ----------
        alpha: float
            Clamping factor.
        """
        self.alpha = alpha
        super().fit(labels, max_iter, tol)

Şimdi yayılma modellerimizi sentetik veriler üzerinde test edelim. Bunu yapmak için bir mağara adamı grafiği seçiyoruz .

import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

# Create caveman graph
n_cliques = 4
size_cliques = 10
caveman_graph = nx.connected_caveman_graph(n_cliques, size_cliques)
adj_matrix = nx.adjacency_matrix(caveman_graph).toarray()

# Create labels
labels = np.full(n_cliques * size_cliques, -1.)

# Only one node per clique is labeled. Each clique belongs to a different class.
labels[0] = 0
labels[size_cliques] = 1
labels[size_cliques * 2] = 2
labels[size_cliques * 3] = 3

# Create input tensors
adj_matrix_t = torch.FloatTensor(adj_matrix)
labels_t = torch.LongTensor(labels)

# Learn with Label Propagation
label_propagation = LabelPropagation(adj_matrix_t)
label_propagation.fit(labels_t)
label_propagation_output_labels = label_propagation.predict_classes()

# Learn with Label Spreading
label_spreading = LabelSpreading(adj_matrix_t)
label_spreading.fit(labels_t, alpha=0.8)
label_spreading_output_labels = label_spreading.predict_classes()

# Plot graphs
color_map = {-1: "orange", 0: "blue", 1: "green", 2: "red", 3: "cyan"}
input_labels_colors = [color_map[l] for l in labels]
lprop_labels_colors = [color_map[l] for l in label_propagation_output_labels.numpy()]
lspread_labels_colors = [color_map[l] for l in label_spreading_output_labels.numpy()]

plt.figure(figsize=(14, 6))
ax1 = plt.subplot(1, 4, 1)
ax2 = plt.subplot(1, 4, 2)
ax3 = plt.subplot(1, 4, 3)

ax1.title.set_text("Raw data (4 classes)")
ax2.title.set_text("Label Propagation")
ax3.title.set_text("Label Spreading")

pos = nx.spring_layout(caveman_graph)
nx.draw(caveman_graph, ax=ax1, pos=pos, node_color=input_labels_colors, node_size=50)
nx.draw(caveman_graph, ax=ax2, pos=pos, node_color=lprop_labels_colors, node_size=50)
nx.draw(caveman_graph, ax=ax3, pos=pos, node_color=lspread_labels_colors, node_size=50)

# Legend
ax4 = plt.subplot(1, 4, 4)
ax4.axis("off")
legend_colors = ["orange", "blue", "green", "red", "cyan"]
legend_labels = ["unlabeled", "class 0", "class 1", "class 2", "class 3"]
dummy_legend = [ax4.plot([], [], ls='-', c=c)[0] for c in legend_colors]
plt.legend(dummy_legend, legend_labels)

plt.show()

Uygulanan modeller düzgün çalışır ve grafikteki toplulukları algılamaya izin verir.

Bir mağara adamı grafiğinde test edilen etiket yayılımı ve etiket yayma uygulamaları

Not: Sunulan yayılma yöntemleri, yönlendirilmemiş grafiklerde kullanılmak içindir.

Kod burada etkileşimli bir Jupyter not defteri olarak mevcuttur .

cilt

Sitemizi kullandığınızda şunları okuyup anladığınızı kabul etmiş olursunuz: Çerez Politikası ve Gizlilik Politikası.
Licensed under cc by-sa 3.0 with attribution required.