From 1862ac8c7f364ecc6a27bd6ce3c00efb0f035122 Mon Sep 17 00:00:00 2001 From: Ying-1106 <1007105680@qq.com> Date: Wed, 24 Jan 2024 19:41:40 +0800 Subject: [PATCH 1/7] add dataset from Neo4j --- openhgnn/dataset/NodeClassificationDataset.py | 5 +- openhgnn/dataset/__init__.py | 2 +- openhgnn/dataset/academic_graph.py | 108 ++++++++++++++++++ 3 files changed, 112 insertions(+), 3 deletions(-) diff --git a/openhgnn/dataset/NodeClassificationDataset.py b/openhgnn/dataset/NodeClassificationDataset.py index d1db91fb..a812af0e 100644 --- a/openhgnn/dataset/NodeClassificationDataset.py +++ b/openhgnn/dataset/NodeClassificationDataset.py @@ -10,7 +10,7 @@ from ogb.nodeproppred import DglNodePropPredDataset from . import load_acm_raw from . import BaseDataset, register_dataset -from . import AcademicDataset, HGBDataset, OHGBDataset +from . import AcademicDataset, HGBDataset, OHGBDataset,IMDB4MAGNN_Dataset from .utils import sparse_mx_to_torch_sparse_tensor from ..utils import add_reverse_edges @@ -211,7 +211,8 @@ def load_HIN(self, name_dataset): self.in_dim = g.ndata['h'][category].shape[1] elif name_dataset == 'imdb4MAGNN': - dataset = AcademicDataset(name='imdb4MAGNN', raw_dir='') + # dataset = AcademicDataset(name='imdb4MAGNN', raw_dir='') + dataset = IMDB4MAGNN_Dataset(name='imdb4MAGNN') category = 'M' g = dataset[0].long() num_classes = 3 diff --git a/openhgnn/dataset/__init__.py b/openhgnn/dataset/__init__.py index 8b6ed583..187be319 100644 --- a/openhgnn/dataset/__init__.py +++ b/openhgnn/dataset/__init__.py @@ -2,7 +2,7 @@ from dgl.data import DGLDataset from .base_dataset import BaseDataset from .utils import load_acm, load_acm_raw, generate_random_hg -from .academic_graph import AcademicDataset +from .academic_graph import AcademicDataset,IMDB4MAGNN_Dataset from .hgb_dataset import HGBDataset from .ohgb_dataset import OHGBDataset from .gtn_dataset import * diff --git a/openhgnn/dataset/academic_graph.py b/openhgnn/dataset/academic_graph.py index 48c31245..1521c0db 100644 --- a/openhgnn/dataset/academic_graph.py +++ b/openhgnn/dataset/academic_graph.py @@ -9,6 +9,114 @@ import torch as th + +class IMDB4MAGNN_Dataset(DGLDataset): + + + _prefix = 'https://s3.cn-north-1.amazonaws.com.cn/dgl-data/' + + + def __init__(self, name, raw_dir=None, force_reload=False, verbose=True): + assert name in ['imdb4MAGNN', ] + + self._urls = { + 'imdb4MAGNN': 'dataset/openhgnn/{}_std.zip'.format(name), +} + + raw_dir = './openhgnn/dataset' + self.data_path = './openhgnn/dataset/'+ name +'_std.zip' + self.g_path = './openhgnn/dataset/' + name + '/{}_std.pkl'.format(name) + url = self._prefix + self._urls[name] # https://s3.cn-north-1.amazonaws.com.cn/dgl-data/ + dataset/imdb4MAGNN_std.zip + super(IMDB4MAGNN_Dataset, self).__init__(name=name, + url=url, + raw_dir=raw_dir, + force_reload=force_reload, + verbose=verbose) + + def download(self): + # download raw data to local disk + # path to store the file + if os.path.exists(self.data_path): + pass + else: + # + download(self.url, + path=os.path.join(self.raw_dir)) + + extract_archive(self.data_path, os.path.join(self.raw_dir, self.name)) + + def process(self): + + import dgl + import pickle + with open(self.g_path, 'rb') as file: + graph = pickle.load(file) + + cano_edges = {} + for edge_type in graph['edge_index_dict'].keys(): + src_type = edge_type[0] + dst_type = edge_type[-1] + edge_type_2 = src_type + '-'+dst_type + cano_edge_type = (src_type,edge_type_2,dst_type) + u,v = graph['edge_index_dict'][edge_type][0] ,graph['edge_index_dict'][edge_type][1] + cano_edges[cano_edge_type] = (u,v) + + hg = dgl.heterograph(cano_edges) + for node_type in graph['X_dict'].keys() : + hg.nodes[node_type].data['h'] = graph['X_dict'][node_type] + if node_type == 'M': + hg.nodes[node_type].data['labels'] = graph['Y_dict'][node_type] + import torch + num_nodes = 4278 + random_indices = torch.randperm(num_nodes) + num_train = 400 + num_val = 400 + num_test = 3478 + train_mask = torch.zeros(num_nodes, dtype=torch.int) + train_mask[random_indices[:num_train]] = 1 + val_mask = torch.zeros(num_nodes, dtype=torch.int) + val_mask[random_indices[num_train:num_train+num_val]] = 1 + test_mask = torch.zeros(num_nodes, dtype=torch.int) + test_mask[random_indices[num_train+num_val:]] = 1 + + assert torch.sum(train_mask * val_mask) == 0 + assert torch.sum(train_mask * test_mask) == 0 + assert torch.sum(val_mask * test_mask) == 0 + + hg.nodes['M'].data['train_mask'] = train_mask + hg.nodes['M'].data['val_mask'] = val_mask + hg.nodes['M'].data['test_mask'] = test_mask + + self._g = hg + + + + + def __getitem__(self, idx): + # get one example by index + assert idx == 0, "This dataset has only one graph" + return self._g + + def __len__(self): + # number of data examples + return 1 + + + def save(self): + # save processed data to directory `self.save_path` + pass + + def load(self): + # load processed data from directory `self.save_path` + pass + + def has_cache(self): + # check whether there are processed data in `self.save_path` + pass + + + + class AcademicDataset(DGLDataset): _prefix = 'https://s3.cn-north-1.amazonaws.com.cn/dgl-data/' From 17ea33ab822334dffc23a769c06cfe64d48319c3 Mon Sep 17 00:00:00 2001 From: Zhao Zihao <1007105680@qq.com> Date: Mon, 5 Feb 2024 01:47:01 +0800 Subject: [PATCH 2/7] add database API --- main.py | 1 + openhgnn/config.py | 5 + openhgnn/dataset/NodeClassificationDataset.py | 16 ++- openhgnn/dataset/__init__.py | 6 +- openhgnn/dataset/academic_graph.py | 124 ++++++++++++++++++ openhgnn/experiment.py | 2 + openhgnn/tasks/node_classification.py | 3 +- 7 files changed, 153 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 17c07321..733a3282 100644 --- a/main.py +++ b/main.py @@ -16,6 +16,7 @@ parser.add_argument('--gpu', '-g', default='-1', type=int, help='-1 means cpu') parser.add_argument('--use_best_config', action='store_true', help='will load utils.best_config') parser.add_argument('--load_from_pretrained', action='store_true', help='load model from the checkpoint') + parser.add_argument('--use_database',action='store_true',help = 'use database') args = parser.parse_args() experiment = Experiment(model=args.model, dataset=args.dataset, task=args.task, gpu=args.gpu, diff --git a/openhgnn/config.py b/openhgnn/config.py index c52fec9e..23e808bb 100644 --- a/openhgnn/config.py +++ b/openhgnn/config.py @@ -241,6 +241,11 @@ def __init__(self, file_path, model, dataset, task, gpu): self.ff_layer = conf.getint('NARS', 'ff_layer') elif self.model_name == 'MAGNN': + + self.graph_address = '' + self.user_name = '' + self.password = '' + self.lr = conf.getfloat("MAGNN", "learning_rate") self.weight_decay = conf.getfloat("MAGNN", "weight_decay") self.seed = conf.getint("MAGNN", "seed") diff --git a/openhgnn/dataset/NodeClassificationDataset.py b/openhgnn/dataset/NodeClassificationDataset.py index d1db91fb..be2e473a 100644 --- a/openhgnn/dataset/NodeClassificationDataset.py +++ b/openhgnn/dataset/NodeClassificationDataset.py @@ -10,7 +10,7 @@ from ogb.nodeproppred import DglNodePropPredDataset from . import load_acm_raw from . import BaseDataset, register_dataset -from . import AcademicDataset, HGBDataset, OHGBDataset +from . import AcademicDataset, HGBDataset, OHGBDataset,IMDB4MAGNN_Dataset from .utils import sparse_mx_to_torch_sparse_tensor from ..utils import add_reverse_edges @@ -181,8 +181,15 @@ class HIN_NodeClassification(NodeClassificationDataset): def __init__(self, dataset_name, *args, **kwargs): super(HIN_NodeClassification, self).__init__(*args, **kwargs) + + if 'args' in kwargs: + self.args = kwargs['args'] + else: + self.args = None + self.g, self.category, self.num_classes = self.load_HIN(dataset_name) + def load_HIN(self, name_dataset): if name_dataset == 'demo_graph': data_path = './openhgnn/dataset/demo_graph.bin' @@ -211,7 +218,12 @@ def load_HIN(self, name_dataset): self.in_dim = g.ndata['h'][category].shape[1] elif name_dataset == 'imdb4MAGNN': - dataset = AcademicDataset(name='imdb4MAGNN', raw_dir='') + + if self.args.use_database == True: + dataset = IMDB4MAGNN_Dataset(name='imdb4MAGNN',args = self.args) + else: + dataset = AcademicDataset(name='imdb4MAGNN', raw_dir='') + category = 'M' g = dataset[0].long() num_classes = 3 diff --git a/openhgnn/dataset/__init__.py b/openhgnn/dataset/__init__.py index 8b6ed583..4e043207 100644 --- a/openhgnn/dataset/__init__.py +++ b/openhgnn/dataset/__init__.py @@ -74,9 +74,13 @@ def build_dataset(dataset, task, *args, **kwargs): if dataset in ['aifb', 'mutag', 'bgs', 'am']: _dataset = 'rdf_' + task elif dataset in ['acm4NSHE', 'acm4GTN', 'academic4HetGNN', 'acm_han', 'acm_han_raw', 'acm4HeCo', 'dblp', - 'dblp4MAGNN', 'imdb4MAGNN', 'imdb4GTN', 'acm4NARS', 'demo_graph', 'yelp4HeGAN', 'DoubanMovie', + 'dblp4MAGNN', 'imdb4GTN', 'acm4NARS', 'demo_graph', 'yelp4HeGAN', 'DoubanMovie', 'Book-Crossing', 'amazon4SLICE', 'MTWM', 'HNE-PubMed', 'HGBl-ACM', 'HGBl-DBLP', 'HGBl-IMDB','amazon', 'yelp4HGSL']: _dataset = 'hin_' + task + elif dataset in ['imdb4MAGNN']: + _dataset = 'hin_' + task + return DATASET_REGISTRY[_dataset](dataset, logger=kwargs['logger'], + args = kwargs['args'] ) elif dataset in ohgbn_datasets + ohgbl_datasets: _dataset = 'ohgb_' + task elif dataset in ['ogbn-mag']: diff --git a/openhgnn/dataset/academic_graph.py b/openhgnn/dataset/academic_graph.py index 48c31245..a4b149a0 100644 --- a/openhgnn/dataset/academic_graph.py +++ b/openhgnn/dataset/academic_graph.py @@ -9,6 +9,130 @@ import torch as th + +# get dataset from database +class IMDB4MAGNN_Dataset(DGLDataset): + + def __init__(self, name, args, raw_dir=None, force_reload=False, verbose=True): + assert name in ['imdb4MAGNN', ] + + self.args = args + super(IMDB4MAGNN_Dataset, self).__init__(name=name, + url=None, + raw_dir=None, + force_reload=force_reload, + verbose=verbose) + + + def download(self): + + from gdbi import NodeExportConfig, EdgeExportConfig, Neo4jInterface, NebulaInterface + node_export_config = [ + NodeExportConfig('A', ['attribute'] ), + NodeExportConfig('M', ['attribute'], ['label']), + NodeExportConfig('D', ['attribute']) + ] + + edge_export_config = [ + EdgeExportConfig('A_M', ('A','M')), + EdgeExportConfig('M_A', ('M','A')), + EdgeExportConfig('M_D', ('M','D')), + EdgeExportConfig('D_M', ('D','M')) + ] + + # neo4j + graph_database = Neo4jInterface() + + # # nebula + # graph_database = NebulaInterface() + + graph_address = self.args.graph_address + user_name = self.args.user_name + password = self.args.password + + conn = graph_database.GraphDBConnection(graph_address, user_name, password) + self.graph = graph_database.get_graph(conn, 'imdb4MAGNN', node_export_config, edge_export_config) + + + + + def process(self): + + graph = self.graph + cano_edges = {} + for edge_type in graph['edge_index_dict'].keys(): # 'A_M' + src_type = edge_type[0] # A + dst_type = edge_type[-1] # M + edge_type_2 = src_type + '-' + dst_type # A-M + + cano_edge_type = (src_type,edge_type_2,dst_type) # ('A','A-M','M') + u,v = graph['edge_index_dict'][edge_type][0] ,graph['edge_index_dict'][edge_type][1] + + cano_edges[cano_edge_type] = (u,v) + + + + hg = dgl.heterograph(cano_edges) + + for node_type in graph['X_dict'].keys() : + hg.nodes[node_type].data['h'] = graph['X_dict'][node_type] + if node_type == 'M': + hg.nodes[node_type].data['labels'] = graph['Y_dict'][node_type] + + import torch + + + + num_nodes = 4278 + random_indices = torch.randperm(num_nodes) + + num_train = 400 + num_val = 400 + num_test = 3478 + + train_mask = torch.zeros(num_nodes, dtype=torch.int) + train_mask[random_indices[:num_train]] = 1 + val_mask = torch.zeros(num_nodes, dtype=torch.int) + val_mask[random_indices[num_train:num_train+num_val]] = 1 + test_mask = torch.zeros(num_nodes, dtype=torch.int) + test_mask[random_indices[num_train+num_val:]] = 1 + + assert torch.sum(train_mask * val_mask) == 0 + assert torch.sum(train_mask * test_mask) == 0 + assert torch.sum(val_mask * test_mask) == 0 + + hg.nodes['M'].data['train_mask'] = train_mask + hg.nodes['M'].data['val_mask'] = val_mask + hg.nodes['M'].data['test_mask'] = test_mask + + self._g = hg + + + + + def __getitem__(self, idx): + # get one example by index + assert idx == 0, "This dataset has only one graph" + return self._g + + def __len__(self): + return 1 + + + def save(self): + pass + + def load(self): + pass + + def has_cache(self): + pass + + + + + + class AcademicDataset(DGLDataset): _prefix = 'https://s3.cn-north-1.amazonaws.com.cn/dgl-data/' diff --git a/openhgnn/experiment.py b/openhgnn/experiment.py index 60364ece..00643470 100644 --- a/openhgnn/experiment.py +++ b/openhgnn/experiment.py @@ -73,6 +73,7 @@ def __init__(self, model, dataset, task, hpo_trials: int = 100, output_dir: str = "./openhgnn/output", conf_path: str = default_conf_path, + use_database:bool = False, **kwargs): self.config = Config(file_path=conf_path, model=model, dataset=dataset, task=task, gpu=gpu) self.config.model = model @@ -80,6 +81,7 @@ def __init__(self, model, dataset, task, self.config.task = task self.config.gpu = gpu self.config.use_best_config = use_best_config + self.config.use_database = use_database # self.config.use_hpo = use_hpo self.config.load_from_pretrained = load_from_pretrained self.config.output_dir = os.path.join(output_dir, self.config.model_name) diff --git a/openhgnn/tasks/node_classification.py b/openhgnn/tasks/node_classification.py index 3565aae0..4352edb4 100644 --- a/openhgnn/tasks/node_classification.py +++ b/openhgnn/tasks/node_classification.py @@ -31,7 +31,8 @@ class NodeClassification(BaseTask): def __init__(self, args): super(NodeClassification, self).__init__() self.logger = args.logger - self.dataset = build_dataset(args.dataset, 'node_classification', logger=self.logger) + self.dataset = build_dataset(args.dataset, 'node_classification', + logger=self.logger,args = args) # self.evaluator = Evaluator() self.logger = args.logger if hasattr(args, 'validation'): From b62dae734c7b26ff0129d53d0fe5fe869979542b Mon Sep 17 00:00:00 2001 From: Zhao Zihao <1007105680@qq.com> Date: Mon, 5 Feb 2024 17:34:53 +0800 Subject: [PATCH 3/7] add gdbi install --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index da59c176..79b90f85 100644 --- a/README.md +++ b/README.md @@ -169,6 +169,14 @@ cd OpenHGNN pip install . ``` + + +**5. 安装 gdbi:** + +```bash +pip install git+https://github.com/xy-Ji/gdbi.git +``` + #### 在已有的评测上运行已有的基线模型 [数据集](./openhgnn/dataset/#Dataset) ```bash From d08f851184995d1a297ff8a147819b73b11c48b3 Mon Sep 17 00:00:00 2001 From: Zhao Zihao <1007105680@qq.com> Date: Mon, 5 Feb 2024 17:41:03 +0800 Subject: [PATCH 4/7] add database_use --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 79b90f85..fc114987 100644 --- a/README.md +++ b/README.md @@ -177,6 +177,8 @@ pip install . pip install git+https://github.com/xy-Ji/gdbi.git ``` +- 使用者需要安装neo4j,nebula等第三方库,并且在config.py文件中修改graph_address,user_name和password,以便能访问数据库,并调用gdbi.get_graph方法获得对应的图数据集 + #### 在已有的评测上运行已有的基线模型 [数据集](./openhgnn/dataset/#Dataset) ```bash From 6e165fa15c574fd782e7194ec6a4b8e22475b760 Mon Sep 17 00:00:00 2001 From: Zhao Zihao <1007105680@qq.com> Date: Mon, 5 Feb 2024 17:44:30 +0800 Subject: [PATCH 5/7] add database --- README.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index fc114987..9890522e 100644 --- a/README.md +++ b/README.md @@ -177,16 +177,17 @@ pip install . pip install git+https://github.com/xy-Ji/gdbi.git ``` -- 使用者需要安装neo4j,nebula等第三方库,并且在config.py文件中修改graph_address,user_name和password,以便能访问数据库,并调用gdbi.get_graph方法获得对应的图数据集 +- 使用者需要安装neo4j,nebula,atlas,gstore的第三方包,并且在config.py文件中修改graph_address,user_name和password,以便能访问数据库,并调用gdbi.get_graph方法获得对应的图数据集。 + #### 在已有的评测上运行已有的基线模型 [数据集](./openhgnn/dataset/#Dataset) ```bash -python main.py -m model_name -d dataset_name -t task_name -g 0 --use_best_config --load_from_pretrained +python main.py -m model_name -d dataset_name -t task_name -g 0 --use_best_config --load_from_pretrained --use_database ``` 使用方法: main.py [-h] [--model MODEL] [--task TASK] [--dataset DATASET] - [--gpu GPU] [--use_best_config] + [--gpu GPU] [--use_best_config][--use_database] *可选参数*: @@ -204,6 +205,8 @@ python main.py -m model_name -d dataset_name -t task_name -g 0 --use_best_config ``--load_from_pretrained`` 从默认检查点加载模型。 +``--use_database`` 从数据库加载数据集 + 示例: ```bash From 3668651ca6385c7f8adb6c77a212117a290f37fd Mon Sep 17 00:00:00 2001 From: Zhao Zihao <1007105680@qq.com> Date: Mon, 5 Feb 2024 17:52:20 +0800 Subject: [PATCH 6/7] . --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 9890522e..21dc8964 100644 --- a/README.md +++ b/README.md @@ -178,6 +178,11 @@ pip install git+https://github.com/xy-Ji/gdbi.git ``` - 使用者需要安装neo4j,nebula,atlas,gstore的第三方包,并且在config.py文件中修改graph_address,user_name和password,以便能访问数据库,并调用gdbi.get_graph方法获得对应的图数据集。 +- 示例: + +```bash +python main.py -m MAGNN -d imdb4MAGNN -t node_classification -g 0 --use_best_config --use_database +``` #### 在已有的评测上运行已有的基线模型 [数据集](./openhgnn/dataset/#Dataset) From b384af7e48e9f03588e3612866a7e92f0c9049f2 Mon Sep 17 00:00:00 2001 From: Zhao Zihao <1007105680@qq.com> Date: Tue, 6 Feb 2024 13:49:34 +0800 Subject: [PATCH 7/7] add readme about database --- README.md | 37 +++++++++++++++++++++++++++++++------ README_EN.md | 47 ++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 77 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 21dc8964..404011de 100644 --- a/README.md +++ b/README.md @@ -171,24 +171,24 @@ pip install . -**5. 安装 gdbi:** +**5. 安装 gdbi(可选):** +- 安装gdbi ```bash pip install git+https://github.com/xy-Ji/gdbi.git ``` -- 使用者需要安装neo4j,nebula,atlas,gstore的第三方包,并且在config.py文件中修改graph_address,user_name和password,以便能访问数据库,并调用gdbi.get_graph方法获得对应的图数据集。 -- 示例: - +- 安装图数据库 ```bash -python main.py -m MAGNN -d imdb4MAGNN -t node_classification -g 0 --use_best_config --use_database +pip install neo4j==5.16.0 +pip install nebula3-python==3.4.0 ``` #### 在已有的评测上运行已有的基线模型 [数据集](./openhgnn/dataset/#Dataset) ```bash -python main.py -m model_name -d dataset_name -t task_name -g 0 --use_best_config --load_from_pretrained --use_database +python main.py -m model_name -d dataset_name -t task_name -g 0 --use_best_config --load_from_pretrained ``` 使用方法: main.py [-h] [--model MODEL] [--task TASK] [--dataset DATASET] @@ -222,6 +222,7 @@ python main.py -m GTN -d imdb4GTN -t node_classification -g 0 --use_best_config 请参考 [文档](https://openhgnn.readthedocs.io/en/latest/index.html) 了解更多的基础和进阶的使用方法。 + #### 使用TensorBoard可视化训练结果 ```bash tensorboard --logdir=./openhgnn/output/{model_name}/ @@ -230,8 +231,32 @@ tensorboard --logdir=./openhgnn/output/{model_name}/ ```bash tensorboard --logdir=./openhgnn/output/RGCN/ ``` + **提示**:需要先运行一次你想要可视化的模型,才能用以上命令可视化结果。 +#### 使用gdbi访问数据库中的标准图数据 +以neo4j数据库和imdb数据集为例 +- 构造图数据集的csv文件(节点级:A.csv,连接级:A_P.csv) +- 导入csv文件到图数据库中 +```bash +LOAD CSV WITH HEADERS FROM "file:///data.csv" AS row +CREATE (:graphname_labelname {ID: row.ID, ... }); +``` +- 在config.py文件中添加访问图数据库所需的用户信息 +```python +self.graph_address = [graph_address] +self.user_name = [user_name] +self.password = [password] +``` + +- 示例: + +```bash +python main.py -m MAGNN -d imdb4MAGNN -t node_classification -g 0 --use_best_config --use_database +``` + + + ## [模型](./openhgnn/models/#Model) ### 特定任务下支持的模型 diff --git a/README_EN.md b/README_EN.md index 728053a2..c23bc783 100644 --- a/README_EN.md +++ b/README_EN.md @@ -171,6 +171,24 @@ cd OpenHGNN pip install . ``` + +**5. Install gdbi(Optional):** + +- install gdbi from git +```bash +pip install git+https://github.com/xy-Ji/gdbi.git +``` + +- install graph database from pypi +```bash +pip install neo4j==5.16.0 +pip install nebula3-python==3.4.0 +``` + + + + + #### Running an existing baseline model on an existing benchmark [dataset](../openhgnn/dataset/#Dataset) ```bash @@ -178,7 +196,7 @@ python main.py -m model_name -d dataset_name -t task_name -g 0 --use_best_config ``` usage: main.py [-h] [--model MODEL] [--task TASK] [--dataset DATASET] -[--gpu GPU] [--use_best_config] +[--gpu GPU] [--use_best_config][--use_database] *optional arguments*: @@ -198,6 +216,8 @@ will override the parameter in config.ini. ``--load_from_pretrained`` will load the model from a default checkpoint. +``--use_database`` get dataset from database + e.g.: ```bash @@ -218,6 +238,31 @@ tensorboard --logdir=./openhgnn/output/RGCN/ ``` **Note**: To visualize results, you need to train the model first. + +#### Use gdbi to get grpah dataset +take neo4j and imdb dataset for example +- construct csv file for dataset(node-level:A.csv,edge-level:A_P.csv) +- import csv file to database +```bash +LOAD CSV WITH HEADERS FROM "file:///data.csv" AS row +CREATE (:graphname_labelname {ID: row.ID, ... }); +``` +- add user information to access database in config.py file +```python +self.graph_address = [graph_address] +self.user_name = [user_name] +self.password = [password] +``` + +- e.g.: + +```bash +python main.py -m MAGNN -d imdb4MAGNN -t node_classification -g 0 --use_best_config --use_database +``` + + + + ## [Models](../openhgnn/models/#Model) ### Supported Models with specific task