Làm cách nào để sử dụng Scikit-Learn Label Propagation trên dữ liệu có cấu trúc biểu đồ?


11

Là một phần của nghiên cứu của tôi, tôi quan tâm đến việc thực hiện tuyên truyền nhãn trên biểu đồ. Tôi đặc biệt quan tâm đến hai phương pháp đó:

Tôi thấy rằng scikit-learn cung cấp một mô hình để làm điều đó. Tuy nhiên, mô hình này được cho là sẽ được áp dụng trên dữ liệu có cấu trúc vector ( tức là điểm dữ liệu).

Mô hình xây dựng ma trận ái lực từ các điểm dữ liệu bằng hạt nhân, sau đó chạy thuật toán trên ma trận được xây dựng. Tôi muốn có thể nhập trực tiếp ma trận kề của biểu đồ của tôi thay cho ma trận tương tự.

Bất kỳ ý tưởng về làm thế nào để đạt được điều đó? Hoặc bạn có biết bất kỳ thư viện Python nào sẽ cho phép chạy lan truyền nhãn trực tiếp trên dữ liệu có cấu trúc biểu đồ cho hai phương thức đã nói ở trên không?

Cảm ơn trước sự giúp đỡ của bạn!


Bạn đã kiểm tra mã nguồn Scikit-learn để xem nó làm gì sau khi tính toán ma trận ái lực chưa? Có lẽ có thể "sao chép" mã sau phần đó để áp dụng trực tiếp vào ma trận kề của bạn.
TASos

Cám ơn bạn đã góp ý! Vì vậy, trên thực tế, đây là những gì tôi hiện đang làm, nhưng một số phần mã tôi cần sửa đổi cho phù hợp với nhu cầu của tôi có phần khó hiểu. Tôi sợ viết lại những phần đó sẽ dẫn đến lỗi. Tôi đã hy vọng tồn tại một phương pháp đơn giản hơn.
Thibaud Martinez

1
Mã nguồn tại github.com/scikit-learn/scikit-learn/blob/7389dba/sklearn/ trộm - nói rằng việc triển khai nên ghi đè phương thức _build_graph. Vì vậy, tự nhiên bạn nên thử tạo một lớp dẫn xuất chấp nhận ma trận tiền mã hóa.
mikalai

Câu trả lời:


2

Trả lời câu hỏi của riêng tôi ở đây, vì tôi hy vọng nó sẽ hữu ích cho một số độc giả.

Scikit-learn được thiết kế chủ yếu để đối phó với dữ liệu có cấu trúc vector. Do đó, nếu bạn muốn thực hiện việc truyền bá nhãn / lan truyền nhãn trên dữ liệu có cấu trúc biểu đồ, có lẽ bạn nên tự mình thực hiện lại phương pháp thay vì sử dụng giao diện Scikit.

Dưới đây là một triển khai Tuyên truyền Nhãn và Truyền bá Nhãn trong PyTorch.

Hai phương pháp tổng thể tuân theo các bước thuật toán giống nhau, với các biến thể về cách ma trận kề được chuẩn hóa và cách các nhãn được truyền bá ở mỗi bước. Do đó, hãy tạo một lớp cơ sở cho hai mô hình của chúng tôi.

from abc import abstractmethod
import torch

