diff --git a/openhgnn/config.ini b/openhgnn/config.ini index 294d73d6..4af88b52 100644 --- a/openhgnn/config.ini +++ b/openhgnn/config.ini @@ -858,3 +858,22 @@ mlp_inference_bool = 1 neg_alpha = 0 load_json = 0 +[HGCL] +batch = 8192 +epochs = 400 +wu1 = 0.8 +wu2 = 0.2 +wi1 = 0.8 +wi2 = 0.2 +lr = 0.055 +topk = 10 +hide_dim = 32 +metareg = 0.15 +ssl_temp = 0.5 +ssl_ureg = 0.04 +ssl_ireg = 0.05 +ssl_reg = 0.01 +ssl_beta = 0.32 +rank = 3 +Layers = 2 +reg = 0.043 diff --git a/openhgnn/config.py b/openhgnn/config.py index 2d743dc5..ff280085 100644 --- a/openhgnn/config.py +++ b/openhgnn/config.py @@ -889,6 +889,26 @@ def __init__(self, file_path, model, dataset, task, gpu): self.compress_ratio = conf.getfloat("SHGP", 'compress_ratio') self.cuda = conf.getint("SHGP", 'cuda') + elif model == 'HGCL': + self.lr = conf.getfloat("HGCL", "lr") + self.batch = conf.getint("HGCL", "batch") + self.wu1 = conf.getfloat('HGCL', "wu1") + self.wu2 = conf.getfloat("HGCL", "wu2") + self.wi1 = conf.getfloat("HGCL", "wi1") + self.wi2 = conf.getfloat("HGCL", "wi2") + self.epochs = conf.getint("HGCL", "epochs") + self.topk = conf.getint("HGCL", "topk") + self.hide_dim = conf.getint("HGCL", "hide_dim") + self.reg = conf.getfloat("HGCL", "reg") + self.metareg = conf.getfloat("HGCL", "metareg") + self.ssl_temp = conf.getfloat("HGCL", "ssl_temp") + self.ssl_ureg = conf.getfloat("HGCL", "ssl_ureg") + self.ssl_ireg = conf.getfloat("HGCL", "ssl_ireg") + self.ssl_reg = conf.getfloat("HGCL", "ssl_reg") + self.ssl_beta = conf.getfloat("HGCL", "ssl_beta") + self.rank = conf.getint("HGCL", "rank") + self.Layers = conf.getint("HGCL", "Layers") + if hasattr(self, 'device'): self.device = th.device(self.device) elif gpu == -1: diff --git a/openhgnn/dataset/HGCLDataset.py b/openhgnn/dataset/HGCLDataset.py new file mode 100644 index 00000000..1ab1dd2f --- /dev/null +++ b/openhgnn/dataset/HGCLDataset.py @@ -0,0 +1,114 @@ +import torch as t +from dgl.data import DGLDataset +from dgl.data.utils import download, extract_archive +from dgl.data.utils import load_graphs +import os +import numpy as np +import dgl +import pickle + + +class HGCLDataset(DGLDataset): + + _prefix = 'https://s3.cn-north-1.amazonaws.com.cn/dgl-data/' + _urls = { + + } + + def __init__(self, name, raw_dir=None, force_reload=False, verbose=True): + assert name in ['Epinions', 'CiaoDVD', 'Yelp'] + self.data_path = './{}.zip'.format(name) + self.g_path = './{}/graph.bin'.format(name) + raw_dir = './' + url = 'https://s3.cn-north-1.amazonaws.com.cn/dgl-data/' + 'dataset/{}.zip'.format(name) + + super(HGCLDataset, self).__init__(name=name, + url=url, + raw_dir=raw_dir, + force_reload=force_reload, + verbose=verbose) + def create_graph(self): + ''' + raw_dataset url : https://drive.google.com/drive/folders/1s6LGibPnal6gMld5t63aK4J7hnVkNeDs + ''' + data_path = self.data_path + '/data.pkl' + distance_path = self.data_path + '/distanceMat_addIUUI.pkl' + ici_path = self.data_path + '/ICI.pkl' + + with open(data_path, 'rb') as fs: + data = pickle.load(fs) + with open(distance_path, 'rb') as fs: + distanceMat = pickle.load(fs) + with open(ici_path, "rb") as fs: + itemMat = pickle.load(fs) + + trainMat, testdata, _, categoryMat, _ = data + userNum, itemNum = trainMat.shape + userDistanceMat, itemDistanceMat, uiDistanceMat = distanceMat + + # trainMat + trainMat_coo = trainMat.tocoo() + trainMat_i, trainMat_j, trainMat_data = trainMat_coo.row, trainMat_coo.col, trainMat_coo.data + + # testdata + testdata = np.array(testdata) + + # userDistanceMat + userDistanceMat_coo = userDistanceMat.tocoo() + userDistanceMat_i, userDistanceMat_j, userDistanceMat_data = userDistanceMat_coo.row, userDistanceMat_coo.col, userDistanceMat_coo.data + + # itemMat + itemMat_coo = itemMat.tocoo() + itemMat_i, itemMat_j, itemMat_data = itemMat_coo.row, itemMat_coo.col, itemMat_coo.data + + # uiDisantanceMat + uiDistanceMat_coo = uiDistanceMat.tocoo() + uiDistanceMat_i, uiDistanceMat_j, uiDistanceMat_data = uiDistanceMat_coo.row, uiDistanceMat_coo.col, uiDistanceMat_coo.data + + graph_data = { + ('user', 'interact_train', 'item'): (t.tensor(trainMat_i), t.tensor(trainMat_j)), + ('user', 'distance', 'user'): (t.tensor(userDistanceMat_i), t.tensor(userDistanceMat_j)), + ('item', 'distance', 'item'): (t.tensor(itemMat_i), t.tensor(itemMat_j)), + ('user+item', 'distance', 'user+item'): (t.tensor(uiDistanceMat_i), t.tensor(uiDistanceMat_j)), + ('user', 'interact_test', 'item'): (t.tensor(testdata[:, 0]), t.tensor(testdata[:, 1])) + } + g = dgl.heterograph(graph_data) + dgl.save_graphs(self.data_path + '/graph.bin', g) + self.g_path = self.data_path + '/graph.bin' + + + def download(self): + # download raw data to local disk + # path to store the file + if os.path.exists(self.data_path): # pragma: no cover + pass + else: + file_path = os.path.join(self.raw_dir) + # download file + download(self.url, path=file_path) + extract_archive(self.data_path, os.path.join(self.raw_dir, self.name)) + + def process(self): + # process raw data to graphs, labels, splitting masks + g, _ = load_graphs(self.g_path) + self._g = g + + def __getitem__(self, idx): + # get one example by index + return self._g[idx] + + def __len__(self): + # number of data examples + return 1 + + def save(self): + # save processed data to directory `self.save_path` + pass + + def load(self): + # load processed data from directory `self.save_path` + pass + + def has_cache(self): + # check whether there are processed data in `self.save_path` + pass diff --git a/openhgnn/dataset/RecommendationDataset.py b/openhgnn/dataset/RecommendationDataset.py index 6a27a4c9..19dcd00b 100644 --- a/openhgnn/dataset/RecommendationDataset.py +++ b/openhgnn/dataset/RecommendationDataset.py @@ -6,6 +6,7 @@ from .multigraph import MultiGraphDataset from ..sampler.negative_sampler import Uniform_exclusive from . import AcademicDataset +from .HGCLDataset import HGCLDataset #add more lib for KGAT import time @@ -51,7 +52,29 @@ def get_train_data(self): def get_labels(self): return self.label +@register_dataset('hgcl_recommendation') +class HGCLRecommendation(RecommendationDataset): + def __init__(self, dataset_name, *args, **kwargs): + super(RecommendationDataset, self).__init__(*args, **kwargs) + dataset = HGCLDataset(name=dataset_name, raw_dir='') + self.g = dataset[0].long() + def get_split(self, validation=True): + ratingsGraph = self.g + n_edges = ratingsGraph.num_edges() + random_int = th.randperm(n_edges) + train_idx = random_int[:int(n_edges * 0.6)] + val_idx = random_int[int(n_edges * 0.6):int(n_edges * 0.8)] + test_idx = random_int[int(n_edges * 0.6):int(n_edges * 0.8)] + + return train_idx, val_idx, test_idx + + def get_train_data(self): + pass + + def get_labels(self): + return self.label + @register_dataset('hin_recommendation') class HINRecommendation(RecommendationDataset): def __init__(self, dataset_name, *args, **kwargs): diff --git a/openhgnn/dataset/__init__.py b/openhgnn/dataset/__init__.py index 10e17947..c56f7412 100644 --- a/openhgnn/dataset/__init__.py +++ b/openhgnn/dataset/__init__.py @@ -92,6 +92,8 @@ def build_dataset(dataset, task, *args, **kwargs): _dataset = 'kgcn_recommendation' elif dataset in ['yelp4rec']: _dataset = 'hin_' + task + elif dataset in ['Epinions', 'CiaoDVD', 'Yelp']: + _dataset = 'hgcl_recommendation' elif dataset in ['dblp4Mg2vec_4', 'dblp4Mg2vec_5']: _dataset = 'hin_' + task elif dataset == 'demo': diff --git a/openhgnn/experiment.py b/openhgnn/experiment.py index ffd7d3bf..ba6b4f37 100644 --- a/openhgnn/experiment.py +++ b/openhgnn/experiment.py @@ -60,6 +60,7 @@ class Experiment(object): 'MeiREC': 'MeiREC_trainer', 'KGAT': 'KGAT_trainer' 'SHGP': 'SHGP_trainer' + 'HGCL': 'hgcltrainer', } immutable_params = ['model', 'dataset', 'task'] diff --git a/openhgnn/models/HGCL.py b/openhgnn/models/HGCL.py new file mode 100644 index 00000000..c7e5d153 --- /dev/null +++ b/openhgnn/models/HGCL.py @@ -0,0 +1,317 @@ +import torch as t +import torch +import torch.nn as nn +import torch.nn.functional as F +import scipy.sparse as sp +import numpy as np +from openhgnn.models.base_model import BaseModel +from openhgnn.models import register_model +import dgl + +### HGCL +@register_model('HGCL') +class HGCL(BaseModel): + def build_model_from_args(args, hg): + userNum = hg.number_of_nodes('user') + itemNum = hg.number_of_nodes('item') + + userMat = hg.adj_external(etype=('user', 'distance', 'user'), scipy_fmt='csr') + itemMat = hg.adj_external(etype=('item', 'distance', 'item'), scipy_fmt='csr') + uiMat = hg.adj_external(etype=('user+item', 'distance', 'user+item'), scipy_fmt='csr') + + return HGCL(userNum=userNum, itemNum=itemNum, userMat=userMat, itemMat=itemMat, uiMat=uiMat, + hide_dim=args.hide_dim, Layers=args.Layers, rank=args.rank, wu1=args.wu1, + wu2=args.wu2, wi1=args.wi1, wi2=args.wi2) + def __init__(self, userNum, itemNum, userMat, itemMat, uiMat, hide_dim, Layers, rank, wu1, wu2, wi1, wi2): + super(HGCL, self).__init__() + self.userNum = userNum + self.itemNum = itemNum + self.uuMat = userMat + self.iiMat = itemMat + self.uiMat = uiMat + self.hide_dim = hide_dim + self.LayerNums = Layers + self.wu1 = wu1 + self.wu2 = wu2 + self.wi1 = wi1 + self.wi2 = wi2 + + uimat = self.uiMat[: self.userNum, self.userNum:] + values = torch.FloatTensor(uimat.tocoo().data) + indices = np.vstack((uimat.tocoo().row, uimat.tocoo().col)) + i = torch.LongTensor(indices) + v = torch.FloatTensor(values) + shape = uimat.tocoo().shape + uimat1 = torch.sparse.FloatTensor(i, v, torch.Size(shape)) + self.uiadj = uimat1 + self.iuadj = uimat1.transpose(0, 1) + + self.gating_weightub = nn.Parameter( + torch.FloatTensor(1, hide_dim)) + nn.init.xavier_normal_(self.gating_weightub.data) + self.gating_weightu = nn.Parameter( + torch.FloatTensor(hide_dim, hide_dim)) + nn.init.xavier_normal_(self.gating_weightu.data) + self.gating_weightib = nn.Parameter( + torch.FloatTensor(1, hide_dim)) + nn.init.xavier_normal_(self.gating_weightib.data) + self.gating_weighti = nn.Parameter( + torch.FloatTensor(hide_dim, hide_dim)) + nn.init.xavier_normal_(self.gating_weighti.data) + + self.encoder = nn.ModuleList() + for i in range(0, self.LayerNums): + self.encoder.append(GCN_layer()) + self.k = rank + k = self.k + self.mlp = MLP(hide_dim, hide_dim * k, hide_dim // 2, hide_dim * k) + self.mlp1 = MLP(hide_dim, hide_dim * k, hide_dim // 2, hide_dim * k) + self.mlp2 = MLP(hide_dim, hide_dim * k, hide_dim // 2, hide_dim * k) + self.mlp3 = MLP(hide_dim, hide_dim * k, hide_dim // 2, hide_dim * k) + self.meta_netu = nn.Linear(hide_dim * 3, hide_dim, bias=True) + self.meta_neti = nn.Linear(hide_dim * 3, hide_dim, bias=True) + + self.embedding_dict = nn.ModuleDict({ + 'uu_emb': torch.nn.Embedding(userNum, hide_dim).cuda(), + 'ii_emb': torch.nn.Embedding(itemNum, hide_dim).cuda(), + 'user_emb': torch.nn.Embedding(userNum, hide_dim).cuda(), + 'item_emb': torch.nn.Embedding(itemNum, hide_dim).cuda(), + }) + + def init_weight(self, userNum, itemNum, hide_dim): + initializer = nn.init.xavier_uniform_ + embedding_dict = nn.ParameterDict({ + 'user_emb': nn.Parameter(initializer(t.empty(userNum, hide_dim))), + 'item_emb': nn.Parameter(initializer(t.empty(itemNum, hide_dim))), + }) + return embedding_dict + + def sparse_mx_to_torch_sparse_tensor(self, sparse_mx): + """Convert a scipy sparse matrix to a torch sparse tensor.""" + if type(sparse_mx) != sp.coo_matrix: + sparse_mx = sparse_mx.tocoo().astype(np.float32) + indices = torch.from_numpy( + np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) + values = torch.from_numpy(sparse_mx.data).float() + shape = torch.Size(sparse_mx.shape) + return torch.sparse.FloatTensor(indices, values, shape) + + def metaregular(self, em0, em, adj): + def row_column_shuffle(embedding): + corrupted_embedding = embedding[:, torch.randperm(embedding.shape[1])] + corrupted_embedding = corrupted_embedding[torch.randperm(embedding.shape[0])] + return corrupted_embedding + + def score(x1, x2): + x1 = F.normalize(x1, p=2, dim=-1) + x2 = F.normalize(x2, p=2, dim=-1) + return torch.sum(torch.multiply(x1, x2), 1) + + user_embeddings = em + Adj_Norm = t.from_numpy(np.sum(adj, axis=1)).float().cuda() + adj = self.sparse_mx_to_torch_sparse_tensor(adj) + edge_embeddings = torch.spmm(adj.cuda(), user_embeddings) / Adj_Norm + user_embeddings = em0 + graph = torch.mean(edge_embeddings, 0) + pos = score(user_embeddings, graph) + neg1 = score(row_column_shuffle(user_embeddings), graph) + global_loss = torch.mean(-torch.log(torch.sigmoid(pos - neg1))) + return global_loss + + def self_gatingu(self, em): + return torch.multiply(em, torch.sigmoid(torch.matmul(em, self.gating_weightu) + self.gating_weightub)) + + def self_gatingi(self, em): + return torch.multiply(em, torch.sigmoid(torch.matmul(em, self.gating_weighti) + self.gating_weightib)) + + def metafortansform(self, auxiembedu, targetembedu, auxiembedi, targetembedi): + + # Neighbor information of the target node + uneighbor = t.matmul(self.uiadj.cuda(), self.ui_itemEmbedding) + ineighbor = t.matmul(self.iuadj.cuda(), self.ui_userEmbedding) + + # Meta-knowlege extraction + tembedu = (self.meta_netu(t.cat((auxiembedu, targetembedu, uneighbor), dim=1).detach())) + tembedi = (self.meta_neti(t.cat((auxiembedi, targetembedi, ineighbor), dim=1).detach())) + + """ Personalized transformation parameter matrix """ + # Low rank matrix decomposition + metau1 = self.mlp(tembedu).reshape(-1, self.hide_dim, self.k) # d*k + metau2 = self.mlp1(tembedu).reshape(-1, self.k, self.hide_dim) # k*d + metai1 = self.mlp2(tembedi).reshape(-1, self.hide_dim, self.k) # d*k + metai2 = self.mlp3(tembedi).reshape(-1, self.k, self.hide_dim) # k*d + meta_biasu = (torch.mean(metau1, dim=0)) + meta_biasu1 = (torch.mean(metau2, dim=0)) + meta_biasi = (torch.mean(metai1, dim=0)) + meta_biasi1 = (torch.mean(metai2, dim=0)) + low_weightu1 = F.softmax(metau1 + meta_biasu, dim=1) + low_weightu2 = F.softmax(metau2 + meta_biasu1, dim=1) + low_weighti1 = F.softmax(metai1 + meta_biasi, dim=1) + low_weighti2 = F.softmax(metai2 + meta_biasi1, dim=1) + + # The learned matrix as the weights of the transformed network + tembedus = (t.sum(t.multiply((auxiembedu).unsqueeze(-1), low_weightu1), + dim=1)) # Equal to a two-layer linear network; Ciao and Yelp data sets are plus gelu activation function + tembedus = t.sum(t.multiply((tembedus).unsqueeze(-1), low_weightu2), dim=1) + tembedis = (t.sum(t.multiply((auxiembedi).unsqueeze(-1), low_weighti1), dim=1)) + tembedis = t.sum(t.multiply((tembedis).unsqueeze(-1), low_weighti2), dim=1) + transfuEmbed = tembedus + transfiEmbed = tembedis + return transfuEmbed, transfiEmbed + + def forward(self, iftraining, uid, iid, norm=1): + + item_index = np.arange(0, self.itemNum) + user_index = np.arange(0, self.userNum) + ui_index = np.array(user_index.tolist() + [i + self.userNum for i in item_index]) + + # Initialize Embeddings + userembed0 = self.embedding_dict['user_emb'].weight + itemembed0 = self.embedding_dict['item_emb'].weight + uu_embed0 = self.self_gatingu(userembed0) # e0uu + ii_embed0 = self.self_gatingi(itemembed0) # e0ii + self.ui_embeddings = t.cat([userembed0, itemembed0], 0) # e0ui + self.all_user_embeddings = [uu_embed0] + self.all_item_embeddings = [ii_embed0] + self.all_ui_embeddings = [self.ui_embeddings] + # Encoder + for i in range(len(self.encoder)): + layer = self.encoder[i] + if i == 0: + userEmbeddings0 = layer(uu_embed0, self.uuMat, user_index) + itemEmbeddings0 = layer(ii_embed0, self.iiMat, item_index) + uiEmbeddings0 = layer(self.ui_embeddings, self.uiMat, ui_index) + else: + userEmbeddings0 = layer(userEmbeddings, self.uuMat, user_index) + itemEmbeddings0 = layer(itemEmbeddings, self.iiMat, item_index) + uiEmbeddings0 = layer(uiEmbeddings, self.uiMat, ui_index) + + # Aggregation of message features across the two related views in the middle layer then fed into the next layer + self.ui_userEmbedding0, self.ui_itemEmbedding0 = t.split(uiEmbeddings0, [self.userNum, self.itemNum]) + userEd = (userEmbeddings0 + self.ui_userEmbedding0) / 2.0 + itemEd = (itemEmbeddings0 + self.ui_itemEmbedding0) / 2.0 + userEmbeddings = userEd + itemEmbeddings = itemEd + uiEmbeddings = torch.cat([userEd, itemEd], 0) + if norm == 1: + norm_embeddings = F.normalize(userEmbeddings0, p=2, dim=1) + self.all_user_embeddings += [norm_embeddings] + norm_embeddings = F.normalize(itemEmbeddings0, p=2, dim=1) + self.all_item_embeddings += [norm_embeddings] + norm_embeddings = F.normalize(uiEmbeddings0, p=2, dim=1) + self.all_ui_embeddings += [norm_embeddings] + else: + self.all_user_embeddings += [userEmbeddings] + self.all_item_embeddings += [norm_embeddings] + self.all_ui_embeddings += [norm_embeddings] + self.userEmbedding = t.stack(self.all_user_embeddings, dim=1) + self.userEmbedding = t.mean(self.userEmbedding, dim=1) + self.itemEmbedding = t.stack(self.all_item_embeddings, dim=1) + self.itemEmbedding = t.mean(self.itemEmbedding, dim=1) + self.uiEmbedding = t.stack(self.all_ui_embeddings, dim=1) + self.uiEmbedding = t.mean(self.uiEmbedding, dim=1) + self.ui_userEmbedding, self.ui_itemEmbedding = t.split(self.uiEmbedding, [self.userNum, self.itemNum]) + + # Personalized Transformation of Auxiliary Domain Features + metatsuembed, metatsiembed = self.metafortansform(self.userEmbedding, self.ui_userEmbedding, self.itemEmbedding, + self.ui_itemEmbedding) + self.userEmbedding = self.userEmbedding + metatsuembed + self.itemEmbedding = self.itemEmbedding + metatsiembed + + # Regularization: the constraint of transformed reasonableness + metaregloss = 0 + if iftraining == True: + self.reg_lossu = self.metaregular((self.ui_userEmbedding[uid.cpu().numpy()]), (self.userEmbedding), + self.uuMat[uid.cpu().numpy()]) + self.reg_lossi = self.metaregular((self.ui_itemEmbedding[iid.cpu().numpy()]), (self.itemEmbedding), + self.iiMat[iid.cpu().numpy()]) + metaregloss = (self.reg_lossu + self.reg_lossi) / 2.0 + return self.userEmbedding, self.itemEmbedding, ( + self.wu1 * self.ui_userEmbedding + self.wu2 * self.userEmbedding), ( + self.wi1 * self.ui_itemEmbedding + self.wi2 * self.itemEmbedding), self.ui_userEmbedding, self.ui_itemEmbedding, metaregloss + + +class GCN_layer(nn.Module): + def __init__(self): + super(GCN_layer, self).__init__() + + def sparse_mx_to_torch_sparse_tensor(self, sparse_mx): + """Convert a scipy sparse matrix to a torch sparse tensor.""" + if type(sparse_mx) != sp.coo_matrix: + sparse_mx = sparse_mx.tocoo().astype(np.float32) + indices = torch.from_numpy( + np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) + values = torch.from_numpy(sparse_mx.data).float() + shape = torch.Size(sparse_mx.shape) + return torch.sparse.FloatTensor(indices, values, shape) + + def normalize_adj(self, adj): + """Symmetrically normalize adjacency matrix.""" + adj = sp.coo_matrix(adj) + rowsum = np.array(adj.sum(1)) + d_inv_sqrt = np.power(rowsum, -0.5).flatten() + d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. + d_mat_inv_sqrt = sp.diags(d_inv_sqrt) + return (d_mat_inv_sqrt).dot(adj).dot(d_mat_inv_sqrt).tocoo() + + def forward(self, features, Mat, index): + subset_Mat = Mat + subset_features = features + subset_Mat = self.normalize_adj(subset_Mat) + subset_sparse_tensor = self.sparse_mx_to_torch_sparse_tensor(subset_Mat).cuda() + out_features = torch.spmm(subset_sparse_tensor, subset_features) + new_features = torch.empty(features.shape).cuda() + new_features[index] = out_features + dif_index = np.setdiff1d(torch.arange(features.shape[0]), index) + new_features[dif_index] = features[dif_index] + return new_features + + +class MLP(torch.nn.Module): + def __init__(self, input_dim, feature_dim, hidden_dim, output_dim, + feature_pre=True, layer_num=2, dropout=True, **kwargs): + super(MLP, self).__init__() + self.feature_pre = feature_pre + self.layer_num = layer_num + self.dropout = dropout + if feature_pre: + self.linear_pre = nn.Linear(input_dim, feature_dim, bias=True) + else: + self.linear_first = nn.Linear(input_dim, hidden_dim) + self.linear_hidden = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for i in range(layer_num - 2)]) + self.linear_out = nn.Linear(feature_dim, output_dim, bias=True) + + def forward(self, data): + x = data + if self.feature_pre: + x = self.linear_pre(x) + prelu = nn.PReLU().cuda() + x = prelu(x) + for i in range(self.layer_num - 2): + x = self.linear_hidden[i](x) + x = F.tanh(x) + if self.dropout: + x = F.dropout(x, training=self.training) + x = self.linear_out(x) + x = F.normalize(x, p=2, dim=-1) + return x + + + + + + + + + + + + + + + + + + + + diff --git a/openhgnn/models/__init__.py b/openhgnn/models/__init__.py index 5ede897c..1c058237 100644 --- a/openhgnn/models/__init__.py +++ b/openhgnn/models/__init__.py @@ -109,8 +109,10 @@ def build_model_from_args(args, hg): 'KGAT': 'openhgnn.models.KGAT', 'SHGP': 'openhgnn.models.ATT_HGCN', 'DSSL': 'openhgnn.models.DSSL' + 'HGCL': 'openhgnn.models.HGCL' } +from .HGCL import HGCL from .CompGCN import CompGCN from .HetGNN import HetGNN from .RGCN import RGCN diff --git a/openhgnn/output/HGCL/README.md b/openhgnn/output/HGCL/README.md new file mode 100644 index 00000000..f632678f --- /dev/null +++ b/openhgnn/output/HGCL/README.md @@ -0,0 +1,39 @@ +# HGCL + +-paper: [Heterogeneous Graph Contrastive Learning +for Recommendation +](https://arxiv.org/pdf/2303.00995.pdf) + +-code from author: [HGCL](https://github.com/HKUDS/HGCL) + +## How to run +- Clone the Openhgnn-DGL + ```bash + python main.py -m HGCL -t recommendation -d Epinions -g 0 + ``` + +for high efficiency, only gpu + +## Performance: Recommendation + +- Device: GPU, **GeForce GTX 1080Ti** +- Dataset:Epinions,CiaoDVD,Yelp + + +| Recommendation | HR | NDCG | +|:--------------:|:---------------------------------:|:---------------------------------:| +| Epinions | paper: 83.67% OpenHGNN: 82.15% | paper: 64.13% OpenHGNN: 62.45% | +| CiaoDVD | paper: 73.76% OpenHGNN: 72.93% | paper: 52.61% OpenHGNN: 50.72% | +| Yelp | paper: 87.12% OpenHGNN: 86.26% | paper: 63.10% OpenHGNN: 60.58% | + +## More + +#### Contributor + +Siyuan Wen[GAMMA LAB] + +#### If you have any questions, + +Submit an issue or email to [wsy0718@bupt.edu.cn](mailto:wsy0718@bupt.edu.cn). + + diff --git a/openhgnn/trainerflow/__init__.py b/openhgnn/trainerflow/__init__.py index a3616416..de5610bb 100644 --- a/openhgnn/trainerflow/__init__.py +++ b/openhgnn/trainerflow/__init__.py @@ -78,9 +78,10 @@ def build_flow(args, flow_name): 'SHGP_trainer': 'openhgnn.trainerflow.SHGP_trainer', 'KGAT_trainer': 'openhgnn.trainerflow.KGAT_trainer', 'DSSL_trainer': 'openhgnn.trainerflow.DSSL_trainer', - + 'hgcltrainer': 'openhgnn.trainerflow.hgcl_trainer' } +from .hgcl_trainer import HGCLtrainer from .node_classification import NodeClassification from .link_prediction import LinkPrediction from .recommendation import Recommendation @@ -128,7 +129,11 @@ def build_flow(args, flow_name): 'DHNE_trainer', 'DiffMG_trainer', 'MeiRECTrainer', +<<<<<<< HEAD 'KGAT_Trainer', 'DSSL_trainer' +======= + 'HGCLtrainer' +>>>>>>> a8f8928e3ae66f22a19b31c7a8e95b9db2643b5b ] classes = __all__ diff --git a/openhgnn/trainerflow/hgcl_trainer.py b/openhgnn/trainerflow/hgcl_trainer.py new file mode 100644 index 00000000..0d4274cb --- /dev/null +++ b/openhgnn/trainerflow/hgcl_trainer.py @@ -0,0 +1,232 @@ +import torch as t +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data as dataloader +import torch.optim as optim +import numpy as np +import datetime +import torch.utils.data as data +from openhgnn.trainerflow.base_flow import BaseFlow +from openhgnn.trainerflow import register_flow +from openhgnn.models import build_model +from ..tasks import build_task +from openhgnn.models.HGCL import HGCL + +saveDefault = False + +@register_flow('hgcltrainer') +class HGCLtrainer(BaseFlow): + def __init__(self, args): + super(HGCLtrainer, self).__init__(args) + self.args = args + + self.task = build_task(args) + self.hg = self.task.dataset.g + self.userNum = self.hg.number_of_nodes('user') + self.itemNum = self.hg.number_of_nodes('item') + + self.model = build_model(self.model).build_model_from_args(args=self.args, hg=self.hg).to(self.device) + self.opt = optim.Adam(self.model.parameters(), lr=self.args.lr) + + trainMat = self.hg.adj_external(etype=('user', 'interact_train', 'item'), scipy_fmt='coo') + testMat = self.hg.adj_external(etype=('user', 'interact_test', 'item'), scipy_fmt='coo') + train_u, train_v, train_r = trainMat.row, trainMat.col, trainMat.data + assert np.sum(train_r == 0) == 0 + test_u, test_v = testMat.row, testMat.col + train_data = np.hstack((train_u.reshape(-1, 1), train_v.reshape(-1, 1))).tolist() + test_data = np.hstack((test_u.reshape(-1, 1), test_v.reshape(-1, 1))).tolist() + train_dataset = BPRData(train_data, self.itemNum, trainMat, 1, True) + test_dataset = BPRData(test_data, self.itemNum, trainMat, 0, False) + self.train_loader = dataloader.DataLoader(train_dataset, batch_size=self.args.batch, shuffle=True, + num_workers=0) + self.test_loader = dataloader.DataLoader(test_dataset, batch_size=1024 * 1000, shuffle=False, num_workers=0) + self.train_losses = [] + self.test_hr = [] + self.test_ndcg = [] + + def predictModel(self, user, pos_i, neg_j, isTest=False): + if isTest: + pred_pos = t.sum(user * pos_i, dim=1) + return pred_pos + else: + pred_pos = t.sum(user * pos_i, dim=1) + pred_neg = t.sum(user * neg_j, dim=1) + return pred_pos, pred_neg + + # Contrastive Learning + def ssl_loss(self, data1, data2, index): + index = t.unique(index) + embeddings1 = data1[index] + embeddings2 = data2[index] + norm_embeddings1 = F.normalize(embeddings1, p=2, dim=1) + norm_embeddings2 = F.normalize(embeddings2, p=2, dim=1) + pos_score = t.sum(t.mul(norm_embeddings1, norm_embeddings2), dim=1) + all_score = t.mm(norm_embeddings1, norm_embeddings2.T) + pos_score = t.exp(pos_score / self.args.ssl_temp) + all_score = t.sum(t.exp(all_score / self.args.ssl_temp), dim=1) + ssl_loss = (-t.sum(t.log(pos_score / ((all_score)))) / (len(index))) + return ssl_loss + + # Model train + def _mini_train_step(self): + epoch_loss = 0 + self.train_loader.dataset.ng_sample() + step_num = 0 # count batch num + for user, item_i, item_j in self.train_loader: + user = user.long().cuda() + item_i = item_i.long().cuda() + item_j = item_j.long().cuda() + step_num += 1 + self.istrain = True + itemindex = t.unique(t.cat((item_i, item_j))) + userindex = t.unique(user) + self.userEmbed, self.itemEmbed, self.ui_userEmbedall, self.ui_itemEmbedall, self.ui_userEmbed, self.ui_itemEmbed, metaregloss = self.model( + self.istrain, userindex, itemindex, norm=1) + + # Contrastive Learning of collaborative relations + ssl_loss_user = self.ssl_loss(self.ui_userEmbed, self.userEmbed, user) + ssl_loss_item = self.ssl_loss(self.ui_itemEmbed, self.itemEmbed, item_i) + ssl_loss = self.args.ssl_ureg * ssl_loss_user + self.args.ssl_ireg * ssl_loss_item + + # prediction + pred_pos, pred_neg = self.predictModel(self.ui_userEmbedall[user], self.ui_itemEmbedall[item_i], + self.ui_itemEmbedall[item_j]) + bpr_loss = - nn.LogSigmoid()(pred_pos - pred_neg).sum() + epoch_loss += bpr_loss.item() + regLoss = (t.norm(self.ui_userEmbedall[user]) ** 2 + t.norm(self.ui_itemEmbedall[item_i]) ** 2 + t.norm( + self.ui_itemEmbedall[item_j]) ** 2) + loss = ((bpr_loss + regLoss * self.args.reg) / self.args.batch) + ssl_loss * self.args.ssl_beta + metaregloss * self.args.metareg + + self.opt.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=20, norm_type=2) + self.opt.step() + return epoch_loss + + def test(self): + HR = [] + NDCG = [] + + with t.no_grad(): + uid = np.arange(0, self.userNum) + iid = np.arange(0, self.itemNum) + self.istrain = False + _, _, self.ui_userEmbed, self.ui_itemEmbed, _, _, _ = self.model(self.istrain, uid, iid, norm=1) + for test_u, test_i in self.test_loader: + test_u = test_u.long().cuda() + test_i = test_i.long().cuda() + pred = self.predictModel(self.ui_userEmbed[test_u], self.ui_itemEmbed[test_i], None, isTest=True) + batch = int(test_u.cpu().numpy().size / 100) + for i in range(batch): + batch_socres = pred[i * 100:(i + 1) * 100].view(-1) + _, indices = t.topk(batch_socres, self.args.topk) + tmp_item_i = test_i[i * 100:(i + 1) * 100] + recommends = t.take(tmp_item_i, indices).cpu().numpy().tolist() + gt_item = tmp_item_i[0].item() + HR.append(self.hit(gt_item, recommends)) + NDCG.append(self.ndcg(gt_item, recommends)) + return np.mean(HR), np.mean(NDCG) + + + def hit(self, gt_item, pred_items): + if gt_item in pred_items: + return 1 + return 0 + + def ndcg(self, gt_item, pred_items): + if gt_item in pred_items: + index = pred_items.index(gt_item) + return np.reciprocal(np.log2(index + 2)) + return 0 + + def log(self, msg, save=None, oneline=False): + global logmsg + global saveDefault + time = datetime.datetime.now() + tem = '%s: %s' % (time, msg) + if save != None: + if save: + logmsg += tem + '\n' + elif saveDefault: + logmsg += tem + '\n' + if oneline: + print(tem, end='\r') + else: + print(tem) + + def _full_train_setp(self): + pass + + def _test_step(self, split=None, logits=None): + pass + + def train(self): + # self.prepareModel() + self.curEpoch = 0 + best_hr = -1 + best_ndcg = -1 + best_epoch = -1 + HR_lis = [] + for e in range(self.args.epochs + 1): + self.curEpoch = e + # train + self.log("**************************************************************") + epoch_loss = self._mini_train_step() + self.train_losses.append(epoch_loss) + self.log("epoch %d/%d, epoch_loss=%.2f" % (e, self.args.epochs, epoch_loss)) + + # test + HR, NDCG = self.test() # + self.test_hr.append(HR) + self.test_ndcg.append(NDCG) + self.log("epoch %d/%d, HR@10=%.4f, NDCG@10=%.4f" % (e, self.args.epochs, HR, NDCG)) + # self.adjust_learning_rate() + if HR > best_hr: + best_hr, best_ndcg, best_epoch = HR, NDCG, e + + HR_lis.append(HR) + + print("*****************************") + self.log("best epoch = %d, HR= %.4f, NDCG=%.4f" % (best_epoch, best_hr, best_ndcg)) + print("*****************************") + print(self.args) + +class BPRData(data.Dataset): + def __init__(self, data, + num_item, train_mat=None, num_ng=0, is_training=None): + super(BPRData, self).__init__() + """ Note that the labels are only useful when training, we thus + add them in the ng_sample() function. + """ + self.data = np.array(data) + self.num_item = num_item + self.train_mat = train_mat + self.num_ng = num_ng + self.is_training = is_training + + def ng_sample(self): + assert self.is_training, 'no need to sampling when testing' + tmp_trainMat = self.train_mat.todok() + length = self.data.shape[0] + self.neg_data = np.random.randint(low=0, high=self.num_item, size=length) + + for i in range(length): + uid = self.data[i][0] + iid = self.neg_data[i] + if (uid, iid) in tmp_trainMat: + while (uid, iid) in tmp_trainMat: + iid = np.random.randint(low=0, high=self.num_item) + self.neg_data[i] = iid + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + user = self.data[idx][0] + item_i = self.data[idx][1] + if self.is_training: + neg_data = self.neg_data + item_j = neg_data[idx] + return user, item_i, item_j + else: + return user, item_i