Trả lời câu hỏi của riêng tôi ở đây, vì tôi hy vọng nó sẽ hữu ích cho một số độc giả.
Scikit-learn được thiết kế chủ yếu để đối phó với dữ liệu có cấu trúc vector. Do đó, nếu bạn muốn thực hiện việc truyền bá nhãn / lan truyền nhãn trên dữ liệu có cấu trúc biểu đồ, có lẽ bạn nên tự mình thực hiện lại phương pháp thay vì sử dụng giao diện Scikit.
Dưới đây là một triển khai Tuyên truyền Nhãn và Truyền bá Nhãn trong PyTorch.
Hai phương pháp tổng thể tuân theo các bước thuật toán giống nhau, với các biến thể về cách ma trận kề được chuẩn hóa và cách các nhãn được truyền bá ở mỗi bước. Do đó, hãy tạo một lớp cơ sở cho hai mô hình của chúng tôi.
from abc import abstractmethod
import torch
class BaseLabelPropagation:
"""Base class for label propagation models.
Parameters
----------
adj_matrix: torch.FloatTensor
Adjacency matrix of the graph.
"""
def __init__(self, adj_matrix):
self.norm_adj_matrix = self._normalize(adj_matrix)
self.n_nodes = adj_matrix.size(0)
self.one_hot_labels = None
self.n_classes = None
self.labeled_mask = None
self.predictions = None
@staticmethod
@abstractmethod
def _normalize(adj_matrix):
raise NotImplementedError("_normalize must be implemented")
@abstractmethod
def _propagate(self):
raise NotImplementedError("_propagate must be implemented")
def _one_hot_encode(self, labels):
# Get the number of classes
classes = torch.unique(labels)
classes = classes[classes != -1]
self.n_classes = classes.size(0)
# One-hot encode labeled data instances and zero rows corresponding to unlabeled instances
unlabeled_mask = (labels == -1)
labels = labels.clone() # defensive copying
labels[unlabeled_mask] = 0
self.one_hot_labels = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)
self.one_hot_labels = self.one_hot_labels.scatter(1, labels.unsqueeze(1), 1)
self.one_hot_labels[unlabeled_mask, 0] = 0
self.labeled_mask = ~unlabeled_mask
def fit(self, labels, max_iter, tol):
"""Fits a semi-supervised learning label propagation model.
labels: torch.LongTensor
Tensor of size n_nodes indicating the class number of each node.
Unlabeled nodes are denoted with -1.
max_iter: int
Maximum number of iterations allowed.
tol: float
Convergence tolerance: threshold to consider the system at steady state.
"""
self._one_hot_encode(labels)
self.predictions = self.one_hot_labels.clone()
prev_predictions = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)
for i in range(max_iter):
# Stop iterations if the system is considered at a steady state
variation = torch.abs(self.predictions - prev_predictions).sum().item()
if variation < tol:
print(f"The method stopped after {i} iterations, variation={variation:.4f}.")
break
prev_predictions = self.predictions
self._propagate()
def predict(self):
return self.predictions
def predict_classes(self):
return self.predictions.max(dim=1).indices
Mô hình lấy đầu vào là ma trận kề của đồ thị cũng như nhãn của các nút. Các nhãn ở dạng một vectơ của một số nguyên cho biết số lớp của mỗi nút có -1 ở vị trí của các nút không được gắn nhãn.
Thuật toán tuyên truyền nhãn được trình bày dưới đây.
W : ma trận kề của đồ thị Tính ma trận độ chéo D theo Di i← ΣjWtôi j Khởi tạo Y^( 0 )← ( y1, ... , ytôi, 0 , 0 , ĐI , 0 ) Lặp đi lặp lại 1. Y^( t + 1 )← D- 1W Y^( t ) 2. Y^( t + 1 )tôi← Ytôi cho đến khi hội tụ đến Y^( ∞ ) Nhãn điểm xTôi bởi dấu hiệu của y^( ∞ )Tôi
Từ Xiaojin Zhu và Zoubin Ghahramani. Học hỏi từ dữ liệu được dán nhãn và không nhãn với tuyên truyền nhãn. Báo cáo kỹ thuật CMU-CALD-02-107, Đại học Carnegie Mellon, 2002
Chúng tôi nhận được thực hiện sau đây.
class LabelPropagation(BaseLabelPropagation):
def __init__(self, adj_matrix):
super().__init__(adj_matrix)
@staticmethod
def _normalize(adj_matrix):
"""Computes D^-1 * W"""
degs = adj_matrix.sum(dim=1)
degs[degs == 0] = 1 # avoid division by 0 error
return adj_matrix / degs[:, None]
def _propagate(self):
self.predictions = torch.matmul(self.norm_adj_matrix, self.predictions)
# Put back already known labels
self.predictions[self.labeled_mask] = self.one_hot_labels[self.labeled_mask]
def fit(self, labels, max_iter=1000, tol=1e-3):
super().fit(labels, max_iter, tol)
Thuật toán trải rộng nhãn là:
W : ma trận kề của đồ thị Tính ma trận độ chéo D theo Di i← ΣjWtôi j Tính toán đồ thị chuẩn hóa Laplacian L ← D- 1 / 2W D- 1 / 2 Khởi tạo Y^( 0 )← ( y1, ... , ytôi, 0 , 0 , ĐI , 0 ) Chọn một tham số alpha ∈ [ 0 , 1 ) Lặp lại Y^( t + 1 ) ← α L Y^( t )+ ( 1 - a ) Y^( 0 ) cho đến khi hội tụ đến Y^( ∞ ) Nhãn điểm xTôi bởi dấu hiệu của y^( ∞ )Tôi
Từ Dengyong Zhou, Olivier Bousquet, Thomas Navin Lal, Jason Weston, Bernhard Schoelkopf. Học tập với sự thống nhất địa phương và toàn cầu (2004)
Do đó, việc thực hiện là như sau.
class LabelSpreading(BaseLabelPropagation):
def __init__(self, adj_matrix):
super().__init__(adj_matrix)
self.alpha = None
@staticmethod
def _normalize(adj_matrix):
"""Computes D^-1/2 * W * D^-1/2"""
degs = adj_matrix.sum(dim=1)
norm = torch.pow(degs, -0.5)
norm[torch.isinf(norm)] = 1
return adj_matrix * norm[:, None] * norm[None, :]
def _propagate(self):
self.predictions = (
self.alpha * torch.matmul(self.norm_adj_matrix, self.predictions)
+ (1 - self.alpha) * self.one_hot_labels
)
def fit(self, labels, max_iter=1000, tol=1e-3, alpha=0.5):
"""
Parameters
----------
alpha: float
Clamping factor.
"""
self.alpha = alpha
super().fit(labels, max_iter, tol)
Bây giờ chúng ta hãy kiểm tra các mô hình lan truyền của chúng tôi trên dữ liệu tổng hợp. Để làm như vậy, chúng tôi chọn sử dụng biểu đồ thượng cổ .
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
# Create caveman graph
n_cliques = 4
size_cliques = 10
caveman_graph = nx.connected_caveman_graph(n_cliques, size_cliques)
adj_matrix = nx.adjacency_matrix(caveman_graph).toarray()
# Create labels
labels = np.full(n_cliques * size_cliques, -1.)
# Only one node per clique is labeled. Each clique belongs to a different class.
labels[0] = 0
labels[size_cliques] = 1
labels[size_cliques * 2] = 2
labels[size_cliques * 3] = 3
# Create input tensors
adj_matrix_t = torch.FloatTensor(adj_matrix)
labels_t = torch.LongTensor(labels)
# Learn with Label Propagation
label_propagation = LabelPropagation(adj_matrix_t)
label_propagation.fit(labels_t)
label_propagation_output_labels = label_propagation.predict_classes()
# Learn with Label Spreading
label_spreading = LabelSpreading(adj_matrix_t)
label_spreading.fit(labels_t, alpha=0.8)
label_spreading_output_labels = label_spreading.predict_classes()
# Plot graphs
color_map = {-1: "orange", 0: "blue", 1: "green", 2: "red", 3: "cyan"}
input_labels_colors = [color_map[l] for l in labels]
lprop_labels_colors = [color_map[l] for l in label_propagation_output_labels.numpy()]
lspread_labels_colors = [color_map[l] for l in label_spreading_output_labels.numpy()]
plt.figure(figsize=(14, 6))
ax1 = plt.subplot(1, 4, 1)
ax2 = plt.subplot(1, 4, 2)
ax3 = plt.subplot(1, 4, 3)
ax1.title.set_text("Raw data (4 classes)")
ax2.title.set_text("Label Propagation")
ax3.title.set_text("Label Spreading")
pos = nx.spring_layout(caveman_graph)
nx.draw(caveman_graph, ax=ax1, pos=pos, node_color=input_labels_colors, node_size=50)
nx.draw(caveman_graph, ax=ax2, pos=pos, node_color=lprop_labels_colors, node_size=50)
nx.draw(caveman_graph, ax=ax3, pos=pos, node_color=lspread_labels_colors, node_size=50)
# Legend
ax4 = plt.subplot(1, 4, 4)
ax4.axis("off")
legend_colors = ["orange", "blue", "green", "red", "cyan"]
legend_labels = ["unlabeled", "class 0", "class 1", "class 2", "class 3"]
dummy_legend = [ax4.plot([], [], ls='-', c=c)[0] for c in legend_colors]
plt.legend(dummy_legend, legend_labels)
plt.show()
Các mô hình được thực hiện hoạt động chính xác và cho phép phát hiện các cộng đồng trong biểu đồ.
Lưu ý: Các phương pháp lan truyền được trình bày có nghĩa là được sử dụng trên các đồ thị vô hướng.
Mã này có sẵn như là một máy tính xách tay Jupyter tương tác ở đây .