class BaseLabelPropagation:
    """Base class for label propagation models.

    Parameters
    ----------
    adj_matrix: torch.FloatTensor
        Adjacency matrix of the graph.
    """
    def __init__(self, adj_matrix):
        self.norm_adj_matrix = self._normalize(adj_matrix)
        self.n_nodes = adj_matrix.size(0)
        self.one_hot_labels = None 
        self.n_classes = None
        self.labeled_mask = None
        self.predictions = None

    @staticmethod
    @abstractmethod
    def _normalize(adj_matrix):
        raise NotImplementedError("_normalize must be implemented")

    @abstractmethod
    def _propagate(self):
        raise NotImplementedError("_propagate must be implemented")

    def _one_hot_encode(self, labels):
        # Get the number of classes
        classes = torch.unique(labels)
        classes = classes[classes != -1]
        self.n_classes = classes.size(0)

        # One-hot encode labeled data instances and zero rows corresponding to unlabeled instances
        unlabeled_mask = (labels == -1)
        labels = labels.clone()  # defensive copying
        labels[unlabeled_mask] = 0
        self.one_hot_labels = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)
        self.one_hot_labels = self.one_hot_labels.scatter(1, labels.unsqueeze(1), 1)
        self.one_hot_labels[unlabeled_mask, 0] = 0

        self.labeled_mask = ~unlabeled_mask

    def fit(self, labels, max_iter, tol):
        """Fits a semi-supervised learning label propagation model.

        labels: torch.LongTensor
            Tensor of size n_nodes indicating the class number of each node.
            Unlabeled nodes are denoted with -1.
        max_iter: int
            Maximum number of iterations allowed.
        tol: float
            Convergence tolerance: threshold to consider the system at steady state.
        """
        self._one_hot_encode(labels)

        self.predictions = self.one_hot_labels.clone()
        prev_predictions = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)

        for i in range(max_iter):
            # Stop iterations if the system is considered at a steady state
            variation = torch.abs(self.predictions - prev_predictions).sum().item()

            if variation < tol:
                print(f"The method stopped after {i} iterations, variation={variation:.4f}.")
                break

            prev_predictions = self.predictions
            self._propagate()

    def predict(self):
        return self.predictions

    def predict_classes(self):
        return self.predictions.max(dim=1).indices

Mô hình lấy đầu vào là ma trận kề của đồ thị cũng như nhãn của các nút. Các nhãn ở dạng một vectơ của một số nguyên cho biết số lớp của mỗi nút có -1 ở vị trí của các nút không được gắn nhãn.

Thuật toán tuyên truyền nhãn được trình bày dưới đây.

W: ma trận kề của đồ thị Tính ma trận độ chéo D bởi DTôiTôiΣjWTôij Khởi tạo Y^(0)(y1,Giáo dục,ytôi,0,0,Giáo dục,0) Lặp đi lặp lại  1. Y^(t+1)D-1WY^(t) 2. Y^tôi(t+1)Ytôi cho đến khi hội tụ đến Y^() Điểm nhãn xTôi bởi dấu hiệu của y^Tôi()

Từ Xiaojin Zhu và Zoubin Ghahramani. Học hỏi từ dữ liệu được dán nhãn và không nhãn với tuyên truyền nhãn. Báo cáo kỹ thuật CMU-CALD-02-107, Đại học Carnegie Mellon, 2002

Chúng tôi nhận được thực hiện sau đây.

class LabelPropagation(BaseLabelPropagation):
    def __init__(self, adj_matrix):
        super().__init__(adj_matrix)

    @staticmethod
    def _normalize(adj_matrix):
        """Computes D^-1 * W"""
        degs = adj_matrix.sum(dim=1)
        degs[degs == 0] = 1  # avoid division by 0 error
        return adj_matrix / degs[:, None]

    def _propagate(self):
        self.predictions = torch.matmul(self.norm_adj_matrix, self.predictions)

        # Put back already known labels
        self.predictions[self.labeled_mask] = self.one_hot_labels[self.labeled_mask]

    def fit(self, labels, max_iter=1000, tol=1e-3):
        super().fit(labels, max_iter, tol)

Thuật toán trải rộng nhãn là:

W: ma trận kề của đồ thị Tính ma trận độ chéo D bởi DTôiTôiΣjWTôij Tính toán đồ thị chuẩn hóa Laplacian LD-1/2WD-1/2 Khởi tạo Y^(0)(y1,Giáo dục,ytôi,0,0,Giáo dục,0) Chọn một tham số α[0,1) Lặp đi lặp lại Y^(t+1)αLY^(t)+(1-α)Y^(0) cho đến khi hội tụ đến Y^() Điểm nhãn xTôi bởi dấu hiệu của y^Tôi()

Từ Dengyong Zhou, Olivier Bousquet, Thomas Navin Lal, Jason Weston, Bernhard Schoelkopf. Học tập với sự thống nhất địa phương và toàn cầu (2004)

Do đó, việc thực hiện là như sau.

