From 008e02eb58237354a9bfdaafb6e98365ab185415 Mon Sep 17 00:00:00 2001 From: clingingsai <15010713603@qq.com> Date: Thu, 8 Jun 2023 18:24:48 +0800 Subject: [PATCH 1/3] add lightGCN model and modify KGCN model --- openhgnn/config.ini | 15 +- openhgnn/config.py | 16 +- openhgnn/dataset/RecommendationDataset.py | 145 ++++++++++ openhgnn/dataset/__init__.py | 2 + openhgnn/experiment.py | 1 + openhgnn/models/__init__.py | 5 +- openhgnn/models/lightGCN.py | 100 +++++++ openhgnn/trainerflow/__init__.py | 3 + openhgnn/trainerflow/kgcn_trainer.py | 5 +- openhgnn/trainerflow/lightGCN_trainer.py | 325 ++++++++++++++++++++++ 10 files changed, 610 insertions(+), 7 deletions(-) create mode 100644 openhgnn/models/lightGCN.py create mode 100644 openhgnn/trainerflow/lightGCN_trainer.py diff --git a/openhgnn/config.ini b/openhgnn/config.ini index 98998407..3664c3d5 100644 --- a/openhgnn/config.ini +++ b/openhgnn/config.ini @@ -381,7 +381,8 @@ n_neighbor = 8 aggregate = SUM n_relation = 60 n_user = 1872 -epoch_iter = 100 +# epoch_iter = 100 +max_epoch = 100 mini_batch_flag = True [HeGAN] @@ -752,4 +753,14 @@ train_epochs = 25 batch_num = 512 num_workers = 8 val_frequency = 1 -save_frequency = 2 \ No newline at end of file +save_frequency = 2 + +[lightGCN] +lr = 0.001 +weight_decay = 0.0001 +max_epoch = 1000 +batch_size = 1024 +embedding_size = 64 +num_layers = 3 +test_u_batch_size = 100 +topks = 20 diff --git a/openhgnn/config.py b/openhgnn/config.py index 1e4c1f2c..e9b4caa3 100644 --- a/openhgnn/config.py +++ b/openhgnn/config.py @@ -438,7 +438,8 @@ def __init__(self, file_path, model, dataset, task, gpu): self.aggregate = conf.get("KGCN", "aggregate") self.n_item = conf.getint("KGCN", "n_relation") self.n_user = conf.getint("KGCN", "n_user") - self.epoch_iter = conf.getint("KGCN", "epoch_iter") + # self.epoch_iter = conf.getint("KGCN", "epoch_iter") + self.max_epoch = conf.getint("KGCN", "max_epoch") elif self.model_name == 'general_HGNN': self.lr = conf.getfloat("general_HGNN", "lr") @@ -787,13 +788,24 @@ def __init__(self, file_path, model, dataset, task, gpu): # self.use_norm = conf.get("DiffMG", "use_norm") # self.out_nl = conf.get("DiffMG", "out_nl") - elif model == 'MeiREC': + elif self.model_name == 'MeiREC': self.lr = conf.getfloat("MeiREC", "lr") self.weight_decay = conf.getfloat("MeiREC", "weight_decay") self.vocab = conf.getint("MeiREC", "vocab_size") self.max_epoch = conf.getint("MeiREC", "train_epochs") self.batch_num = conf.getint("MeiREC", "batch_num") + elif self.model_name == 'lightGCN': + self.lr = conf.getfloat("lightGCN", "lr") + self.weight_decay = conf.getfloat("lightGCN", "weight_decay") + self.max_epoch = conf.getint("lightGCN", "max_epoch") + self.batch_size = conf.getint("lightGCN", "batch_size") + self.embedding_size = conf.getint("lightGCN", "embedding_size") + self.num_layers = conf.getint("lightGCN", "num_layers") + self.test_u_batch_size = conf.getint("lightGCN", "test_u_batch_size") + self.topks = conf.getint("lightGCN", "topks") + # self.alpha = conf.getfloat("lightGCN", "alpha") + if gpu == -1: self.device = th.device('cpu') elif gpu >= 0: diff --git a/openhgnn/dataset/RecommendationDataset.py b/openhgnn/dataset/RecommendationDataset.py index 92bd3974..34f92ce3 100644 --- a/openhgnn/dataset/RecommendationDataset.py +++ b/openhgnn/dataset/RecommendationDataset.py @@ -1,8 +1,11 @@ import os import dgl import torch as th +import numpy as np from . import BaseDataset, register_dataset from dgl.data.utils import load_graphs +from scipy.sparse import csr_matrix +import scipy.sparse as sp from .multigraph import MultiGraphDataset from ..sampler.negative_sampler import Uniform_exclusive from . import AcademicDataset @@ -46,6 +49,148 @@ def get_labels(self): return self.label +@register_dataset('lightGCN_recommendation') +class lightGCN_Recommendation(RecommendationDataset): + + def __init__(self, dataset_name, *args, **kwargs): + super(RecommendationDataset, self).__init__(*args, **kwargs) + + # train and test data + self.mode_dict = {'train': 0, "test": 1} + self.mode = self.mode_dict['train'] + self.n_user = 0 + self.m_item = 0 + path = './openhgnn/dataset/' + dataset_name + train_file = path + '/train.txt' + test_file = path + '/test.txt' + self.path = path + trainUniqueUsers, trainItem, trainUser = [], [], [] + testUniqueUsers, testItem, testUser = [], [], [] + self.traindataSize = 0 + self.testDataSize = 0 + + with open(train_file) as f: + for l in f.readlines(): + if len(l) > 0: + l = l.strip('\n').split(' ') + items = [int(i) for i in l[1:]] + uid = int(l[0]) + trainUniqueUsers.append(uid) + trainUser.extend([uid] * len(items)) + trainItem.extend(items) + + self.m_item = max(self.m_item, max(items)) + self.n_user = max(self.n_user, uid) + self.traindataSize += len(items) + self.trainUniqueUsers = np.array(trainUniqueUsers) + self.trainUser = np.array(trainUser) + self.trainItem = np.array(trainItem) + + with open(test_file) as f: + for l in f.readlines(): + if len(l) > 0: + l = l.strip('\n').split(' ') + items = [int(i) for i in l[1:]] + uid = int(l[0]) + testUniqueUsers.append(uid) + testUser.extend([uid] * len(items)) + testItem.extend(items) + self.m_item = max(self.m_item, max(items)) + self.n_user = max(self.n_user, uid) + self.testDataSize += len(items) + self.m_item += 1 + self.n_user += 1 + self.testUniqueUsers = np.array(testUniqueUsers) + self.testUser = np.array(testUser) + self.testItem = np.array(testItem) + + self.Graph = None + + # (users,items), bipartite graph + self.UserItemNet = csr_matrix((np.ones(len(self.trainUser)), (self.trainUser, self.trainItem)), + shape=(self.n_user, self.m_item)) + self.users_D = np.array(self.UserItemNet.sum(axis=1)).squeeze() + self.users_D[self.users_D == 0.] = 1 + self.items_D = np.array(self.UserItemNet.sum(axis=0)).squeeze() + self.items_D[self.items_D == 0.] = 1. + # pre-calculate + self.allPos = self.getUserPosItems(list(range(self.n_user))) + self.testDict = self.__build_test() + + self.g = self.getSparseGraph() + + def get_split(self): + return self.g, [], [] + + def __build_test(self): + """ + return: + dict: {user: [items]} + """ + test_data = {} + for i, item in enumerate(self.testItem): + user = self.testUser[i] + if test_data.get(user): + test_data[user].append(item) + else: + test_data[user] = [item] + return test_data + + def getUserPosItems(self, users): + posItems = [] + for user in users: + posItems.append(self.UserItemNet[user].nonzero()[1]) + return posItems + + def _convert_sp_mat_to_sp_tensor(self, X): + coo = X.tocoo().astype(np.float32) + row = th.Tensor(coo.row).long() + col = th.Tensor(coo.col).long() + index = th.stack([row, col]) + data = th.FloatTensor(coo.data) + return th.sparse.FloatTensor(index, data, th.Size(coo.shape)) + + def getSparseGraph(self): + print("loading adjacency matrix") + if self.Graph is None: + try: + pre_adj_mat = sp.load_npz(self.path + '/s_pre_adj_mat.npz') + print("successfully loaded...") + norm_adj = pre_adj_mat + except: + print("generating adjacency matrix") + # s = time() + adj_mat = sp.dok_matrix((self.n_user + self.m_item, self.n_user + self.m_item), dtype=np.float32) + adj_mat = adj_mat.tolil() + R = self.UserItemNet.tolil() + adj_mat[:self.n_user, self.n_user:] = R + adj_mat[self.n_user:, :self.n_user] = R.T + adj_mat = adj_mat.todok() + # adj_mat = adj_mat + sp.eye(adj_mat.shape[0]) + + rowsum = np.array(adj_mat.sum(axis=1)) + d_inv = np.power(rowsum, -0.5).flatten() + d_inv[np.isinf(d_inv)] = 0. + d_mat = sp.diags(d_inv) + + norm_adj = d_mat.dot(adj_mat) + norm_adj = norm_adj.dot(d_mat) + norm_adj = norm_adj.tocsr() + # end = time() + # print(f"costing {end - s}s, saved norm_mat...") + sp.save_npz(self.path + '/s_pre_adj_mat.npz', norm_adj) + + # if self.split == True: + # self.Graph = self._split_A_hat(norm_adj) + # print("done split matrix") + # else: + self.Graph = self._convert_sp_mat_to_sp_tensor(norm_adj) + # self.Graph = self.Graph.coalesce().to(self.device) + self.Graph = self.Graph.coalesce() + print("don't split the matrix") + return self.Graph + + @register_dataset('hin_recommendation') class HINRecommendation(RecommendationDataset): def __init__(self, dataset_name, *args, **kwargs): diff --git a/openhgnn/dataset/__init__.py b/openhgnn/dataset/__init__.py index 657a35d4..2694fd84 100644 --- a/openhgnn/dataset/__init__.py +++ b/openhgnn/dataset/__init__.py @@ -89,6 +89,8 @@ def build_dataset(dataset, task, *args, **kwargs): _dataset = 'kg_link_prediction' elif dataset in ['LastFM4KGCN']: _dataset = 'kgcn_recommendation' + elif dataset in ['gowalla', 'yelp2018', 'amazon-book']: + _dataset = 'lightGCN_recommendation' elif dataset in ['yelp4rec']: _dataset = 'hin_' + task elif dataset in ['dblp4Mg2vec_4', 'dblp4Mg2vec_5']: diff --git a/openhgnn/experiment.py b/openhgnn/experiment.py index 1a161926..c2d438a9 100644 --- a/openhgnn/experiment.py +++ b/openhgnn/experiment.py @@ -57,6 +57,7 @@ class Experiment(object): 'DHNE': 'DHNE_trainer', 'DiffMG': 'DiffMG_trainer', 'MeiREC': 'MeiREC_trainer', + 'lightGCN': 'lightGCN_trainer', } immutable_params = ['model', 'dataset', 'task'] diff --git a/openhgnn/models/__init__.py b/openhgnn/models/__init__.py index 51c4505e..afa1fa9e 100644 --- a/openhgnn/models/__init__.py +++ b/openhgnn/models/__init__.py @@ -103,7 +103,8 @@ def build_model_from_args(args, hg): 'DHNE': 'openhgnn.models.DHNE', 'DiffMG': 'openhgnn.models.DiffMG', 'MeiREC': 'openhgnn.models.MeiREC', - 'HGNN_AC': 'openhgnn.models.HGNN_AC' + 'HGNN_AC': 'openhgnn.models.HGNN_AC', + 'lightGCN': 'openhgnn.models.lightGCN', } from .CompGCN import CompGCN @@ -141,6 +142,7 @@ def build_model_from_args(args, hg): from .DiffMG import DiffMG from .MeiREC import MeiREC from .HGNN_AC import HGNN_AC +from .lightGCN import lightGCN __all__ = [ 'BaseModel', @@ -175,5 +177,6 @@ def build_model_from_args(args, hg): 'DHNE', 'DiffMG', 'MeiREC', + 'lightGCN', ] classes = __all__ diff --git a/openhgnn/models/lightGCN.py b/openhgnn/models/lightGCN.py new file mode 100644 index 00000000..64d2d421 --- /dev/null +++ b/openhgnn/models/lightGCN.py @@ -0,0 +1,100 @@ +import torch as th +import torch.nn as nn +import dgl.function as fn +from . import BaseModel, register_model +from torch import Tensor +from torch.nn import Embedding, ModuleList +from dgl.utils import expand_as_pair + +@register_model('lightGCN') +class lightGCN(BaseModel): + + @classmethod + def build_model_from_args(cls, args, g): + return cls(g, args) + + def __init__(self, g, args, **kwargs): + super(lightGCN, self).__init__() + + self.g = g['g'] + self.num_nodes = self.g.shape[0] + self.num_user = g['user_num'] + self.num_item = g['item_num'] + self.embedding_dim = args.embedding_size + self.num_layers = args.num_layers + # if args.alpha is None: + # self.alpha = 1. / (self.num_layers + 1) + self.alpha = 1. / (self.num_layers + 1) + if isinstance(self.alpha, Tensor): + assert self.alpha.size(0) == self.num_layers + 1 + else: + self.alpha = th.tensor([self.alpha] * (self.num_layers + 1)) + + self.embedding = Embedding(self.num_nodes, self.embedding_dim) + self.embedding_user = th.nn.Embedding( + num_embeddings=self.num_user, embedding_dim=self.embedding_dim) + self.embedding_item = th.nn.Embedding( + num_embeddings=self.num_item, embedding_dim=self.embedding_dim) + + nn.init.normal_(self.embedding_user.weight, std=0.1) + nn.init.normal_(self.embedding_item.weight, std=0.1) + self.f = nn.Sigmoid() + + self.reset_parameters() + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + # th.nn.init.xavier_uniform_(self.embedding.weight) + th.nn.init.normal_(self.embedding.weight, std=0.1) + + def computer(self): + """ + propagate methods for lightGCN + """ + all_emb = self.embedding.weight + embs = [all_emb] + + g_droped = self.g + + for layer in range(self.num_layers): + + all_emb = th.sparse.mm(g_droped, all_emb) + embs.append(all_emb) + embs = th.stack(embs, dim=1) + + # print(embs.size()) + light_out = th.mean(embs, dim=1) + users, items = th.split(light_out, [self.num_user, self.num_item]) + return users, items + + def getUsersRating(self, users): + all_users, all_items = self.computer() + users_emb = all_users[users.long()] + items_emb = all_items + rating = self.f(th.matmul(users_emb, items_emb.t())) + return rating + + def getEmbedding(self, users, pos_items, neg_items): + all_users, all_items = self.computer() + users_emb = all_users[users] + pos_emb = all_items[pos_items] + neg_emb = all_items[neg_items] + users_emb_ego = self.embedding_user(users) + pos_emb_ego = self.embedding_item(pos_items) + neg_emb_ego = self.embedding_item(neg_items) + return users_emb, pos_emb, neg_emb, users_emb_ego, pos_emb_ego, neg_emb_ego + + def bpr_loss(self, users, pos, neg): + (users_emb, pos_emb, neg_emb, + userEmb0, posEmb0, negEmb0) = self.getEmbedding(users.long(), pos.long(), neg.long()) + reg_loss = (1 / 2) * (userEmb0.norm(2).pow(2) + + posEmb0.norm(2).pow(2) + + negEmb0.norm(2).pow(2)) / float(len(users)) + pos_scores = th.mul(users_emb, pos_emb) + pos_scores = th.sum(pos_scores, dim=1) + neg_scores = th.mul(users_emb, neg_emb) + neg_scores = th.sum(neg_scores, dim=1) + + loss = th.mean(th.nn.functional.softplus(neg_scores - pos_scores)) + + return loss, reg_loss diff --git a/openhgnn/trainerflow/__init__.py b/openhgnn/trainerflow/__init__.py index fed66eca..e0387e3d 100644 --- a/openhgnn/trainerflow/__init__.py +++ b/openhgnn/trainerflow/__init__.py @@ -73,6 +73,7 @@ def build_flow(args, flow_name): 'DHNE_trainer': 'openhgnn.trainerflow.DHNE_trainer', 'DiffMG_trainer': 'openhgnn.trainerflow.DiffMG_trainer', 'MeiREC_trainer': 'openhgnn.trainerflow.MeiRec_trainer', + 'lightGCN_trainer': 'openhgnn.trainerflow.lightGCN_trainer', } from .node_classification import NodeClassification @@ -96,6 +97,7 @@ def build_flow(args, flow_name): from .DiffMG_trainer import DiffMG_trainer from .MeiRec_trainer import MeiRECTrainer from .node_classification_ac import NodeClassificationAC +from .lightGCN_trainer import lightGCNTrainer __all__ = [ 'BaseFlow', @@ -119,5 +121,6 @@ def build_flow(args, flow_name): 'DHNE_trainer', 'DiffMG_trainer', 'MeiRECTrainer', + 'lightGCNTrainer', ] classes = __all__ diff --git a/openhgnn/trainerflow/kgcn_trainer.py b/openhgnn/trainerflow/kgcn_trainer.py index 5f7d7d8a..838bfe37 100644 --- a/openhgnn/trainerflow/kgcn_trainer.py +++ b/openhgnn/trainerflow/kgcn_trainer.py @@ -63,8 +63,8 @@ def preprocess(self, dataIndex): return def train(self): - epoch_iter = self.args.epoch_iter - for self.epoch in range(epoch_iter): + max_epoch = self.args.max_epoch + for self.epoch in range(max_epoch): self._mini_train_step() print('train_data:') self.evaluate(self.trainIndex) @@ -82,6 +82,7 @@ def _mini_train_step(self,): L = 0 import time t0 = time.time() + # length = len(self.dataloader_it) for block, inputData in self.dataloader_it: t1 =time.time() self.labels, self.scores = self.model(block, inputData) diff --git a/openhgnn/trainerflow/lightGCN_trainer.py b/openhgnn/trainerflow/lightGCN_trainer.py new file mode 100644 index 00000000..2c9faaec --- /dev/null +++ b/openhgnn/trainerflow/lightGCN_trainer.py @@ -0,0 +1,325 @@ +import random +import dgl +from tqdm import tqdm +import numpy as np +import torch as th +from time import time +from dgl.nn.functional import edge_softmax +from openhgnn.models import build_model +from dgl.dataloading import DataLoader, NeighborSampler, as_edge_prediction_sampler +import dgl.backend as F +from . import BaseFlow, register_flow +from ..tasks import build_task +from sklearn.metrics import f1_score, roc_auc_score +import torch.nn as nn +from torch import Tensor +from torch.nn.modules.loss import _Loss + + +@register_flow("lightGCN_trainer") +class lightGCNTrainer(BaseFlow): + """Demo flows.""" + + def __init__(self, args): + super(lightGCNTrainer, self).__init__(args) + + self.l2_weight = args.weight_decay + self.task = build_task(args) + self.train_dataloader = None + + self.g, _, _ = self.task.get_split() + + self.user_num = self.task.dataset.n_user + self.item_num = self.task.dataset.m_item + self.g_dict = {"g": self.hg, "user_num": self.user_num, "item_num": self.item_num} + self.f = nn.Sigmoid() + + self.model = build_model(self.model).build_model_from_args(self.args, self.g_dict).to(self.device) + # self.optimizer = th.optim.Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.weight_decay) + self.optimizer = th.optim.Adam(self.model.parameters(), lr=self.args.lr) + + def train(self): + + # for epoch in tqdm.tqdm(range(self.args.max_epoch)): + for epoch in range(self.args.max_epoch): + + if epoch % 10 == 0: + self.model.eval() + u_batch_size = self.args.test_u_batch_size # the batch size of users for testing + dataset = self.task.dataset + testDict = dataset.testDict # all testdata + + max_K = self.args.topks # update + results = {'precision': np.zeros(1), + 'recall': np.zeros(1), + 'ndcg': np.zeros(1)} + + with th.no_grad(): + users = list(testDict.keys()) # get test userID + try: + assert u_batch_size <= len(users) / 10 + except AssertionError: + + print(f"test_u_batch_size is too big for this dataset, try a small one {len(users) // 10}") + users_list = [] + rating_list = [] + groundTrue_list = [] + # auc_record = [] + # ratings = [] + total_batch = len(users) // u_batch_size + 1 + + for batch_users in self.minibatch(users, batch_size=u_batch_size): + allPos = dataset.getUserPosItems(batch_users) + groundTrue = [testDict[u] for u in batch_users] + batch_users_gpu = th.Tensor(batch_users).long() + batch_users_gpu = batch_users_gpu.to(self.device) + + x = self.model.embedding.weight + + all_users, all_items = self.model.computer() + users_emb = all_users[batch_users_gpu.long()] + items_emb = all_items + rating = self.f(th.matmul(users_emb, items_emb.t())) + # rating = th.matmul(users_emb, items_emb.t()) + + rating = rating.cpu() + exclude_index = [] + exclude_items = [] + for range_i, items in enumerate(allPos): + exclude_index.extend([range_i] * len(items)) + exclude_items.extend(items) + rating[exclude_index, exclude_items] = -(1 << 10) + + _, rating_K = th.topk(rating, k=max_K) # mak_K = 20 + + rating = rating.cpu().numpy() + + del rating + users_list.append(batch_users) + rating_list.append(rating_K.cpu()) + groundTrue_list.append(groundTrue) + assert total_batch == len(users_list) + X = zip(rating_list, groundTrue_list) + + pre_results = [] + for x in X: + pre_results.append(self.test_one_batch(x)) + scale = float(u_batch_size / len(users)) + for result in pre_results: + results['recall'] += result['recall'] + results['precision'] += result['precision'] + results['ndcg'] += result['ndcg'] + results['recall'] /= float(len(users)) + results['precision'] /= float(len(users)) + results['ndcg'] /= float(len(users)) + + print('[TEST]') + print(results) + + # for it, (input_nodes, positive_graph, negative_graph, blocks) in tqdm.tqdm(enumerate(self.train_dataloader)): + self.model.train() + + S = self.UniformSample_original_python() + + users = th.Tensor(S[:, 0]).long() + posItems = th.Tensor(S[:, 1]).long() + negItems = th.Tensor(S[:, 2]).long() + + users = users.to(self.device) + posItems = posItems.to(self.device) + negItems = negItems.to(self.device) + users, posItems, negItems = self.shuffle(users, posItems, negItems) + total_batch = len(users) // self.args.batch_size + 1 + aver_loss = 0. + + for (batch_i, + (batch_users, + batch_pos, + batch_neg)) in enumerate(self.minibatch(users, posItems, negItems,batch_size=self.args.batch_size)): + + loss, reg_loss = self.model.bpr_loss(batch_users, batch_pos, batch_neg) + reg_loss = reg_loss * self.l2_weight + loss = loss + reg_loss + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + cri = loss.cpu().item() + aver_loss += cri + + # tqdm.set_postfix(f"Batch {batch_i}/{total_batch}") + + aver_loss = aver_loss / total_batch + + # print(epoch) + # print(aver_loss) + print(f'EPOCH[{epoch + 1}/{self.args.max_epoch}] loss:{aver_loss}') + + class BPRLoss(_Loss): + r"""The Bayesian Personalized Ranking (BPR) loss. + + The BPR loss is a pairwise loss that encourages the prediction of an + observed entry to be higher than its unobserved counterparts + (see `here `__). + + .. math:: + L_{\text{BPR}} = - \sum_{u=1}^{M} \sum_{i \in \mathcal{N}_u} + \sum_{j \not\in \mathcal{N}_u} \ln \sigma(\hat{y}_{ui} - \hat{y}_{uj}) + + \lambda \vert\vert \textbf{x}^{(0)} \vert\vert^2 + + where :math:`lambda` controls the :math:`L_2` regularization strength. + We compute the mean BPR loss for simplicity. + + Args: + lambda_reg (float, optional): The :math:`L_2` regularization strength + (default: 0). + **kwargs (optional): Additional arguments of the underlying + :class:`torch.nn.modules.loss._Loss` class. + """ + __constants__ = ['lambda_reg'] + lambda_reg: float + + def __init__(self, lambda_reg: float = 0, **kwargs): + super().__init__(None, None, "sum", **kwargs) + self.lambda_reg = lambda_reg + + def forward(self, positives: Tensor, negatives: Tensor, + parameters: Tensor = None) -> Tensor: + log_prob = nn.functional.logsigmoid(positives - negatives).mean() + # log_prob = - th.mean(th.nn.functional.softplus(negatives - positives)) + regularization = 0 + + if self.lambda_reg != 0: + regularization = self.lambda_reg * (1/2) * parameters.norm(2).pow(2) / float(parameters.shape[0]) + # print(-log_prob) + # + # print(regularization) + return -log_prob + regularization + + def test_one_batch(self, X): + sorted_items = X[0].numpy() + groundTrue = X[1] + r = self.getLabel(groundTrue, sorted_items) + pre, recall, ndcg = [], [], [] + # for k in self.args.topks: + k = self.args.topks + ret = self.recall(groundTrue, r, k) + pre.append(ret['precision']) + recall.append(ret['recall']) + ndcg.append(self.ndcg(groundTrue, r, k)) + return {'recall': np.array(recall), + 'precision': np.array(pre), + 'ndcg': np.array(ndcg)} + + def minibatch(self, *tensors, **kwargs): + + batch_size = kwargs.get('batch_size', self.args.batch_size) + + if len(tensors) == 1: + tensor = tensors[0] + for i in range(0, len(tensor), batch_size): + yield tensor[i:i + batch_size] + else: + for i in range(0, len(tensors[0]), batch_size): + yield tuple(x[i:i + batch_size] for x in tensors) + + def getLabel(self, test_data, pred_data): + r = [] + for i in range(len(test_data)): + groundTrue = test_data[i] + predictTopK = pred_data[i] + pred = list(map(lambda x: x in groundTrue, predictTopK)) + pred = np.array(pred).astype("float") + r.append(pred) + + return np.array(r).astype('float') + + def recall(self, test_data, r, k): + """ + test_data should be a list? cause users may have different amount of pos items. shape (test_batch, k) + pred_data : shape (test_batch, k) NOTE: pred_data should be pre-sorted + k : top-k + """ + right_pred = r[:, :k].sum(1) + precis_n = k + recall_n = np.array([len(test_data[i]) for i in range(len(test_data))]) + recall = np.sum(right_pred / recall_n) + precis = np.sum(right_pred) / precis_n + return {'recall': recall, 'precision': precis} + + def ndcg(self, test_data, r, k): + """ + Normalized Discounted Cumulative Gain + rel_i = 1 or 0, so 2^{rel_i} - 1 = 1 or 0 + """ + assert len(r) == len(test_data) + pred_data = r[:, :k] + + test_matrix = np.zeros((len(pred_data), k)) + for i, items in enumerate(test_data): + length = k if k <= len(items) else len(items) + test_matrix[i, :length] = 1 + max_r = test_matrix + idcg = np.sum(max_r * 1. / np.log2(np.arange(2, k + 2)), axis=1) + dcg = pred_data * (1. / np.log2(np.arange(2, k + 2))) + dcg = np.sum(dcg, axis=1) + idcg[idcg == 0.] = 1. + ndcg = dcg / idcg + ndcg[np.isnan(ndcg)] = 0. + return np.sum(ndcg) + + def UniformSample_original_python(self): + """ + the original impliment of BPR Sampling in LightGCN + :return: + np.array + """ + total_start = time() + # dataset: BasicDataset + user_num = self.task.dataset.traindataSize + users = np.random.randint(0, self.task.dataset.n_user, user_num) + allPos = self.task.dataset.allPos + S = [] + sample_time1 = 0. + sample_time2 = 0. + for i, user in enumerate(users): + start = time() + posForUser = allPos[user] + if len(posForUser) == 0: + continue + sample_time2 += time() - start + posindex = np.random.randint(0, len(posForUser)) + positem = posForUser[posindex] + while True: + negitem = np.random.randint(0, self.task.dataset.m_item) + if negitem in posForUser: + continue + else: + break + S.append([user, positem, negitem]) + end = time() + sample_time1 += end - start + total = time() - total_start + return np.array(S) + + def shuffle(self, *arrays, **kwargs): + + require_indices = kwargs.get('indices', False) + + if len(set(len(x) for x in arrays)) != 1: + raise ValueError('All inputs to shuffle must have ' + 'the same length.') + + shuffle_indices = np.arange(len(arrays[0])) + np.random.shuffle(shuffle_indices) + + if len(arrays) == 1: + result = arrays[0][shuffle_indices] + else: + result = tuple(x[shuffle_indices] for x in arrays) + + if require_indices: + return result, shuffle_indices + else: + return result \ No newline at end of file From 2e686090b3a20fd30c179e3bc9b6876e404f75ee Mon Sep 17 00:00:00 2001 From: clingingsai <15010713603@qq.com> Date: Mon, 10 Jul 2023 18:58:38 +0800 Subject: [PATCH 2/3] add lightGCN model and modify KGCN model --- openhgnn/models/lightGCN.py | 32 +++++++++++++++ openhgnn/output/lightGCN/README.md | 66 ++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+) create mode 100644 openhgnn/output/lightGCN/README.md diff --git a/openhgnn/models/lightGCN.py b/openhgnn/models/lightGCN.py index 64d2d421..6eff5b19 100644 --- a/openhgnn/models/lightGCN.py +++ b/openhgnn/models/lightGCN.py @@ -9,6 +9,38 @@ @register_model('lightGCN') class lightGCN(BaseModel): + r""" + This module lightGCN was introduced in `lightGCN `__. + + The difference with GCN is that aggregate the entity representation and its neighborhood representation into the entity's embedding, but don't use feature transformation and nonlinear + activation. + The message function is defined as follow: + + :math:`\mathbf{e}_u^{(k+1)}=\operatorname{AGG}\left(\mathbf{e}_u^{(k)},\left\{\mathbf{e}_i^{(k)}: i \in \mathcal{N}_u\right\}\right)` + + The AGG is an aggregation function — the core of graph convolution — that considers the k-th layer’s representation of the target node and its neighbor nodes. + + + In LightGCN, we adopt the simple weighted sum aggregator and abandon the use of feature transformation and nonlinear activation. + :math:`\mathbf{e}_u^{(k+1)}=\sum_{i \in \mathcal{N}_u} \frac{1}{\sqrt{\left|\mathcal{N}_u\right|} \sqrt{\left|\mathcal{N}_i\right|}}` + :math:`\mathbf{e}_i^{(k)}, \\ & \mathbf{e}_i^{(k+1)}=\sum_{u \in \mathcal{N}_i} \frac{1}{\sqrt{\left|\mathcal{N}_i\right|} \sqrt{\left|\mathcal{N}_u\right|}} \mathbf{e}_u^{(k)}` + + In the above equations, :math:`\sigma` is the nonlinear function and + :math:`\mathrm{W}` and :math:`\mathrm{b}` are transformation weight and bias. + the representation of an item is bound up with its neighbors by aggregation + + The model prediction is defined as the inner product of user and + item final representations: + + :math:`\hat{y}_{u i}=\mathbf{e}_u^T \mathbf{e}_i` + + Parameters + ---------- + g : DGLGraph + A knowledge Graph preserves relationships between entities + args : Config + Model's config + """ @classmethod def build_model_from_args(cls, args, g): return cls(g, args) diff --git a/openhgnn/output/lightGCN/README.md b/openhgnn/output/lightGCN/README.md new file mode 100644 index 00000000..f0e7a48b --- /dev/null +++ b/openhgnn/output/lightGCN/README.md @@ -0,0 +1,66 @@ +# lightGCN[SIGIR 2020] + +- paper: [LightGCN: Simplifying and Powering Graph Convolution Network for Recommendation](https://dl.acm.org/doi/abs/10.1145/3397271.3401063) +- Code from author: [LightGCN-PyTorch](https://github.com/gusye1234/LightGCN-PyTorch) + +## How to run + +- Clone the Openhgnn-DGL + + ```bash + python main.py -m lightGCN -d gowalla -t recommendation -g 0 --use_best_config + ``` + + If you do not have gpu, set -gpu -1. + + the dataset gowalla, yelp2018, amazon-book are supported. + +## Performance: Recommendation + +- Device: GPU, **GeForce RTX 4070** + +| Recommendation | recall | ndcg | +|:--------------:|:---------------------------------:|:---------------------------------:| +| gowalla | paper: 0.1830 OpenHGNN: 0.1841 | paper: 0.1554 OpenHGNN: 0.1553 | +| yelp2018 | paper: 0.0649 OpenHGNN: 0.0648 | paper: 0.0530 OpenHGNN: 0.0532 | +| amazon-book | paper: 0.0411 OpenHGNN: 0.0414 | paper: 0.0315 OpenHGNN: 0.0316 | + +## Dataset + +- We process the lightGCN dataset given by [LightGCN-PyTorch](https://github.com/gusye1234/LightGCN-PyTorch). + +### Description + +- Last.FM + + | | gowalla | yelp2018 | amazon-book | + |:------------:|:------------:|:-----------:|:---- ------:| + | User | 29858 | 31668 | 52643 | + | item | 40981 | 38048 | 91599 | + | interactions | 1027370 | 1561406 | 2984108 | + | Density | 0.00084 | 0.00130 | 0.00062 | + + + + +## TrainerFlow: Recommendation + +#### model + +- ​ lightGCN + - ​ lightGCN is only to aggregate but don't use feature transformation and nonlinear activation + + + +## More + +#### Contributor + +Saisai Geng[GAMMA LAB] + +#### If you have any questions, + +Submit an issue or email to [15010713603@qq.com](mailto:15010713603@qq.com). + + + From 6a501a7d28f6ee992cc1de3a1e28ba5289513609 Mon Sep 17 00:00:00 2001 From: clingingsai <15010713603@qq.com> Date: Tue, 18 Jul 2023 11:11:02 +0800 Subject: [PATCH 3/3] add lightGCN model and modify KGCN model --- openhgnn/dataset/RecommendationDataset.py | 34 +++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/openhgnn/dataset/RecommendationDataset.py b/openhgnn/dataset/RecommendationDataset.py index 34f92ce3..db48269c 100644 --- a/openhgnn/dataset/RecommendationDataset.py +++ b/openhgnn/dataset/RecommendationDataset.py @@ -3,7 +3,7 @@ import torch as th import numpy as np from . import BaseDataset, register_dataset -from dgl.data.utils import load_graphs +from dgl.data.utils import load_graphs, download from scipy.sparse import csr_matrix import scipy.sparse as sp from .multigraph import MultiGraphDataset @@ -55,7 +55,16 @@ class lightGCN_Recommendation(RecommendationDataset): def __init__(self, dataset_name, *args, **kwargs): super(RecommendationDataset, self).__init__(*args, **kwargs) - # train and test data + if dataset_name not in ['gowalla','yelp2018','amazon-book']: + raise KeyError('Dataset {} is not supported!'.format(dataset_name)) + self.dataset_name=dataset_name + + self.data_path=f'openhgnn/dataset/{self.dataset_name}' + + if not os.path.exists(f"{self.data_path}/train.txt"): + self.download() + + # test self.mode_dict = {'train': 0, "test": 1} self.mode = self.mode_dict['train'] self.n_user = 0 @@ -189,6 +198,27 @@ def getSparseGraph(self): self.Graph = self.Graph.coalesce() print("don't split the matrix") return self.Graph + + def download(self): + prefix = 'https://raw.githubusercontent.com/gusye1234/LightGCN-PyTorch/master/data' + + required_file = ['train.txt', 'test.txt'] + + for filename in required_file: + url = f"{prefix}/{self.dataset_name}/{filename}" + file_path = f"{self.data_path}/{filename}" + if not os.path.exists(file_path): + try: + download(url, file_path) + + except BaseException as e: + print("\n",e) + print("\nNote! --- If you want to download the file, vpn is required ---") + print("If you don't have a vpn, please download the dataset from here: https://github.com/gusye1234/LightGCN-PyTorch") + print("\nAfter downloading the dataset, you need to store the files in the following path: ") + print(f"{os.getcwd()}\openhgnn\dataset\{self.dataset_name}\\train.txt") + print(f"{os.getcwd()}\openhgnn\dataset\{self.dataset_name}\\test.txt") + exit() @register_dataset('hin_recommendation')