Skip to content

Commit

Permalink
add lightGCN model and modify KGCN model
Browse files Browse the repository at this point in the history
  • Loading branch information
clingingsai committed Jul 18, 2023
1 parent 2e68609 commit 6a501a7
Showing 1 changed file with 32 additions and 2 deletions.
34 changes: 32 additions & 2 deletions openhgnn/dataset/RecommendationDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch as th
import numpy as np
from . import BaseDataset, register_dataset
from dgl.data.utils import load_graphs
from dgl.data.utils import load_graphs, download
from scipy.sparse import csr_matrix
import scipy.sparse as sp
from .multigraph import MultiGraphDataset
Expand Down Expand Up @@ -55,7 +55,16 @@ class lightGCN_Recommendation(RecommendationDataset):
def __init__(self, dataset_name, *args, **kwargs):
super(RecommendationDataset, self).__init__(*args, **kwargs)

# train and test data
if dataset_name not in ['gowalla','yelp2018','amazon-book']:
raise KeyError('Dataset {} is not supported!'.format(dataset_name))
self.dataset_name=dataset_name

self.data_path=f'openhgnn/dataset/{self.dataset_name}'

if not os.path.exists(f"{self.data_path}/train.txt"):
self.download()

# test
self.mode_dict = {'train': 0, "test": 1}
self.mode = self.mode_dict['train']
self.n_user = 0
Expand Down Expand Up @@ -189,6 +198,27 @@ def getSparseGraph(self):
self.Graph = self.Graph.coalesce()
print("don't split the matrix")
return self.Graph

def download(self):
prefix = 'https://raw.githubusercontent.com/gusye1234/LightGCN-PyTorch/master/data'

required_file = ['train.txt', 'test.txt']

for filename in required_file:
url = f"{prefix}/{self.dataset_name}/{filename}"
file_path = f"{self.data_path}/{filename}"
if not os.path.exists(file_path):
try:
download(url, file_path)

except BaseException as e:
print("\n",e)
print("\nNote! --- If you want to download the file, vpn is required ---")
print("If you don't have a vpn, please download the dataset from here: https://github.com/gusye1234/LightGCN-PyTorch")
print("\nAfter downloading the dataset, you need to store the files in the following path: ")
print(f"{os.getcwd()}\openhgnn\dataset\{self.dataset_name}\\train.txt")
print(f"{os.getcwd()}\openhgnn\dataset\{self.dataset_name}\\test.txt")
exit()


@register_dataset('hin_recommendation')
Expand Down

0 comments on commit 6a501a7

Please sign in to comment.