class LabelSpreading(BaseLabelPropagation):
    def __init__(self, adj_matrix):
        super().__init__(adj_matrix)
        self.alpha = None

    @staticmethod
    def _normalize(adj_matrix):
        """Computes D^-1/2 * W * D^-1/2"""
        degs = adj_matrix.sum(dim=1)
        norm = torch.pow(degs, -0.5)
        norm[torch.isinf(norm)] = 1
        return adj_matrix * norm[:, None] * norm[None, :]

    def _propagate(self):
        self.predictions = (
            self.alpha * torch.matmul(self.norm_adj_matrix, self.predictions)
            + (1 - self.alpha) * self.one_hot_labels
        )

    def fit(self, labels, max_iter=1000, tol=1e-3, alpha=0.5):
        """
        Parameters
        ----------
        alpha: float
            Clamping factor.
        """
        self.alpha = alpha
        super().fit(labels, max_iter, tol)

Bây giờ chúng ta hãy kiểm tra các mô hình lan truyền của chúng tôi trên dữ liệu tổng hợp. Để làm như vậy, chúng tôi chọn sử dụng biểu đồ thượng cổ .

import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

# Create caveman graph
n_cliques = 4
size_cliques = 10
caveman_graph = nx.connected_caveman_graph(n_cliques, size_cliques)
adj_matrix = nx.adjacency_matrix(caveman_graph).toarray()

# Create labels
labels = np.full(n_cliques * size_cliques, -1.)

# Only one node per clique is labeled. Each clique belongs to a different class.
labels[0] = 0
labels[size_cliques] = 1
labels[size_cliques * 2] = 2
labels[size_cliques * 3] = 3

# Create input tensors
adj_matrix_t = torch.FloatTensor(adj_matrix)
labels_t = torch.LongTensor(labels)

# Learn with Label Propagation
label_propagation = LabelPropagation(adj_matrix_t)
label_propagation.fit(labels_t)
label_propagation_output_labels = label_propagation.predict_classes()

# Learn with Label Spreading
label_spreading = LabelSpreading(adj_matrix_t)
label_spreading.fit(labels_t, alpha=0.8)
label_spreading_output_labels = label_spreading.predict_classes()

# Plot graphs
color_map = {-1: "orange", 0: "blue", 1: "green", 2: "red", 3: "cyan"}
input_labels_colors = [color_map[l] for l in labels]
lprop_labels_colors = [color_map[l] for l in label_propagation_output_labels.numpy()]
lspread_labels_colors = [color_map[l] for l in label_spreading_output_labels.numpy()]

plt.figure(figsize=(14, 6))
ax1 = plt.subplot(1, 4, 1)
ax2 = plt.subplot(1, 4, 2)
ax3 = plt.subplot(1, 4, 3)

ax1.title.set_text("Raw data (4 classes)")
ax2.title.set_text("Label Propagation")
ax3.title.set_text("Label Spreading")

pos = nx.spring_layout(caveman_graph)
nx.draw(caveman_graph, ax=ax1, pos=pos, node_color=input_labels_colors, node_size=50)
nx.draw(caveman_graph, ax=ax2, pos=pos, node_color=lprop_labels_colors, node_size=50)
nx.draw(caveman_graph, ax=ax3, pos=pos, node_color=lspread_labels_colors, node_size=50)

# Legend
ax4 = plt.subplot(1, 4, 4)
ax4.axis("off")
legend_colors = ["orange", "blue", "green", "red", "cyan"]
legend_labels = ["unlabeled", "class 0", "class 1", "class 2", "class 3"]
dummy_legend = [ax4.plot([], [], ls='-', c=c)[0] for c in legend_colors]
plt.legend(dummy_legend, legend_labels)

plt.show()

Các mô hình được thực hiện hoạt động chính xác và cho phép phát hiện các cộng đồng trong biểu đồ.

Việc truyền nhãn và triển khai trải nhãn được thử nghiệm trên biểu đồ thượng cổ

Lưu ý: Các phương pháp lan truyền được trình bày có nghĩa là được sử dụng trên các đồ thị vô hướng.

Mã này có sẵn như là một máy tính xách tay Jupyter tương tác ở đây .

Chất kết dính

Khi sử dụng trang web của chúng tôi, bạn xác nhận rằng bạn đã đọc và hiểu Chính sách cookieChính sách bảo mật của chúng tôi.
Licensed under cc by-sa 3.0 with attribution required.