diff --git a/openhgnn/config.ini b/openhgnn/config.ini index bf0a9469..ced176fd 100644 --- a/openhgnn/config.ini +++ b/openhgnn/config.ini @@ -1,3 +1,96 @@ +[DisenKGAT] +# str +name = Disen_Model +# data = DisenKGAT_WN18RR +# model = DisenKGAT +score_func = interacte +opn = cross +# gpu = 2 +logdir = ./log/ +config = ./config/ +strategy = one_to_n +form = plain +mi_method = club_b +att_mode = dot_weight +score_method = dot_rel +score_order = after +gamma_method = norm + + +# int +k_w = 10 +batch = 2048 +test_batch = 2048 +epoch = 1500 +num_workers = 10 +seed = 41504 +init_dim = 100 +gcn_dim = 200 +embed_dim = 200 +gcn_layer = 1 +k_h = 20 +num_filt = 200 +ker_sz = 7 +num_bases = -1 +neg_num = 1000 +ik_w = 10 +ik_h = 20 +inum_filt = 200 +iker_sz = 9 +iperm = 1 +head_num = 1 +num_factors = 3 +early_stop = 200 +mi_epoch = 1 + +# float +feat_drop = 0.3 +hid_drop2 = 0.3 +hid_drop = 0.3 +gcn_drop = 0.4 +gamma = 9.0 +l2 = 0.0 +lr = 0.001 +lbl_smooth = 0.1 +iinp_drop = 0.3 +ifeat_drop = 0.4 +ihid_drop = 0.3 +alpha = 1e-1 +max_gamma = 5.0 +init_gamma = 9.0 + +# boolean +restore = False +bias = False +no_act = False +mi_train = True +no_enc = False +mi_drop = True +fix_gamma = False + + + + + +[NBF] +input_dim = 32 +hidden_dims = [32, 32, 32, 32, 32, 32] +message_func = distmult +aggregate_func = pna +short_cut = True +layer_norm = True +dependent = False +num_negative = 32 +strict_negative = True +adversarial_temperature = 1 +metric = ['mr', 'mrr', 'hits@1', 'hits@3', 'hits@10', 'hits@10_50'] +lr = 0.005 +gpus = [0] +batch_size = 64 +num_epoch = 20 +log_interval = 100 + + [General] learning_rate = 0.01 weight_decay = 0.0001 @@ -887,3 +980,254 @@ embedding_size = 64 num_layers = 3 test_u_batch_size = 100 topks = 20 + +[Grail] +num_epochs: 100 +eval_every: 3 +eval_every_iter: 455 +save_every: 10 +early_stop: 100 +optimizer: Adam +lr: 0.01 +clip:1000 +l2: 5e-4 +margin: 10 +max_links:1000000 +hop: 3 +max_nodes_per_hop: 0 +use_kge_embeddings: False +kge_model: TransE +model_type: dgl +constrained_neg_prob: 0.0 +batch_size: 16 +num_neg_samples_per_link: 1 +num_workers: 8 +add_traspose_rels: False +enclosing_sub_graph: True +rel_emb_dim: 32 +attn_rel_emb_dim: 32 +emb_dim: 32 +num_gcn_layers: 3 +num_bases: 4 +dropout: 0 +edge_dropout: 0.5 +gnn_agg_type: sum +add_ht_emb: True +has_attn: True +mode: sample + +[ComPILE] +num_epochs: 100 +eval_every: 3 +eval_every_iter: 455 +save_every: 10 +early_stop: 100 +optimizer: Adam +lr: 0.01 +clip:1000 +l2: 5e-4 +margin: 10 +max_links:1000000 +hop: 3 +max_nodes_per_hop: 0 +use_kge_embeddings: False +kge_model: TransE +model_type: dgl +constrained_neg_prob: 0.0 +batch_size: 16 +num_neg_samples_per_link: 1 +num_workers: 8 +add_traspose_rels: False +enclosing_sub_graph: True +rel_emb_dim: 32 +attn_rel_emb_dim: 32 +emb_dim: 32 +num_gcn_layers: 3 +num_bases: 4 +dropout: 0 +edge_dropout: 0.5 +gnn_agg_type: sum +add_ht_emb: True +has_attn: True +mode: sample + + +[AdapropT] +data_path = data/family/ +layers=8 +sampling=incremental +act=relu +weight=None +tau=1.0 +train=True +remove_1hop_edges=False +scheduler=exp +fact_ratio=0.9 +epoch=300 +eval_interval=1 +topk = 100 +lr = 0.0036 +decay_rate = 0.999 +lamb = 0.000017 +hidden_dim = 48 +attn_dim = 5 +dropout = 0.29 +n_edge_topk = -1 +n_layer = 8 +n_batch = 20 +n_node_topk = 800 +seed = 1234 +n_tbatch=20 +eval=False +[AdapropI] +data_path=./data/fb237_v1 +seed=1234 + +[LTE] +model_name_GCN=LTE_Transe +model_name=LTE +name=lte +data=FB15k-237 +score_func=transe +opn=mult +hid_drop=0.2 +gpu=0 +x_ops=p +n_layer=0 +init_dim=200 +batch_size=64 +epoch=300 +l2=0.0 +lr=0.001 +lbl_smooth=0.1 +num_workers=8 +seed=12345 +restore=False +bias=False +num_bases=-1 +gcn_dim=200 +gcn_drop=0.1 +conve_hid_drop=0.3 +feat_drop=0.2 +input_drop=0.2 +k_w=20 +k_h=10 +num_filt=200 +ker_sz=7 +gamma=9.0 +rat=False +wni=False +wsi=False +ss=False +nobn=False +noltr=False +encoder=compgcn +max_epochs=500 + +[SACN] +seed=12345 +init_emb_size=200 +gc1_emb_size=150 +embedding_dim=200 +input_dropout=0 +dropout_rate=0.2 +channels=200 +kernel_size=5 +gpu=5 +lr=0.002 +n_epochs=300 +num_workers=2 +eval_every=1 +dataset_data=FB15k-237 +batch_size=64 +patience=100 +decoder=transe +gamma=9.0 +name=repro +n_layer=1 +rat=False +wsi=False +wni=False +ss=-1 +final_act=True +final_bn=False +final_drop=False + +[ExpressGNN] +seed=10 +embedding_size = 128 +gcn_free_size = 127 +slice_dim = 16 +no_train = 0 +filtered = filtered +hidden_dim = 64 +num_epochs = 100 +batchsize = 16 +trans = 0 +num_hops = 2 +num_mlp_layers = 2 +num_batches = 100 +learning_rate = 0.0005 +lr_decay_factor = 0.5 +lr_decay_patience = 100 +lr_decay_min = 0.00001 +patience = 20 +l2_coef = 0.0 +observed_prob = 0.9 +entropy_temp = 1 +no_entropy = 0 +rule_weights_learning = 1 +learning_rate_rule_weights = 0.001 +epoch_mode = 0 +shuffle_sampling = 1 +load_method = 1 +load_s = 1 +use_gcn = 1 +filter_latent = 0 +closed_world = 0 + +[Ingram] +margin = 2 +lr = 5e-4 +nle = 2 +nlr = 2 +d_e = 32 +d_r = 32 +hdr_e = 8 +hdr_r = 4 +num_bin = 10 +num_epoch = 10000 +validation_epoch = 200 +num_head = 8 +num_neg = 10 + +[RedGNN] +seed = 0 +patience = 3 +batch_size = 100 +hidden_dim = 64 +optimizer = Adam +lr = 0.005 +weight_decay = 0.0002 +max_epoch = 50 +decay_rate = 0.991 +attn_dim = 5 +dropout = 0.21 +act = idd +n_layer = 5 + +[RedGNNT] +seed = 0 +patience = 3 +batch_size = 20 +hidden_dim = 48 +optimizer = Adam +lr = 0.0036 +weight_decay = 0.000017 +max_epoch = 50 +decay_rate = 0.999 +attn_dim = 5 +dropout = 0.21 +act = relu +n_layer = 3 +n_tbatch = 50 diff --git a/openhgnn/config.py b/openhgnn/config.py index 23e808bb..eb56d319 100644 --- a/openhgnn/config.py +++ b/openhgnn/config.py @@ -9,7 +9,7 @@ class Config(object): def __init__(self, file_path, model, dataset, task, gpu): - conf = configparser.ConfigParser() + conf = configparser.ConfigParser( ) try: conf.read(file_path) except: @@ -40,6 +40,92 @@ def __init__(self, file_path, model, dataset, task, gpu): self.seed = conf.getint("General", "seed") self.patience = conf.getint("General", "patience") self.mini_batch_flag = conf.getboolean("General", "mini_batch_flag") + + elif self.model_name == "DisenKGAT": + + self.name = conf.get("DisenKGAT","name") + self.score_func = conf.get("DisenKGAT","score_func") + self.opn = conf.get("DisenKGAT","opn") + self.logdir = conf.get("DisenKGAT","logdir") + self.config = conf.get("DisenKGAT","config") + self.strategy = conf.get("DisenKGAT","strategy") + self.form = conf.get("DisenKGAT","form") + self.mi_method = conf.get("DisenKGAT","mi_method") + self.att_mode = conf.get("DisenKGAT","att_mode") + self.score_method = conf.get("DisenKGAT","score_method") + self.score_order = conf.get("DisenKGAT","score_order") + self.gamma_method = conf.get("DisenKGAT","gamma_method") + + self.k_w= conf.getint("DisenKGAT", "k_w") + self.batch = conf.getint("DisenKGAT", "batch") + self.test_batch = conf.getint("DisenKGAT", "test_batch") + self.epoch = conf.getint("DisenKGAT", "epoch") + self.num_workers = conf.getint("DisenKGAT", "num_workers") + self.seed = conf.getint("DisenKGAT", "seed") + self.init_dim = conf.getint("DisenKGAT", "init_dim") + self.gcn_dim = conf.getint("DisenKGAT", "gcn_dim") + self.embed_dim = conf.getint("DisenKGAT", "embed_dim") + self.gcn_layer = conf.getint("DisenKGAT", "gcn_layer") + self.k_h = conf.getint("DisenKGAT", "k_h") + self.num_filt = conf.getint("DisenKGAT", "num_filt") + self.ker_sz = conf.getint("DisenKGAT", "ker_sz") + self.num_bases = conf.getint("DisenKGAT", "num_bases") + self.neg_num = conf.getint("DisenKGAT", "neg_num") + self.ik_w = conf.getint("DisenKGAT", "ik_w") + self.ik_h = conf.getint("DisenKGAT", "ik_h") + self.inum_filt = conf.getint("DisenKGAT", "inum_filt") + self.iker_sz = conf.getint("DisenKGAT", "iker_sz") + self.iperm = conf.getint("DisenKGAT", "iperm") + self.head_num = conf.getint("DisenKGAT", "head_num") + self.num_factors = conf.getint("DisenKGAT", "num_factors") + self.early_stop = conf.getint("DisenKGAT", "early_stop") + self.mi_epoch = conf.getint("DisenKGAT", "mi_epoch") + self.feat_drop = conf.getfloat("DisenKGAT", "feat_drop") + self.hid_drop2 = conf.getfloat("DisenKGAT", "hid_drop2") + self.hid_drop = conf.getfloat("DisenKGAT", "hid_drop") + self.gcn_drop = conf.getfloat("DisenKGAT", "gcn_drop") + self.gamma = conf.getfloat("DisenKGAT", "gamma") + self.l2 = conf.getfloat("DisenKGAT", "l2") + self.lr = conf.getfloat("DisenKGAT", "lr") + self.lbl_smooth = conf.getfloat("DisenKGAT", "lbl_smooth") + self.iinp_drop = conf.getfloat("DisenKGAT", "iinp_drop") + self.ifeat_drop = conf.getfloat("DisenKGAT", "ifeat_drop") + self.ihid_drop = conf.getfloat("DisenKGAT", "ihid_drop") + self.alpha = conf.getfloat("DisenKGAT", "alpha") + self.max_gamma = conf.getfloat("DisenKGAT", "max_gamma") + self.init_gamma = conf.getfloat("DisenKGAT", "init_gamma") + self.restore = conf.getboolean("DisenKGAT", "restore") + self.bias = conf.getboolean("DisenKGAT", "bias") + self.no_act = conf.getboolean("DisenKGAT", "no_act") + self.mi_train = conf.getboolean("DisenKGAT", "mi_train") + self.no_enc = conf.getboolean("DisenKGAT", "no_enc") + self.mi_drop = conf.getboolean("DisenKGAT", "mi_drop") + self.fix_gamma = conf.getboolean("DisenKGAT", "fix_gamma") + + + elif self.model_name == "NBF" and self.dataset_name =="NBF_WN18RR": + + self.input_dim = conf.getint("NBF", "input_dim") + self.hidden_dims = [32, 32, 32, 32, 32, 32] + self.message_func = conf.get("NBF", "message_func") + self.aggregate_func = conf.get("NBF", "aggregate_func") + self.short_cut = conf.getboolean("NBF","short_cut") + self.layer_norm = conf.getboolean("NBF","layer_norm") + self.dependent = conf.getboolean("NBF","dependent") + + self.num_negative = conf.getint("NBF","num_negative") + self.strict_negative = conf.getboolean("NBF","strict_negative") + self.adversarial_temperature = conf.getint("NBF", "adversarial_temperature") + self.metric = ['mr', 'mrr', 'hits@1', 'hits@3', 'hits@10', 'hits@10_50'] + + self.lr = conf.getfloat("NBF","lr") + self.gpus = [0] + self.batch_size = conf.getint("NBF","batch_size") + self.num_epoch = conf.getint("NBF","num_epoch") + self.log_interval = conf.getint("NBF","log_interval") + +############################################################################################################### + elif self.model_name == "NSHE": self.dim_size = {} self.dim_size['emd'] = conf.getint("NSHE", "emd_dim") @@ -926,12 +1012,268 @@ def __init__(self, file_path, model, dataset, task, gpu): self.topks = conf.getint("lightGCN", "topks") # self.alpha = conf.getfloat("lightGCN", "alpha") + elif self.model_name == "AdapropT": + self.lr = conf.getfloat("AdapropT", "lr") + self.decay_rate = conf.getfloat("AdapropT", "decay_rate") + self.lamb = conf.getfloat("AdapropT", "lamb") + self.hidden_dim = conf.getint("AdapropT", "hidden_dim") + self.attn_dim = conf.getint("AdapropT", "attn_dim") + self.dropout = conf.getfloat("AdapropT", "dropout") + self.n_edge_topk = conf.getint("AdapropT", "n_edge_topk") + self.n_layer = conf.getint("AdapropT", "n_layer") + self.n_batch = conf.getint("AdapropT", "n_batch") + self.n_node_topk = conf.getint("AdapropT", "n_node_topk") + self.seed = conf.getint("AdapropT", "seed") + self.topk = conf.getint("AdapropT", "topk") + self.data_path = conf.get("AdapropT", "data_path") + self.layers = conf.getint("AdapropT", "layers") + self.sampling = conf.get("AdapropT", "sampling") + self.train = conf.getboolean("AdapropT", "train") + self.scheduler = conf.get("AdapropT", "scheduler") + self.fact_ratio = conf.getfloat("AdapropT", "fact_ratio") + self.epoch = conf.getint("AdapropT", "epoch") + self.eval_interval = conf.getint("AdapropT", "eval_interval") + self.remove_1hop_edges = conf.getboolean("AdapropT", "remove_1hop_edges") + self.act = conf.get("AdapropT", "act") + self.tau = conf.getfloat("AdapropT", "tau") + self.weight = conf.get("AdapropT", "weight") + self.n_tbatch = conf.getint("AdapropT", "n_tbatch") + self.eval = conf.getboolean("AdapropT", 'eval') + + elif self.model_name == "AdapropI": + self.data_path = conf.get("AdapropI", "data_path") + self.seed = conf.getint("AdapropI", "seed") + + elif self.model_name == 'LTE': + self.model_name_GCN=conf.get("LTE", "model_name_GCN") + self.name=conf.get("LTE", "name") + self.data=conf.get("LTE", "data") + self.score_func=conf.get("LTE", "score_func") + self.opn=conf.get("LTE", "opn") + self.hid_drop=conf.getfloat("LTE", "hid_drop") + self.gpu=conf.getint("LTE", "gpu") + self.x_ops=conf.get("LTE", "x_ops") + self.n_layer=conf.getint("LTE", "n_layer") + self.init_dim=conf.getint("LTE", "init_dim") + self.batch_size=conf.getint("LTE", "batch_size") + self.epoch=conf.getint("LTE", "epoch") + self.l2=conf.getfloat("LTE", "l2") + self.lr=conf.getfloat("LTE", "lr") + self.lbl_smooth=conf.getfloat("LTE", "lbl_smooth") + self.num_workers=conf.getint("LTE", "num_workers") + self.seed=conf.getint("LTE", "seed") + self.restore=conf.getboolean("LTE", "restore") + self.bias=conf.getboolean("LTE", "bias") + self.num_bases=conf.getint("LTE", "num_bases") + self.gcn_dim=conf.getint("LTE", "gcn_dim") + self.gcn_drop=conf.getfloat("LTE", "gcn_drop") + self.conve_hid_drop=conf.getfloat("LTE", "conve_hid_drop") + self.feat_drop=conf.getfloat("LTE", "feat_drop") + self.input_drop=conf.getfloat("LTE", "input_drop") + self.k_w=conf.getint("LTE", "k_w") + self.k_h=conf.getint("LTE", "k_h") + self.num_filt=conf.getint("LTE", "num_filt") + self.ker_sz=conf.getint("LTE", "ker_sz") + self.gamma=conf.getfloat("LTE", "gamma") + self.rat=conf.getboolean("LTE", "rat") + self.wni=conf.getboolean("LTE", "wni") + self.wsi=conf.getboolean("LTE", "wsi") + self.ss=conf.getboolean("LTE", "ss") + self.nobn=conf.getboolean("LTE", "nobn") + self.noltr=conf.getboolean("LTE", "noltr") + self.encoder=conf.get("LTE", "encoder") + self.max_epochs=conf.getint("LTE", "max_epochs") + + elif self.model_name == 'SACN': + self.seed=conf.getint("SACN","seed") + self.init_emb_size=conf.getint("SACN","init_emb_size") + self.gc1_emb_size=conf.getint("SACN","gc1_emb_size") + self.embedding_dim=conf.getint("SACN","embedding_dim") + self.input_dropout=conf.getint("SACN","input_dropout") + self.dropout_rate=conf.getfloat("SACN","dropout_rate") + self.channels=conf.getint("SACN","channels") + self.kernel_size=conf.getint("SACN","kernel_size") + self.gpu=conf.getint("SACN","gpu") + self.lr=conf.getfloat("SACN","lr") + self.n_epochs=conf.getint("SACN","n_epochs") + self.num_workers=conf.getint("SACN","num_workers") + self.eval_every=conf.getint("SACN","eval_every") + self.dataset_data=conf.get("SACN","dataset_data") + self.batch_size=conf.getint("SACN","batch_size") + self.patience=conf.getint("SACN","patience") + self.decoder=conf.get("SACN","decoder") + self.gamma=conf.getfloat("SACN","gamma") + self.name=conf.get("SACN","name") + self.n_layer=conf.getint("SACN","n_layer") + self.rat=conf.getboolean("SACN","rat") + self.wsi=conf.getboolean("SACN","wsi") + self.wni=conf.getboolean("SACN","wni") + self.ss=conf.getint("SACN","ss") + self.final_act=conf.getboolean("SACN","final_act") + self.final_bn=conf.getboolean("SACN","final_bn") + self.final_drop=conf.getboolean("SACN","final_drop") + + elif self.model_name == 'Ingram': + self.margin = conf.getint("Ingram", "margin") + self.lr = conf.getfloat("Ingram", "lr") + self.nle = conf.getint("Ingram", "nle") + self.nlr = conf.getint("Ingram", "nlr") + self.d_e = conf.getint("Ingram", "d_e") + self.d_r = conf.getint("Ingram", "d_r") + self.hdr_e = conf.getint("Ingram", "hdr_e") + self.hdr_r = conf.getint("Ingram", "hdr_r") + self.num_bin = conf.getint("Ingram", "num_bin") + self.num_epoch = conf.getint("Ingram", "num_epoch") + self.validation_epoch = conf.getint("Ingram", "validation_epoch") + self.num_head = conf.getint("Ingram", "num_head") + self.num_neg = conf.getint("Ingram", "num_neg") + elif self.model_name == 'RedGNN': + self.seed = conf.getint("RedGNN", "seed") + self.patience = conf.getint("RedGNN", "patience") + self.batch_size = conf.getint("RedGNN", "batch_size") + self.optimizer = conf.get("RedGNN", "optimizer") + self.lr = conf.getfloat("RedGNN", "lr") + self.weight_decay = conf.getfloat("RedGNN", "weight_decay") + self.max_epoch = conf.getint("RedGNN", "max_epoch") + self.decay_rate = conf.getfloat("RedGNN", "decay_rate") + self.hidden_dim = conf.getint("RedGNN", "hidden_dim") + self.attn_dim = conf.getint("RedGNN", "attn_dim") + self.dropout = conf.getfloat("RedGNN", "dropout") + self.act = conf.get("RedGNN", "act") + self.n_layer = conf.getint("RedGNN", "n_layer") + + elif self.model_name == 'RedGNNT': + self.seed = conf.getint("RedGNNT", "seed") + self.patience = conf.getint("RedGNNT", "patience") + self.batch_size = conf.getint("RedGNNT", "batch_size") + self.n_tbatch = conf.getint("RedGNNT", "n_tbatch") + self.optimizer = conf.get("RedGNNT", "optimizer") + self.lr = conf.getfloat("RedGNNT", "lr") + self.weight_decay = conf.getfloat("RedGNNT", "weight_decay") + self.max_epoch = conf.getint("RedGNNT", "max_epoch") + self.decay_rate = conf.getfloat("RedGNNT", "decay_rate") + self.hidden_dim = conf.getint("RedGNNT", "hidden_dim") + self.attn_dim = conf.getint("RedGNNT", "attn_dim") + self.dropout = conf.getfloat("RedGNNT", "dropout") + self.act = conf.get("RedGNNT", "act") + self.n_layer = conf.getint("RedGNNT", "n_layer") + + elif self.model_name == 'ExpressGNN': + self.embedding_size = conf.getint('ExpressGNN', 'embedding_size') + self.gcn_free_size = conf.getint("ExpressGNN", "gcn_free_size") + self.filtered = conf.get("ExpressGNN", "filtered") + self.hidden_dim = conf.getint("ExpressGNN", "hidden_dim") + self.rule_weights_learning = conf.getint("ExpressGNN", "rule_weights_learning") + self.load_method = conf.getint("ExpressGNN", "load_method") + self.num_epochs = conf.getint("ExpressGNN", "num_epochs") + + self.slice_dim = conf.getint("ExpressGNN", "slice_dim") + self.no_train = conf.getint("ExpressGNN", "no_train") + self.hidden_dim = conf.getint("ExpressGNN", "hidden_dim") + self.num_epochs = conf.getint("ExpressGNN", "num_epochs") + self.batchsize = conf.getint("ExpressGNN", "batchsize") + self.trans = conf.getint("ExpressGNN", "trans") + self.num_hops = conf.getint("ExpressGNN", "num_hops") + self.num_mlp_layers = conf.getint("ExpressGNN", "num_mlp_layers") + self.num_epochs = conf.getint("ExpressGNN", "num_epochs") + + self.num_batches = conf.getint("ExpressGNN", "num_batches") + self.learning_rate = conf.getfloat("ExpressGNN", "learning_rate") + self.lr_decay_factor = conf.getfloat("ExpressGNN", "lr_decay_factor") + self.lr_decay_patience = conf.getint("ExpressGNN", "lr_decay_patience") + self.lr_decay_min = conf.getfloat("ExpressGNN", "lr_decay_min") + self.patience = conf.getint("ExpressGNN", "patience") + self.l2_coef = conf.getfloat("ExpressGNN", "l2_coef") + self.observed_prob = conf.getfloat("ExpressGNN", "observed_prob") + self.entropy_temp = conf.getint("ExpressGNN", "entropy_temp") + self.no_entropy = conf.getint("ExpressGNN", "no_entropy") + self.learning_rate_rule_weights = conf.getfloat("ExpressGNN", "learning_rate_rule_weights") + self.epoch_mode = conf.getint("ExpressGNN", "epoch_mode") + self.shuffle_sampling = conf.getint("ExpressGNN", "shuffle_sampling") + + self.load_method = conf.getint("ExpressGNN", "load_method") + self.load_s = conf.getint("ExpressGNN", "load_s") + self.use_gcn = conf.getint("ExpressGNN", "use_gcn") + self.filter_latent = conf.getint("ExpressGNN", "filter_latent") + self.closed_world = conf.getint("ExpressGNN", "closed_world") + self.seed = conf.getint("ExpressGNN", "seed") + + elif self.model_name =='Grail': + self.num_epochs = conf.getint("Grail", "num_epochs") + self.eval_every = conf.getint("Grail","eval_every") + self.eval_every_iter = conf.getint("Grail","eval_every_iter") + self.save_every = conf.getint("Grail","save_every") + self.early_stop = conf.getint("Grail","early_stop") + self.optimizer = conf.get("Grail","optimizer") + self.lr = conf.getfloat("Grail","lr") + self.clip = conf.getint("Grail","clip") + self.l2 = conf.getfloat("Grail","l2") + self.margin = conf.getint("Grail","margin") + self.max_links = conf.getint("Grail","max_links") + self.hop = conf.getint("Grail","hop") + self.max_nodes_per_hop= conf.getint("Grail","max_nodes_per_hop") + self.use_kge_embeddings = conf.getboolean("Grail","use_kge_embeddings") + self.kge_model = conf.get("Grail","kge_model") + self.model_type =conf.get("Grail","model_type") + self.constrained_neg_prob= conf.getfloat("Grail","constrained_neg_prob") + self.batch_size = conf.getint("Grail","batch_size") + self.num_neg_samples_per_link = conf.getint("Grail","num_neg_samples_per_link") + self.num_workers = conf.getint("Grail","num_workers") + self.add_traspose_rels = conf.getboolean("Grail","add_traspose_rels") + self.enclosing_sub_graph = conf.getboolean("Grail","enclosing_sub_graph") + self.rel_emb_dim = conf.getint("Grail","rel_emb_dim") + self.attn_rel_emb_dim = conf.getint("Grail","attn_rel_emb_dim") + self.emb_dim = conf.getint("Grail","emb_dim") + self.num_gcn_layers = conf.getint("Grail","num_gcn_layers") + self.num_bases = conf.getint("Grail","num_bases") + self.dropout = conf.getfloat("Grail","dropout") + self.edge_dropout = conf.getfloat("Grail", "edge_dropout") + self.gnn_agg_type = conf.get("Grail","gnn_agg_type") + self.add_ht_emb = conf.getboolean("Grail","add_ht_emb") + self.has_attn = conf.getboolean("Grail", "has_attn") + self.mode = conf.get("Grail","mode") + + elif self.model_name =='ComPILE': + self.num_epochs = conf.getint("ComPILE", "num_epochs") + self.eval_every = conf.getint("ComPILE","eval_every") + self.eval_every_iter = conf.getint("ComPILE","eval_every_iter") + self.save_every = conf.getint("ComPILE","save_every") + self.early_stop = conf.getint("ComPILE","early_stop") + self.optimizer = conf.get("ComPILE","optimizer") + self.lr = conf.getfloat("ComPILE","lr") + self.clip = conf.getint("ComPILE","clip") + self.l2 = conf.getfloat("ComPILE","l2") + self.margin = conf.getint("ComPILE","margin") + self.max_links = conf.getint("ComPILE","max_links") + self.hop = conf.getint("ComPILE","hop") + self.max_nodes_per_hop= conf.getint("ComPILE","max_nodes_per_hop") + self.use_kge_embeddings = conf.getboolean("ComPILE","use_kge_embeddings") + self.kge_model = conf.get("ComPILE","kge_model") + self.model_type =conf.get("ComPILE","model_type") + self.constrained_neg_prob= conf.getfloat("ComPILE","constrained_neg_prob") + self.batch_size = conf.getint("ComPILE","batch_size") + self.num_neg_samples_per_link = conf.getint("Grail","num_neg_samples_per_link") + self.num_workers = conf.getint("ComPILE","num_workers") + self.add_traspose_rels = conf.getboolean("ComPILE","add_traspose_rels") + self.enclosing_sub_graph = conf.getboolean("ComPILE","enclosing_sub_graph") + self.rel_emb_dim = conf.getint("ComPILE","rel_emb_dim") + self.attn_rel_emb_dim = conf.getint("ComPILE","attn_rel_emb_dim") + self.emb_dim = conf.getint("ComPILE","emb_dim") + self.num_gcn_layers = conf.getint("ComPILE","num_gcn_layers") + self.num_bases = conf.getint("ComPILE","num_bases") + self.dropout = conf.getfloat("ComPILE","dropout") + self.edge_dropout = conf.getfloat("ComPILE", "edge_dropout") + self.gnn_agg_type = conf.get("ComPILE","gnn_agg_type") + self.add_ht_emb = conf.getboolean("ComPILE","add_ht_emb") + self.has_attn = conf.getboolean("ComPILE", "has_attn") + self.mode = conf.get("ComPILE","mode") + if hasattr(self, 'device'): self.device = th.device(self.device) elif gpu == -1: self.device = th.device('cpu') elif gpu >= 0: - if not th.cuda.is_available(): + if not th.cuda.is_available( ): self.device = th.device('cpu') warnings.warn("cuda is unavailable, the program will use cpu instead. please set 'gpu' to -1.") else: diff --git a/openhgnn/dataset/AdapropI_dataset.py b/openhgnn/dataset/AdapropI_dataset.py new file mode 100644 index 00000000..91b6a23e --- /dev/null +++ b/openhgnn/dataset/AdapropI_dataset.py @@ -0,0 +1,245 @@ +import os +import torch +import numpy as np +from scipy.sparse import csr_matrix +from collections import defaultdict +import requests +import zipfile +import io + +class AdapropIDataLoader: + def __init__(self, args): + self.args = args + self.dir = './data' + name1=self.args.dataset_name + name2=name1+'_ind' + path_ckp1 = os.path.join(self.dir, name1) + path_ckp2 = os.path.join(self.dir, name2) + self.dir = os.path.join(self.dir, name1) + task_dir=self.dir + print(path_ckp1) + folder = os.path.exists(path_ckp1) + if not folder: # 判断是否存在文件夹如果不存在则创建为文件夹 + os.makedirs(path_ckp1) # makedirs 创建文件时如果路径不存在会创建这个路径 + # 下载数据 + if name1=='fb237_v1': + url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/fb237_v1.zip" + elif name1=='fb237_v2': + url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/fb237_v2.zip" + elif name1=='fb237_v3': + url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/fb237_v3.zip" + elif name1=='fb237_v4': + url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/fb237_v4.zip" + response = requests.get(url) + with zipfile.ZipFile(io.BytesIO(response.content)) as myzip: + myzip.extractall(path_ckp1) + print("--- download data ---") + + else: + print("--- There is data! ---") + + print(path_ckp2) + folder = os.path.exists(path_ckp2) + if not folder: # 判断是否存在文件夹如果不存在则创建为文件夹 + os.makedirs(path_ckp2) # makedirs 创建文件时如果路径不存在会创建这个路径 + # 下载数据 + if name1=='fb237_v1': + url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/fb237_v1_ind.zip" + elif name1=='fb237_v2': + url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/fb237_v2_ind.zip" + elif name1=='fb237_v3': + url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/fb237_v3_ind.zip" + elif name1=='fb237_v4': + url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/fb237_v4_ind.zip" + response = requests.get(url) + with zipfile.ZipFile(io.BytesIO(response.content)) as myzip: + myzip.extractall(path_ckp2) + print("--- download data ---") + + else: + print("--- There is data! ---") + self.task_dir=self.dir + task_dir=self.task_dir + n_batch=args.n_batch + self.trans_dir = task_dir + self.n_batch = n_batch + self.ind_dir = task_dir + '_ind' + + with open(os.path.join(task_dir, 'entities.txt')) as f: + self.entity2id = dict() + for line in f: + entity, eid = line.strip().split() + self.entity2id[entity] = int(eid) + + with open(os.path.join(task_dir, 'relations.txt')) as f: + self.relation2id = dict() + id2relation = [] + for line in f: + relation, rid = line.strip().split() + self.relation2id[relation] = int(rid) + id2relation.append(relation) + + with open(os.path.join(self.ind_dir, 'entities.txt')) as f: + self.entity2id_ind = dict() + for line in f: + entity, eid = line.strip().split() + self.entity2id_ind[entity] = int(eid) + + for i in range(len(self.relation2id)): + id2relation.append(id2relation[i] + '_inv') + id2relation.append('idd') + self.id2relation = id2relation + + self.n_ent = len(self.entity2id) + self.n_rel = len(self.relation2id) + self.n_ent_ind = len(self.entity2id_ind) + + self.tra_train = self.read_triples(self.trans_dir, 'train.txt') + self.tra_valid = self.read_triples(self.trans_dir, 'valid.txt') + self.tra_test = self.read_triples(self.trans_dir, 'test.txt') + self.ind_train = self.read_triples(self.ind_dir, 'train.txt', 'inductive') + self.ind_valid = self.read_triples(self.ind_dir, 'valid.txt', 'inductive') + self.ind_test = self.read_triples(self.ind_dir, 'test.txt', 'inductive') + + self.val_filters = self.get_filter('valid') + self.tst_filters = self.get_filter('test') + + for filt in self.val_filters: + self.val_filters[filt] = list(self.val_filters[filt]) + for filt in self.tst_filters: + self.tst_filters[filt] = list(self.tst_filters[filt]) + + self.tra_KG, self.tra_sub = self.load_graph(self.tra_train) + self.ind_KG, self.ind_sub = self.load_graph(self.ind_train, 'inductive') + + self.tra_train = np.array(self.tra_valid) + self.tra_val_qry, self.tra_val_ans = self.load_query(self.tra_test) + self.ind_val_qry, self.ind_val_ans = self.load_query(self.ind_valid) + self.ind_tst_qry, self.ind_tst_ans = self.load_query(self.ind_test) + self.valid_q, self.valid_a = self.tra_val_qry, self.tra_val_ans + self.test_q, self.test_a = self.ind_val_qry + self.ind_tst_qry, self.ind_val_ans + self.ind_tst_ans + + self.n_train = len(self.tra_train) + self.n_valid = len(self.valid_q) + self.n_test = len(self.test_q) + + print('n_train:', self.n_train, 'n_valid:', self.n_valid, 'n_test:', self.n_test) + + def read_triples(self, directory, filename, mode='transductive'): + triples = [] + with open(os.path.join(directory, filename)) as f: + for line in f: + h, r, t = line.strip().split() + if mode == 'transductive': + h, r, t = self.entity2id[h], self.relation2id[r], self.entity2id[t] + else: + h, r, t = self.entity2id_ind[h], self.relation2id[r], self.entity2id_ind[t] + + triples.append([h, r, t]) + triples.append([t, r + self.n_rel, h]) + return triples + + def load_graph(self, triples, mode='transductive'): + n_ent = self.n_ent if mode == 'transductive' else self.n_ent_ind + + KG = np.array(triples) + idd = np.concatenate([np.expand_dims(np.arange(n_ent), 1), 2 * self.n_rel * np.ones((n_ent, 1)), + np.expand_dims(np.arange(n_ent), 1)], 1) + KG = np.concatenate([KG, idd], 0) + + n_fact = KG.shape[0] + + M_sub = csr_matrix((np.ones((n_fact,)), (np.arange(n_fact), KG[:, 0])), shape=(n_fact, n_ent)) + return KG, M_sub + + def load_query(self, triples): + triples.sort(key=lambda x: (x[0], x[1])) + trip_hr = defaultdict(lambda: list()) + + for trip in triples: + h, r, t = trip + trip_hr[(h, r)].append(t) + + queries = [] + answers = [] + for key in trip_hr: + queries.append(key) + answers.append(np.array(trip_hr[key])) + return queries, answers + + def get_neighbors(self, nodes, mode='transductive'): + # nodes: n_node x 2 with (batch_idx, node_idx) + if mode == 'transductive': + KG = self.tra_KG + M_sub = self.tra_sub + n_ent = self.n_ent + else: + KG = self.ind_KG + M_sub = self.ind_sub + n_ent = self.n_ent_ind + + node_1hot = csr_matrix((np.ones(len(nodes)), (nodes[:, 1], nodes[:, 0])), shape=(n_ent, nodes.shape[0])) + edge_1hot = M_sub.dot(node_1hot) + edges = np.nonzero(edge_1hot) + selected_edges = np.concatenate([np.expand_dims(edges[1], 1), KG[edges[0]]], + axis=1) # (batch_idx, head, rela, tail) + selected_edges = torch.LongTensor(selected_edges).cuda() + + # index to nodes + head_nodes, head_index = torch.unique(selected_edges[:, [0, 1]], dim=0, sorted=True, return_inverse=True) + tail_nodes, tail_index = torch.unique(selected_edges[:, [0, 3]], dim=0, sorted=True, return_inverse=True) + mask = selected_edges[:, 2] == (self.n_rel * 2) + _, old_idx = head_index[mask].sort() + old_nodes_new_idx = tail_index[mask][old_idx] + selected_edges = torch.cat([selected_edges, head_index.unsqueeze(1), tail_index.unsqueeze(1)], 1) + + return tail_nodes, selected_edges, old_nodes_new_idx + + def get_batch(self, batch_idx, steps=2, data='train'): + if data == 'train': + return self.tra_train[batch_idx] + if data == 'valid': + query, answer = np.array(self.valid_q), self.valid_a + n_ent = self.n_ent + if data == 'test': + query, answer = np.array(self.test_q), self.test_a + n_ent = self.n_ent_ind + + subs = [] + rels = [] + objs = [] + + subs = query[batch_idx, 0] + rels = query[batch_idx, 1] + objs = np.zeros((len(batch_idx), n_ent)) + for i in range(len(batch_idx)): + objs[i][answer[batch_idx[i]]] = 1 + return subs, rels, objs + + def shuffle_train(self, ): + rand_idx = np.random.permutation(self.n_train) + self.tra_train = self.tra_train[rand_idx] + + def get_filter(self, data='valid'): + filters = defaultdict(lambda: set()) + if data == 'valid': + for triple in self.tra_train: + h, r, t = triple + filters[(h, r)].add(t) + for triple in self.tra_valid: + h, r, t = triple + filters[(h, r)].add(t) + for triple in self.tra_test: + h, r, t = triple + filters[(h, r)].add(t) + else: + for triple in self.ind_train: + h, r, t = triple + filters[(h, r)].add(t) + for triple in self.ind_valid: + h, r, t = triple + filters[(h, r)].add(t) + for triple in self.ind_test: + h, r, t = triple + filters[(h, r)].add(t) + return filters diff --git a/openhgnn/dataset/AdapropT_dataset.py b/openhgnn/dataset/AdapropT_dataset.py new file mode 100644 index 00000000..e6f76759 --- /dev/null +++ b/openhgnn/dataset/AdapropT_dataset.py @@ -0,0 +1,220 @@ +import io +import os +import torch +from scipy.sparse import csr_matrix +import numpy as np +from collections import defaultdict +import requests +import zipfile + +class AdapropTDataLoader: + def __init__(self, args): + self.args = args + # self.task_dir = task_dir = args.data_path + # current_dir = os.getcwd() + # print(1111) + # print(current_dir) + # current_dir = os.path.join(current_dir, 'OpenHGNN') + # task_dir=os.path.join(current_dir,task_dir) + # print(task_dir) + # self.task_dir=task_dir + self.dir = './data' + path_ckp = os.path.join(self.dir, 'family') + self.dir = os.path.join(self.dir, 'family') + task_dir=self.dir + self.task_dir=self.dir + print(path_ckp) + folder = os.path.exists(path_ckp) + if not folder: # 判断是否存在文件夹如果不存在则创建为文件夹 + os.makedirs(path_ckp) # makedirs 创建文件时如果路径不存在会创建这个路径 + # 下载数据 + url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/family.zip" + response = requests.get(url) + with zipfile.ZipFile(io.BytesIO(response.content)) as myzip: + myzip.extractall(path_ckp) + print("--- download data ---") + + else: + print("--- There is data! ---") + with open(os.path.join(task_dir, 'entities.txt')) as f: + self.entity2id = dict() + n_ent = 0 + for line in f: + entity = line.strip() + self.entity2id[entity] = n_ent + n_ent += 1 + + with open(os.path.join(task_dir, 'relations.txt')) as f: + self.relation2id = dict() + n_rel = 0 + for line in f: + relation = line.strip() + self.relation2id[relation] = n_rel + n_rel += 1 + + self.n_ent = n_ent + self.n_rel = n_rel + + # prepare triples + self.filters = defaultdict(lambda: set()) + self.fact_triple = self.read_triples('facts.txt') + self.train_triple = self.read_triples('train.txt') + self.valid_triple = self.read_triples('valid.txt') + self.test_triple = self.read_triples('test.txt') + self.all_triple = np.concatenate([np.array(self.fact_triple), np.array(self.train_triple)], axis=0) + self.tmp_all_triple = np.concatenate( + [np.array(self.fact_triple), np.array(self.train_triple), np.array(self.valid_triple), + np.array(self.test_triple)], axis=0) + + # add inverse + self.fact_data = self.double_triple(self.fact_triple) + self.train_data = np.array(self.double_triple(self.train_triple)) + self.valid_data = self.double_triple(self.valid_triple) + self.test_data = self.double_triple(self.test_triple) + + self.shuffle_train() + self.load_graph(self.fact_data) + self.load_test_graph(self.double_triple(self.fact_triple) + self.double_triple(self.train_triple)) + self.valid_q, self.valid_a = self.load_query(self.valid_data) + self.test_q, self.test_a = self.load_query(self.test_data) + + self.n_train = len(self.train_data) + self.n_valid = len(self.valid_q) + self.n_test = len(self.test_q) + + for filt in self.filters: + self.filters[filt] = list(self.filters[filt]) + + print('n_train:', self.n_train, 'n_valid:', self.n_valid, 'n_test:', self.n_test) + + def read_triples(self, filename): + triples = [] + with open(os.path.join(self.task_dir, filename)) as f: + for line in f: + h, r, t = line.strip().split() + h, r, t = self.entity2id[h], self.relation2id[r], self.entity2id[t] + triples.append([h, r, t]) + self.filters[(h, r)].add(t) + self.filters[(t, r + self.n_rel)].add(h) + return triples + + def double_triple(self, triples): + new_triples = [] + for triple in triples: + h, r, t = triple + new_triples.append([t, r + self.n_rel, h]) + return triples + new_triples + + def load_graph(self, triples): + # (e, r', e) + # r' = 2 * n_rel, r' is manual generated and not exist in the original KG + # self.KG: shape=(self.n_fact, 3) + # M_sub shape=(self.n_fact, self.n_ent), store projection from head entity to triples + idd = np.concatenate([np.expand_dims(np.arange(self.n_ent), 1), 2 * self.n_rel * np.ones((self.n_ent, 1)), + np.expand_dims(np.arange(self.n_ent), 1)], 1) + + self.KG = np.concatenate([np.array(triples), idd], 0) + self.n_fact = len(self.KG) + self.M_sub = csr_matrix((np.ones((self.n_fact,)), (np.arange(self.n_fact), self.KG[:, 0])), + shape=(self.n_fact, self.n_ent)) + + def load_test_graph(self, triples): + idd = np.concatenate([np.expand_dims(np.arange(self.n_ent), 1), 2 * self.n_rel * np.ones((self.n_ent, 1)), + np.expand_dims(np.arange(self.n_ent), 1)], 1) + + self.tKG = np.concatenate([np.array(triples), idd], 0) + self.tn_fact = len(self.tKG) + self.tM_sub = csr_matrix((np.ones((self.tn_fact,)), (np.arange(self.tn_fact), self.tKG[:, 0])), + shape=(self.tn_fact, self.n_ent)) + + def load_query(self, triples): + trip_hr = defaultdict(lambda: list()) + + for trip in triples: + h, r, t = trip + trip_hr[(h, r)].append(t) + + queries = [] + answers = [] + for key in trip_hr: + queries.append(key) + answers.append(np.array(trip_hr[key])) + return queries, answers + + def get_neighbors(self, nodes, batchsize, mode='train'): + if mode == 'train': + KG = self.KG + M_sub = self.M_sub + else: + KG = self.tKG + M_sub = self.tM_sub + + # nodes: [N_ent_of_all_batch_last, 2] with (batch_idx, node_idx) + # [N_ent, N_ent_of_all_batch_last] + node_1hot = csr_matrix((np.ones(len(nodes)), (nodes[:, 1], nodes[:, 0])), shape=(self.n_ent, nodes.shape[0])) + # [N_fact, N_ent] * [N_ent, N_ent_of_all_batch_last] -> [N_fact, N_ent_of_all_batch_last] + edge_1hot = M_sub.dot(node_1hot) + # [2, N_edge_of_all_batch] with (fact_idx, batch_idx) + edges = np.nonzero(edge_1hot) + # {batch_idx} + {head, rela, tail} -> concat -> [N_edge_of_all_batch, 4] with (batch_idx, head, rela, tail) + sampled_edges = np.concatenate([np.expand_dims(edges[1], 1), KG[edges[0]]], axis=1) + sampled_edges = torch.LongTensor(sampled_edges).cuda() + + # indexing nodes | within/out of a batch | relative index + # note that node_idx is the absolute nodes idx in original KG + # head_nodes: [N_ent_of_all_batch_last, 2] with (batch_idx, node_idx) + # tail_nodes: [N_ent_of_all_batch_this, 2] with (batch_idx, node_idx) + # head_index: [N_edge_of_all_batch] with relative node idx + # tail_index: [N_edge_of_all_batch] with relative node idx + head_nodes, head_index = torch.unique(sampled_edges[:, [0, 1]], dim=0, sorted=True, return_inverse=True) + tail_nodes, tail_index = torch.unique(sampled_edges[:, [0, 3]], dim=0, sorted=True, return_inverse=True) + + # [N_edge_of_all_batch, 4] -> [N_edge_of_all_batch, 6] with (batch_idx, head, rela, tail, head_index, tail_index) + # node that the head_index and tail_index are of this layer + sampled_edges = torch.cat([sampled_edges, head_index.unsqueeze(1), tail_index.unsqueeze(1)], 1) + + # get new index for nodes in last layer + mask = sampled_edges[:, 2] == (self.n_rel * 2) + # old_nodes_new_idx: [N_ent_of_all_batch_last] + old_nodes_new_idx = tail_index[mask].sort()[0] + + return tail_nodes, sampled_edges, old_nodes_new_idx + + def get_batch(self, batch_idx, steps=2, data='train'): + if data == 'train': + return np.array(self.train_data)[batch_idx] + if data == 'valid': + query, answer = np.array(self.valid_q), np.array(self.valid_a) + if data == 'test': + query, answer = np.array(self.test_q), np.array(self.test_a) + + subs = [] + rels = [] + objs = [] + subs = query[batch_idx, 0] + rels = query[batch_idx, 1] + objs = np.zeros((len(batch_idx), self.n_ent)) + for i in range(len(batch_idx)): + objs[i][answer[batch_idx[i]]] = 1 + return subs, rels, objs + + def shuffle_train(self): + all_triple = self.all_triple + n_all = len(all_triple) + rand_idx = np.random.permutation(n_all) + all_triple = all_triple[rand_idx] + + bar = int(n_all * self.args.fact_ratio) + self.fact_data = np.array(self.double_triple(all_triple[:bar].tolist())) + self.train_data = np.array(self.double_triple(all_triple[bar:].tolist())) + + if self.args.remove_1hop_edges: + print('==> removing 1-hop links...') + tmp_index = np.ones((self.n_ent, self.n_ent)) + tmp_index[self.train_data[:, 0], self.train_data[:, 2]] = 0 + save_facts = tmp_index[self.fact_data[:, 0], self.fact_data[:, 2]].astype(bool) + self.fact_data = self.fact_data[save_facts] + print('==> done') + + self.n_train = len(self.train_data) + self.load_graph(self.fact_data) \ No newline at end of file diff --git a/openhgnn/dataset/Ingram_dataset.py b/openhgnn/dataset/Ingram_dataset.py new file mode 100644 index 00000000..7f866356 --- /dev/null +++ b/openhgnn/dataset/Ingram_dataset.py @@ -0,0 +1,237 @@ +import numpy as np +import random +import torch +import dgl +import time +import os +import igraph +import requests +import zipfile +import io + + +def remove_duplicate(x): + return list(dict.fromkeys(x)) + + +class UnionFind: + def __init__(self, n): + self.n = n + self.parent = list(range(n)) + + def find(self, x): + if self.parent[x] != x: + self.parent[x] = self.find(self.parent[x]) + return self.parent[x] + + def union(self, x, y): + self.parent[self.find(x)] = self.find(y) + + def connected(self, x, y): + return self.find(x) == self.find(y) + + +def kruskal(g): + uf = UnionFind(g.num_nodes()) + mst_edges = [] + mst_weights = [] + edge_index = [] + edges, weights = g.edges(), g.edata['w'] + indices = torch.argsort(weights) + for i in indices: + u, v = edges[0][i], edges[1][i] + if not uf.connected(u, v): + mst_edges.append((u, v)) + edge_index.append(int(i)) + uf.union(u, v) + if len(mst_edges) == g.num_nodes() - 1: + break + mst_g = dgl.graph(mst_edges) + edge_index = torch.tensor(edge_index) + return mst_g, edge_index + + +class Ingram_KG_TrainData(): + def __init__(self, path, dataset_name, *args, **kwargs): + super(Ingram_KG_TrainData, self).__init__(*args, **kwargs) + # 上线的时候要更改 + self.path = 'openhgnn/data/' + dataset_name + '/' + self.rel_info = {} # (h,t):[r1,r2,...] + self.pair_info = {} # r:[(h,t),(h,t),...] + self.spanning = [] # [(h,t),(h,t),...], + self.remaining = [] # [(h,t),(h,t),...], + self.ent2id = None # ent2id + self.rel2id = None # rel2id + self.id2ent, self.id2rel, self.triplets = self.read_triplet(self.path + 'train.txt') + self.num_triplets = len(self.triplets) + self.num_ent, self.num_rel = len(self.id2ent), len(self.id2rel) + self.dataset_name = dataset_name + + def read_triplet(self, path): + + + path_ckp = self.path + print(path_ckp) + folder = os.path.exists(path_ckp) + if not folder: + os.makedirs(path_ckp) + url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/NL-100.zip" + response = requests.get(url) + with zipfile.ZipFile(io.BytesIO(response.content)) as myzip: + myzip.extractall(self.path) + print("--- download data ---") + + else: + print("--- There is data! ---") + + + id2ent, id2rel, triplets = [], [], [] + with open(path, 'r') as f: + for line in f.readlines(): + h, r, t = line.strip().split('\t') + id2ent.append(h) + id2ent.append(t) + id2rel.append(r) + triplets.append((h, r, t)) + id2ent = remove_duplicate(id2ent) + id2rel = remove_duplicate(id2rel) + self.ent2id = {ent: idx for idx, ent in enumerate(id2ent)} + self.rel2id = {rel: idx for idx, rel in enumerate(id2rel)} + triplets = [(self.ent2id[h], self.rel2id[r], self.ent2id[t]) for h, r, t in triplets] + for (h, r, t) in triplets: + if (h, t) in self.rel_info: + self.rel_info[(h, t)].append(r) + else: + self.rel_info[(h, t)] = [r] + if r in self.pair_info: + self.pair_info[r].append((h, t)) + else: + self.pair_info[r] = [(h, t)] + G = igraph.Graph.TupleList(np.array(triplets)[:, 0::2]) + G_ent = igraph.Graph.TupleList(np.array(triplets)[:, 0::2], directed=True) + spanning = G_ent.spanning_tree() + G_ent.delete_edges(spanning.get_edgelist()) + print(spanning.es) + for e in spanning.es: + e1, e2 = e.tuple + e1 = spanning.vs[e1]["name"] + e2 = spanning.vs[e2]["name"] + self.spanning.append((e1, e2)) + + spanning_set = set(self.spanning) + + print("-----Train Data Statistics-----") + print(f"{len(self.ent2id)} entities, {len(self.rel2id)} relations") + print(f"{len(triplets)} triplets") + self.triplet2idx = {triplet: idx for idx, triplet in enumerate(triplets)} + self.triplets_with_inv = np.array([(t, r + len(id2rel), h) for h, r, t in triplets] + triplets) + return id2ent, id2rel, triplets + + def split_transductive(self, p): + msg, sup = [], [] + rels_encountered = np.zeros(self.num_rel) + + remaining_triplet_indexes = np.ones(self.num_triplets) + + for h, t in self.spanning: + r = random.choice(self.rel_info[(h, t)]) + msg.append((h, r, t)) + remaining_triplet_indexes[self.triplet2idx[(h, r, t)]] = 0 + rels_encountered[r] = 1 + for r in (1 - rels_encountered).nonzero()[0].tolist(): + h, t = random.choice(self.pair_info[int(r)]) + msg.append((h, r, t)) + remaining_triplet_indexes[self.triplet2idx[(h, r, t)]] = 0 + + start = time.time() + sup = [self.triplets[idx] for idx, tf in enumerate(remaining_triplet_indexes) if tf] + msg = np.array(msg) + random.shuffle(sup) + sup = np.array(sup) + add_num = max(int(self.num_triplets * p) - len(msg), 0) + msg = np.concatenate([msg, sup[:add_num]]) + sup = sup[add_num:] + + msg_inv = np.fliplr(msg).copy() + msg_inv[:, 1] += self.num_rel + msg = np.concatenate([msg, msg_inv]) + + return msg, sup + + +class Ingram_KG_TestData(): + def __init__(self, path, dataset_name, data_type="valid"): + self.path = 'openhgnn/data/' + dataset_name + '/' + self.data_type = data_type + self.ent2id = None + self.rel2id = None + self.id2ent, self.id2rel, self.msg_triplets, self.sup_triplets, self.filter_dict = self.read_triplet() + self.num_ent, self.num_rel = len(self.id2ent), len(self.id2rel) + + def read_triplet(self): + id2ent, id2rel, msg_triplets, sup_triplets = [], [], [], [] + total_triplets = [] + + + with open(self.path + "msg.txt", 'r') as f: + for line in f.readlines(): + h, r, t = line.strip().split('\t') + id2ent.append(h) + id2ent.append(t) + id2rel.append(r) + msg_triplets.append((h, r, t)) + total_triplets.append((h, r, t)) + + id2ent = remove_duplicate(id2ent) + id2rel = remove_duplicate(id2rel) + self.ent2id = {ent: idx for idx, ent in enumerate(id2ent)} + self.rel2id = {rel: idx for idx, rel in enumerate(id2rel)} + num_rel = len(self.rel2id) + msg_triplets = [(self.ent2id[h], self.rel2id[r], self.ent2id[t]) for h, r, t in msg_triplets] + msg_inv_triplets = [(t, r + num_rel, h) for h, r, t in msg_triplets] + + with open(self.path + self.data_type + ".txt", 'r') as f: + for line in f.readlines(): + h, r, t = line.strip().split('\t') + sup_triplets.append((self.ent2id[h], self.rel2id[r], self.ent2id[t])) + assert (self.ent2id[h], self.rel2id[r], self.ent2id[t]) not in msg_triplets, \ + (self.ent2id[h], self.rel2id[r], self.ent2id[t]) + total_triplets.append((h, r, t)) + for data_type in ['valid', 'test']: + if data_type == self.data_type: + continue + with open(self.path + data_type + ".txt", 'r') as f: + for line in f.readlines(): + h, r, t = line.strip().split('\t') + assert (self.ent2id[h], self.rel2id[r], self.ent2id[t]) not in msg_triplets, \ + (self.ent2id[h], self.rel2id[r], self.ent2id[t]) + total_triplets.append((h, r, t)) + + filter_dict = {} + for triplet in total_triplets: + h, r, t = triplet + if ('_', self.rel2id[r], self.ent2id[t]) not in filter_dict: + filter_dict[('_', self.rel2id[r], self.ent2id[t])] = [self.ent2id[h]] + else: + filter_dict[('_', self.rel2id[r], self.ent2id[t])].append(self.ent2id[h]) + + if (self.ent2id[h], '_', self.ent2id[t]) not in filter_dict: + filter_dict[(self.ent2id[h], '_', self.ent2id[t])] = [self.rel2id[r]] + else: + filter_dict[(self.ent2id[h], '_', self.ent2id[t])].append(self.rel2id[r]) + + if (self.ent2id[h], self.rel2id[r], '_') not in filter_dict: + filter_dict[(self.ent2id[h], self.rel2id[r], '_')] = [self.ent2id[t]] + else: + filter_dict[(self.ent2id[h], self.rel2id[r], '_')].append(self.ent2id[t]) + + print(f"-----{self.data_type.capitalize()} Data Statistics-----") + print(f"Message set has {len(msg_triplets)} triplets") + print(f"Supervision set has {len(sup_triplets)} triplets") + print(f"{len(self.ent2id)} entities, " + \ + f"{len(self.rel2id)} relations, " + \ + f"{len(total_triplets)} triplets") + + msg_triplets = msg_triplets + msg_inv_triplets + + return id2ent, id2rel, np.array(msg_triplets), np.array(sup_triplets), filter_dict diff --git a/openhgnn/dataset/LTE_dataset.py b/openhgnn/dataset/LTE_dataset.py new file mode 100644 index 00000000..e69de29b diff --git a/openhgnn/dataset/LinkPredictionDataset.py b/openhgnn/dataset/LinkPredictionDataset.py index a19fb042..9842f496 100644 --- a/openhgnn/dataset/LinkPredictionDataset.py +++ b/openhgnn/dataset/LinkPredictionDataset.py @@ -1,15 +1,28 @@ +import os.path + import dgl import math -import random +import re +from copy import deepcopy import numpy as np import torch as th +import itertools +import random +from random import shuffle, choice +from collections import Counter +from os.path import join as joinpath +from os.path import isfile from dgl.data.knowledge_graph import load_data from . import BaseDataset, register_dataset -from . import AcademicDataset, HGBDataset, OHGBDataset +from . import AcademicDataset, HGBDataset, OHGBDataset, NBF_Dataset from ..utils import add_reverse_edges +from collections import defaultdict +import os +from scipy.sparse import csr_matrix __all__ = ['LinkPredictionDataset', 'HGB_LinkPrediction'] + @register_dataset('link_prediction') class LinkPredictionDataset(BaseDataset): """ @@ -76,24 +89,24 @@ def get_split(self, val_ratio=0.1, test_ratio=0.2): else: if 'valid_mask' not in self.g.edges[etype].data: train_idx = self.g.edges[etype].data['train_mask'] - random_int = th.randperm(int(train_idx.sum())) - val_index = random_int[:int(train_idx.sum() * val_ratio)] + random_int = th.randperm(int(train_idx.sum( ))) + val_index = random_int[:int(train_idx.sum( ) * val_ratio)] val_edge = self.g.find_edges(val_index, etype) else: - val_mask = self.g.edges[etype].data['valid_mask'].squeeze() - val_index = th.nonzero(val_mask).squeeze() + val_mask = self.g.edges[etype].data['valid_mask'].squeeze( ) + val_index = th.nonzero(val_mask).squeeze( ) val_edge = self.g.find_edges(val_index, etype) - test_mask = self.g.edges[etype].data['test_mask'].squeeze() - test_index = th.nonzero(test_mask).squeeze() + test_mask = self.g.edges[etype].data['test_mask'].squeeze( ) + test_index = th.nonzero(test_mask).squeeze( ) test_edge = self.g.find_edges(test_index, etype) val_edge_dict[etype] = val_edge test_edge_dict[etype] = test_edge out_ntypes.append(etype[0]) out_ntypes.append(etype[2]) - #self.val_label = train_graph.edges[etype[1]].data['label'][val_index] + # self.val_label = train_graph.edges[etype[1]].data['label'][val_index] self.test_label = train_graph.edges[etype[1]].data['label'][test_index] train_graph = dgl.remove_edges(train_graph, th.cat((val_index, test_index)), etype) @@ -112,17 +125,17 @@ def get_split(self, val_ratio=0.1, test_ratio=0.2): @register_dataset('demo_link_prediction') class Test_LinkPrediction(LinkPredictionDataset): def __init__(self, dataset_name): - super(Test_LinkPrediction, self).__init__() + super(Test_LinkPrediction, self).__init__( ) self.g = self.load_HIN('./openhgnn/debug/data.bin') self.target_link = 'user-item' self.has_feature = False self.meta_paths_dict = None - self.preprocess() + self.preprocess( ) # self.generate_negative() def preprocess(self): test_mask = self.g.edges[self.target_link].data['test_mask'] - index = th.nonzero(test_mask).squeeze() + index = th.nonzero(test_mask).squeeze( ) self.test_edge = self.g.find_edges(index, self.target_link) self.pos_test_graph = dgl.heterograph({('user', 'user-item', 'item'): self.test_edge}, {ntype: self.g.number_of_nodes(ntype) for ntype in ['user', 'item']}) @@ -134,10 +147,10 @@ def preprocess(self): def generate_negative(self): k = 99 - e = self.pos_test_graph.edges() + e = self.pos_test_graph.edges( ) neg_src = [] neg_dst = [] - for i in range(self.pos_test_graph.number_of_edges()): + for i in range(self.pos_test_graph.number_of_edges( )): src = e[0][i] exp = self.pos_test_graph.successors(src) dst = th.randint(high=self.g.number_of_nodes('item'), size=(k,)) @@ -164,8 +177,8 @@ def load_link_pred(self, path): v_list = [] label_list = [] with open(path) as f: - for i in f.readlines(): - u, v, label = i.strip().split(', ') + for i in f.readlines( ): + u, v, label = i.strip( ).split(', ') u_list.append(int(u)) v_list.append(int(v)) label_list.append(int(label)) @@ -176,7 +189,7 @@ def load_HIN(self, dataset_name): if dataset_name == 'academic4HetGNN': # which is used in HetGNN dataset = AcademicDataset(name='academic4HetGNN', raw_dir='') - g = dataset[0].long() + g = dataset[0].long( ) self.train_batch = self.load_link_pred('./openhgnn/dataset/' + dataset_name + '/a_a_list_train.txt') self.test_batch = self.load_link_pred('./openhgnn/dataset/' + dataset_name + '/a_a_list_test.txt') @@ -188,29 +201,29 @@ def load_HIN(self, dataset_name): self.node_type = ['user', 'item'] elif dataset_name == 'amazon4SLICE': dataset = AcademicDataset(name='amazon4SLICE', raw_dir='') - g = dataset[0].long() + g = dataset[0].long( ) elif dataset_name == 'MTWM': dataset = AcademicDataset(name='MTWM', raw_dir='') - g = dataset[0].long() + g = dataset[0].long( ) g = add_reverse_edges(g) self.target_link = [('user', 'user-buy-spu', 'spu')] self.target_link_r = [('spu', 'user-buy-spu-rev', 'user')] self.meta_paths_dict = { - 'UPU1':[('user','user-buy-poi','poi'),('poi','user-buy-poi-rev','user')], - 'UPU2':[('user','user-click-poi','poi'),('poi','user-click-poi-rev','user')], - 'USU':[('user','user-buy-spu','spu'),('spu','user-buy-spu-rev','user')], - 'UPSPU1': [('user','user-buy-poi','poi'),('poi','poi-contain-spu','spu'), - ('spu','poi-contain-spu-rev','poi'),('poi','user-buy-poi-rev','user') + 'UPU1': [('user', 'user-buy-poi', 'poi'), ('poi', 'user-buy-poi-rev', 'user')], + 'UPU2': [('user', 'user-click-poi', 'poi'), ('poi', 'user-click-poi-rev', 'user')], + 'USU': [('user', 'user-buy-spu', 'spu'), ('spu', 'user-buy-spu-rev', 'user')], + 'UPSPU1': [('user', 'user-buy-poi', 'poi'), ('poi', 'poi-contain-spu', 'spu'), + ('spu', 'poi-contain-spu-rev', 'poi'), ('poi', 'user-buy-poi-rev', 'user') ], - 'UPSPU2':[ - ('user','user-click-poi','poi'), ('poi','poi-contain-spu','spu'), - ('spu','poi-contain-spu-rev','poi'),('poi','user-click-poi-rev','user') - ] + 'UPSPU2': [ + ('user', 'user-click-poi', 'poi'), ('poi', 'poi-contain-spu', 'spu'), + ('spu', 'poi-contain-spu-rev', 'poi'), ('poi', 'user-click-poi-rev', 'user') + ] } self.node_type = ['user', 'spu'] elif dataset_name == 'HGBl-ACM': dataset = HGBDataset(name='HGBn-ACM', raw_dir='') - g = dataset[0].long() + g = dataset[0].long( ) self.has_feature = True self.target_link = [('paper', 'paper-ref-paper', 'paper')] self.node_type = ['author', 'paper', 'subject', 'term'] @@ -233,7 +246,7 @@ def load_HIN(self, dataset_name): } elif dataset_name == 'HGBl-DBLP': dataset = HGBDataset(name='HGBn-DBLP', raw_dir='') - g = dataset[0].long() + g = dataset[0].long( ) self.has_feature = True self.target_link = [('author', 'author-paper', 'paper')] self.node_type = ['author', 'paper', 'venue', 'term'] @@ -250,7 +263,7 @@ def load_HIN(self, dataset_name): elif dataset_name == 'HGBl-IMDB': dataset = HGBDataset(name='HGBn-IMDB', raw_dir='') - g = dataset[0].long() + g = dataset[0].long( ) self.has_feature = True # self.target_link = [('author', 'author-paper', 'paper')] # self.node_type = ['author', 'paper', 'subject', 'term'] @@ -268,9 +281,9 @@ def load_HIN(self, dataset_name): 'AMA': [('actor', 'actor->movie', 'movie'), ('movie', 'movie->actor', 'actor')], 'AMDMA': [('actor', 'actor->movie', 'movie'), ('movie', 'movie->director', 'director'), ('director', 'director->movie', 'movie'), ('movie', 'movie->actor', 'actor')] - } + } return g - + def get_split(self, val_ratio=0.1, test_ratio=0.2): if self.dataset_name == 'academic4HetGNN': return None, None, None, None, None @@ -304,7 +317,7 @@ def __init__(self, dataset_name, *args, **kwargs): self.target_link_r = None if dataset_name == 'HGBl-amazon': dataset = HGBDataset(name=dataset_name, raw_dir='') - g = dataset[0].long() + g = dataset[0].long( ) self.has_feature = False self.target_link = [('product', 'product-product-0', 'product'), ('product', 'product-product-1', 'product')] @@ -319,7 +332,7 @@ def __init__(self, dataset_name, *args, **kwargs): elif dataset_name == 'HGBl-LastFM': dataset = HGBDataset(name=dataset_name, raw_dir='') - g = dataset[0].long() + g = dataset[0].long( ) self.has_feature = False self.target_link = [('user', 'user-artist', 'artist')] self.node_type = ['user', 'artist', 'tag'] @@ -337,7 +350,7 @@ def __init__(self, dataset_name, *args, **kwargs): elif dataset_name == 'HGBl-PubMed': dataset = HGBDataset(name=dataset_name, raw_dir='') - g = dataset[0].long() + g = dataset[0].long( ) self.has_feature = True self.target_link = [('1', '1_to_1', '1')] self.node_type = ['0', '1', '2', '3'] @@ -351,7 +364,7 @@ def __init__(self, dataset_name, *args, **kwargs): } self.g = g - self.shift_dict = self.calculate_node_shift() + self.shift_dict = self.calculate_node_shift( ) def load_link_pred(self, path): return @@ -376,14 +389,14 @@ def get_split(self): train_graph = self.g val_ratio = 0.1 for i, etype in enumerate(self.target_link): - train_mask = self.g.edges[etype].data['train_mask'].squeeze() - train_index = th.nonzero(train_mask).squeeze() + train_mask = self.g.edges[etype].data['train_mask'].squeeze( ) + train_index = th.nonzero(train_mask).squeeze( ) random_int = th.randperm(len(train_index))[:int(len(train_index) * val_ratio)] val_index = train_index[random_int] val_edge = self.g.find_edges(val_index, etype) - test_mask = self.g.edges[etype].data['test_mask'].squeeze() - test_index = th.nonzero(test_mask).squeeze() + test_mask = self.g.edges[etype].data['test_mask'].squeeze( ) + test_index = th.nonzero(test_mask).squeeze( ) test_edge = self.g.find_edges(test_index, etype) val_edge_dict[etype] = val_edge @@ -405,7 +418,7 @@ def get_split(self): return train_graph, val_graph, test_graph, None, None def save_results(self, hg, score, file_path): - with hg.local_scope(): + with hg.local_scope( ): src_list = [] dst_list = [] edge_type_list = [] @@ -433,28 +446,27 @@ def __init__(self, dataset_name, *args, **kwargs): self.has_feature = True if dataset_name == 'ohgbl-MTWM': dataset = OHGBDataset(name=dataset_name, raw_dir='') - g = dataset[0].long() + g = dataset[0].long( ) self.target_link = [('user', 'user-buy-spu', 'spu')] self.target_link_r = [('spu', 'user-buy-spu-rev', 'user')] self.node_type = ['user', 'spu'] elif dataset_name == 'ohgbl-yelp1': dataset = OHGBDataset(name=dataset_name, raw_dir='') - g = dataset[0].long() + g = dataset[0].long( ) self.target_link = [('user', 'user-buy-business', 'business')] self.target_link_r = [('business', 'user-buy-business-rev', 'user')] elif dataset_name == 'ohgbl-yelp2': dataset = OHGBDataset(name=dataset_name, raw_dir='') - g = dataset[0].long() + g = dataset[0].long( ) self.target_link = [('business', 'described-with', 'phrase')] self.target_link_r = [('business', 'described-with-rev', 'phrase')] elif dataset_name == 'ohgbl-Freebase': dataset = OHGBDataset(name=dataset_name, raw_dir='') - g = dataset[0].long() - self.target_link = [('BOOK','BOOK-and-BOOK','BOOK')] - self.target_link_r = [('BOOK','BOOK-and-BOOK-rev','BOOK')] + g = dataset[0].long( ) + self.target_link = [('BOOK', 'BOOK-and-BOOK', 'BOOK')] + self.target_link_r = [('BOOK', 'BOOK-and-BOOK-rev', 'BOOK')] self.g = g - - + def build_graph_from_triplets(num_nodes, num_rels, triplets): """ Create a DGL graph. The graph is bidirectional because RGCN authors use reversed relations. @@ -467,7 +479,7 @@ def build_graph_from_triplets(num_nodes, num_rels, triplets): src, dst = np.concatenate((src, dst)), np.concatenate((dst, src)) rel = np.concatenate((rel, rel + num_rels)) edges = sorted(zip(dst, src, rel)) - dst, src, rel = np.array(edges).transpose() + dst, src, rel = np.array(edges).transpose( ) g.add_edges(src, dst) norm = comp_deg_norm(g) print("# nodes: {}, # edges: {}".format(num_nodes, len(src))) @@ -475,18 +487,408 @@ def build_graph_from_triplets(num_nodes, num_rels, triplets): def comp_deg_norm(g): - g = g.local_var() - in_deg = g.in_degrees(range(g.number_of_nodes())).float().numpy() + g = g.local_var( ) + in_deg = g.in_degrees(range(g.number_of_nodes( ))).float( ).numpy( ) norm = 1.0 / in_deg norm[np.isinf(norm)] = 0 return norm +@register_dataset('kg_sub_link_prediction') +class KG_RedDataset(LinkPredictionDataset): + def __init__(self, dataset_name, *args, **kwargs): + super(KG_RedDataset, self).__init__(*args, **kwargs) + self.trans_dir = os.path.join('openhgnn/dataset/data', dataset_name) + self.ind_dir = self.trans_dir + '_ind' + + folder = os.path.exists(self.trans_dir) + if not folder: + os.makedirs(self.trans_dir) + url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/fb237_v1.zip" + response = requests.get(url) + with zipfile.ZipFile(io.BytesIO(response.content)) as myzip: + myzip.extractall(self.trans_dir) + print("--- download data ---") + + else: + print("--- There is data! ---") + + folder = os.path.exists(self.ind_dir) + if not folder: + os.makedirs(self.ind_dir) + # 下载数据 + url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/fb237_v1_ind.zip" + response = requests.get(url) + with zipfile.ZipFile(io.BytesIO(response.content)) as myzip: + myzip.extractall(self.ind_dir) + print("--- download data ---") + + else: + print("--- There is data! ---") + + with open(os.path.join(self.trans_dir, 'entities.txt')) as f: + self.entity2id = dict() + for line in f: + entity, eid = line.strip().split() + self.entity2id[entity] = int(eid) + + with open(os.path.join(self.trans_dir, 'relations.txt')) as f: + self.relation2id = dict() + id2relation = [] + for line in f: + relation, rid = line.strip().split() + self.relation2id[relation] = int(rid) + id2relation.append(relation) + + with open(os.path.join(self.ind_dir, 'entities.txt')) as f: + self.entity2id_ind = dict() + for line in f: + entity, eid = line.strip().split() + self.entity2id_ind[entity] = int(eid) + + for i in range(len(self.relation2id)): + id2relation.append(id2relation[i] + '_inv') + id2relation.append('idd') + self.id2relation = id2relation + + self.n_ent = len(self.entity2id) + self.n_rel = len(self.relation2id) + self.n_ent_ind = len(self.entity2id_ind) + + self.tra_train = self.read_triples(self.trans_dir, 'train.txt') + self.tra_valid = self.read_triples(self.trans_dir, 'valid.txt') + self.tra_test = self.read_triples(self.trans_dir, 'test.txt') + self.ind_train = self.read_triples(self.ind_dir, 'train.txt', 'inductive') + self.ind_valid = self.read_triples(self.ind_dir, 'valid.txt', 'inductive') + self.ind_test = self.read_triples(self.ind_dir, 'test.txt', 'inductive') + + self.val_filters = self.get_filter('valid') + self.tst_filters = self.get_filter('test') + + for filt in self.val_filters: + self.val_filters[filt] = list(self.val_filters[filt]) + for filt in self.tst_filters: + self.tst_filters[filt] = list(self.tst_filters[filt]) + + self.tra_KG, self.tra_sub = self.load_graph(self.tra_train) + self.ind_KG, self.ind_sub = self.load_graph(self.ind_train, 'inductive') + + self.tra_train = np.array(self.tra_valid) + self.tra_val_qry, self.tra_val_ans = self.load_query(self.tra_test) + self.ind_val_qry, self.ind_val_ans = self.load_query(self.ind_valid) + self.ind_tst_qry, self.ind_tst_ans = self.load_query(self.ind_test) + self.valid_q, self.valid_a = self.tra_val_qry, self.tra_val_ans + self.test_q, self.test_a = self.ind_val_qry + self.ind_tst_qry, self.ind_val_ans + self.ind_tst_ans + + self.n_train = len(self.tra_train) + self.n_valid = len(self.valid_q) + self.n_test = len(self.test_q) + + + def read_triples(self, directory, filename, mode='transductive'): + triples = [] + with open(os.path.join(directory, filename)) as f: + for line in f: + h, r, t = line.strip().split() + if mode == 'transductive': + h, r, t = self.entity2id[h], self.relation2id[r], self.entity2id[t] + else: + h, r, t = self.entity2id_ind[h], self.relation2id[r], self.entity2id_ind[t] + + triples.append([h, r, t]) + triples.append([t, r + self.n_rel, h]) + return triples + + def load_graph(self, triples, mode='transductive'): + n_ent = self.n_ent if mode == 'transductive' else self.n_ent_ind + + KG = np.array(triples) + idd = np.concatenate([np.expand_dims(np.arange(n_ent), 1), 2 * self.n_rel * np.ones((n_ent, 1)), + np.expand_dims(np.arange(n_ent), 1)], 1) + KG = np.concatenate([KG, idd], 0) + + n_fact = KG.shape[0] + + M_sub = csr_matrix((np.ones((n_fact,)), (np.arange(n_fact), KG[:, 0])), shape=(n_fact, n_ent)) + return KG, M_sub + + def load_query(self, triples): + triples.sort(key=lambda x: (x[0], x[1])) + trip_hr = defaultdict(lambda: list()) + + for trip in triples: + h, r, t = trip + trip_hr[(h, r)].append(t) + + queries = [] + answers = [] + for key in trip_hr: + queries.append(key) + answers.append(np.array(trip_hr[key])) + return queries, answers + + def get_neighbors(self, nodes, mode='transductive'): + # nodes: n_node x 2 with (batch_idx, node_idx) + + if mode == 'transductive': + KG = self.tra_KG + M_sub = self.tra_sub + n_ent = self.n_ent + else: + KG = self.ind_KG + M_sub = self.ind_sub + n_ent = self.n_ent_ind + + node_1hot = csr_matrix((np.ones(len(nodes)), (nodes[:, 1], nodes[:, 0])), shape=(n_ent, nodes.shape[0])) + edge_1hot = M_sub.dot(node_1hot) + edges = np.nonzero(edge_1hot) + sampled_edges = np.concatenate([np.expand_dims(edges[1], 1), KG[edges[0]]], + axis=1) # (batch_idx, head, rela, tail) + sampled_edges = th.LongTensor(sampled_edges) + # index to nodes + head_nodes, head_index = th.unique(sampled_edges[:, [0, 1]], dim=0, sorted=True, return_inverse=True) + tail_nodes, tail_index = th.unique(sampled_edges[:, [0, 3]], dim=0, sorted=True, return_inverse=True) + + mask = sampled_edges[:, 2] == (self.n_rel * 2) + _, old_idx = head_index[mask].sort() + old_nodes_new_idx = tail_index[mask][old_idx] + + sampled_edges = th.cat([sampled_edges, head_index.unsqueeze(1), tail_index.unsqueeze(1)], 1) + + return tail_nodes, sampled_edges, old_nodes_new_idx + + def get_batch(self, batch_idx, steps=2, data='train'): + if data == 'train': + return self.tra_train[batch_idx] + if data == 'valid': + # print(self.) + query, answer = np.array(self.valid_q), self.valid_a # np.array(self.valid_a) + n_ent = self.n_ent + if data == 'test': + query, answer = np.array(self.test_q), self.test_a # np.array(self.test_a) + n_ent = self.n_ent_ind + + subs = [] + rels = [] + objs = [] + + subs = query[batch_idx, 0] + rels = query[batch_idx, 1] + objs = np.zeros((len(batch_idx), n_ent)) + for i in range(len(batch_idx)): + objs[i][answer[batch_idx[i]]] = 1 + return subs, rels, objs + + def shuffle_train(self, ): + rand_idx = np.random.permutation(self.n_train) + self.tra_train = self.tra_train[rand_idx] + + def get_filter(self, data='valid'): + filters = defaultdict(lambda: set()) + if data == 'valid': + for triple in self.tra_train: + h, r, t = triple + filters[(h, r)].add(t) + for triple in self.tra_valid: + h, r, t = triple + filters[(h, r)].add(t) + for triple in self.tra_test: + h, r, t = triple + filters[(h, r)].add(t) + else: + for triple in self.ind_train: + h, r, t = triple + filters[(h, r)].add(t) + for triple in self.ind_valid: + h, r, t = triple + filters[(h, r)].add(t) + for triple in self.ind_test: + h, r, t = triple + filters[(h, r)].add(t) + return filters + + +@register_dataset('kg_subT_link_prediction') +class KG_RedTDataset(LinkPredictionDataset): + def __init__(self, dataset_name, *args, **kwargs): + super(KG_RedTDataset, self).__init__(*args, **kwargs) + self.task_dir = os.path.join('openhgnn/dataset/data', dataset_name) + task_dir = self.task_dir + folder = os.path.exists(self.task_dir) + if not folder: + os.makedirs(self.task_dir) + url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/family.zip" + response = requests.get(url) + with zipfile.ZipFile(io.BytesIO(response.content)) as myzip: + myzip.extractall(self.task_dir) + print("--- download data ---") + + else: + print("--- There is data! ---") + + with open(os.path.join(task_dir, 'entities.txt')) as f: + self.entity2id = dict() + n_ent = 0 + for line in f: + entity = line.strip() + self.entity2id[entity] = n_ent + n_ent += 1 + + with open(os.path.join(task_dir, 'relations.txt')) as f: + self.relation2id = dict() + n_rel = 0 + for line in f: + relation = line.strip() + self.relation2id[relation] = n_rel + n_rel += 1 + + self.n_ent = n_ent + self.n_rel = n_rel + + self.filters = defaultdict(lambda: set()) + + self.fact_triple = self.read_triples('facts.txt') + self.train_triple = self.read_triples('train.txt') + self.valid_triple = self.read_triples('valid.txt') + self.test_triple = self.read_triples('test.txt') + + self.fact_data = self.double_triple(self.fact_triple) + self.train_data = np.array(self.double_triple(self.train_triple)) + self.valid_data = self.double_triple(self.valid_triple) + self.test_data = self.double_triple(self.test_triple) + + self.load_graph(self.fact_data) + self.load_test_graph(self.double_triple(self.fact_triple) + self.double_triple(self.train_triple)) + + self.valid_q, self.valid_a = self.load_query(self.valid_data) + self.test_q, self.test_a = self.load_query(self.test_data) + + self.n_train = len(self.train_data) + self.n_valid = len(self.valid_q) + self.n_test = len(self.test_q) + + for filt in self.filters: + self.filters[filt] = list(self.filters[filt]) + + print('n_train:', self.n_train, 'n_valid:', self.n_valid, 'n_test:', self.n_test) + + def read_triples(self, filename): + triples = [] + with open(os.path.join(self.task_dir, filename)) as f: + for line in f: + h, r, t = line.strip().split() + h, r, t = self.entity2id[h], self.relation2id[r], self.entity2id[t] + triples.append([h, r, t]) + self.filters[(h, r)].add(t) + self.filters[(t, r + self.n_rel)].add(h) + return triples + + def double_triple(self, triples): + new_triples = [] + for triple in triples: + h, r, t = triple + new_triples.append([t, r + self.n_rel, h]) + return triples + new_triples + + def load_graph(self, triples): + idd = np.concatenate([np.expand_dims(np.arange(self.n_ent), 1), 2 * self.n_rel * np.ones((self.n_ent, 1)), + np.expand_dims(np.arange(self.n_ent), 1)], 1) + + self.KG = np.concatenate([np.array(triples), idd], 0) + self.n_fact = len(self.KG) + self.M_sub = csr_matrix((np.ones((self.n_fact,)), (np.arange(self.n_fact), self.KG[:, 0])), + shape=(self.n_fact, self.n_ent)) + + def load_test_graph(self, triples): + idd = np.concatenate([np.expand_dims(np.arange(self.n_ent), 1), 2 * self.n_rel * np.ones((self.n_ent, 1)), + np.expand_dims(np.arange(self.n_ent), 1)], 1) + + self.tKG = np.concatenate([np.array(triples), idd], 0) + self.tn_fact = len(self.tKG) + self.tM_sub = csr_matrix((np.ones((self.tn_fact,)), (np.arange(self.tn_fact), self.tKG[:, 0])), + shape=(self.tn_fact, self.n_ent)) + + def load_query(self, triples): + triples.sort(key=lambda x: (x[0], x[1])) + trip_hr = defaultdict(lambda: list()) + + for trip in triples: + h, r, t = trip + trip_hr[(h, r)].append(t) + + queries = [] + answers = [] + for key in trip_hr: + queries.append(key) + answers.append(np.array(trip_hr[key])) + return queries, answers + + def get_neighbors(self, nodes, mode='train'): + if mode == 'train': + KG = self.KG + M_sub = self.M_sub + else: + KG = self.tKG + M_sub = self.tM_sub + + # nodes: n_node x 2 with (batch_idx, node_idx) + node_1hot = csr_matrix((np.ones(len(nodes)), (nodes[:, 1], nodes[:, 0])), shape=(self.n_ent, nodes.shape[0])) + edge_1hot = M_sub.dot(node_1hot) + edges = np.nonzero(edge_1hot) + sampled_edges = np.concatenate([np.expand_dims(edges[1], 1), KG[edges[0]]], + axis=1) # (batch_idx, head, rela, tail) + sampled_edges = torch.LongTensor(sampled_edges).cuda() + + # index to nodes + head_nodes, head_index = torch.unique(sampled_edges[:, [0, 1]], dim=0, sorted=True, return_inverse=True) + tail_nodes, tail_index = torch.unique(sampled_edges[:, [0, 3]], dim=0, sorted=True, return_inverse=True) + + sampled_edges = torch.cat([sampled_edges, head_index.unsqueeze(1), tail_index.unsqueeze(1)], 1) + + mask = sampled_edges[:, 2] == (self.n_rel * 2) + _, old_idx = head_index[mask].sort() + old_nodes_new_idx = tail_index[mask][old_idx] + + return tail_nodes, sampled_edges, old_nodes_new_idx + + def get_batch(self, batch_idx, steps=2, data='train'): + if data == 'train': + return np.array(self.train_data)[batch_idx] + if data == 'valid': + query, answer = np.array(self.valid_q), self.valid_a + if data == 'test': + query, answer = np.array(self.test_q), self.test_a + + subs = [] + rels = [] + objs = [] + + subs = query[batch_idx, 0] + rels = query[batch_idx, 1] + objs = np.zeros((len(batch_idx), self.n_ent)) + for i in range(len(batch_idx)): + objs[i][answer[batch_idx[i]]] = 1 + return subs, rels, objs + + def shuffle_train(self, ): + fact_triple = np.array(self.fact_triple) + train_triple = np.array(self.train_triple) + all_triple = np.concatenate([fact_triple, train_triple], axis=0) + n_all = len(all_triple) + rand_idx = np.random.permutation(n_all) + all_triple = all_triple[rand_idx] + + # increase the ratio of fact_data, e.g., 3/4->4/5, can increase the performance + self.fact_data = self.double_triple(all_triple[:n_all * 3 // 4].tolist()) + self.train_data = np.array(self.double_triple(all_triple[n_all * 3 // 4:].tolist())) + self.n_train = len(self.train_data) + self.load_graph(self.fact_data) @register_dataset('kg_link_prediction') class KG_LinkPrediction(LinkPredictionDataset): """ From `RGCN `_, WN18 & FB15k face a data leakage. """ + def __init__(self, dataset_name, *args, **kwargs): super(KG_LinkPrediction, self).__init__(*args, **kwargs) if dataset_name in ['wn18', 'FB15k', 'FB15k-237']: @@ -504,8 +906,8 @@ def __init__(self, dataset_name, *args, **kwargs): self.target_link = self.test_hg.canonical_etypes def _build_hg(self, g, mode): - sub_g = dgl.edge_subgraph(g, g.edata[mode+'_edge_mask'], relabel_nodes=False) - src, dst = sub_g.edges() + sub_g = dgl.edge_subgraph(g, g.edata[mode + '_edge_mask'], relabel_nodes=False) + src, dst = sub_g.edges( ) etype = sub_g.edata['etype'] edge_dict = {} @@ -519,9 +921,11 @@ def _build_hg(self, g, mode): def modify_size(self, eval_percent, dataset_type): if dataset_type == 'valid': - self.valid_triplets = th.tensor(random.sample(self.valid_triplets.tolist(), math.ceil(self.valid_triplets.shape[0]*eval_percent))) + self.valid_triplets = th.tensor( + random.sample(self.valid_triplets.tolist( ), math.ceil(self.valid_triplets.shape[0] * eval_percent))) elif dataset_type == 'test': - self.test_triplets = th.tensor(random.sample(self.test_triplets.tolist(), math.ceil(self.test_triplets.shape[0]*eval_percent))) + self.test_triplets = th.tensor( + random.sample(self.test_triplets.tolist( ), math.ceil(self.test_triplets.shape[0] * eval_percent))) def get_graph_directed_from_triples(self, triples, format='graph'): s = th.LongTensor(triples[:, 0]) @@ -541,7 +945,7 @@ def get_triples(self, g, mask_mode): :param mask_mode: should be one of 'train_mask', 'val_mask', 'test_mask :return: ''' - edges = g.edges() + edges = g.edges( ) etype = g.edata['etype'] mask = g.edata.pop(mask_mode) return th.stack((edges[0][mask], etype[mask], edges[1][mask])) @@ -569,7 +973,7 @@ def split_graph(self, g, mode='train'): ------- hg: DGLHeterograph """ - edges = g.edges() + edges = g.edges( ) etype = g.edata['etype'] if mode == 'train': mask = g.edata['train_mask'] @@ -601,8 +1005,260 @@ def build_g(self, train): hg = dgl.heterograph(edge_dict, {self.category: self.num_nodes}) return hg +import torch +import struct +import os +import json +import logging +from scipy.sparse import csc_matrix +from scipy.special import softmax +from tqdm import tqdm +import pickle +import scipy.sparse as ssp +import lmdb +import requests +import zipfile +import io +from torch.utils.data import Dataset +import networkx as nx +from ..utils.Grail_utils import * +class SubGraphDataset(Dataset): + def __init__(self, db_path, db_name_pos, db_name_neg, raw_data_paths, included_relations=None, + add_traspose_rels=False, num_neg_samples_per_link=1, use_kge_embeddings=False, dataset='', + kge_model='', file_name=''): + + self.main_env = lmdb.open(db_path, readonly= True, max_dbs=3, lock=False) + self.db_pos = self.main_env.open_db(db_name_pos.encode()) + self.db_neg = self.main_env.open_db(db_name_neg.encode()) + self.node_features, self.kge_entity2id = get_kge_embeddings(dataset, kge_model) if use_kge_embeddings else (None, None) + self.num_neg_samples_per_link = num_neg_samples_per_link + self.file_name = file_name + self.add_traspose_rels = add_traspose_rels + + ssp_graph, __, __, __, id2entity, id2relation = process_files(raw_data_paths, included_relations) + self.num_rels = len(ssp_graph) + + # Add transpose matrices to handle both directions of relations. + if add_traspose_rels: + ssp_graph_t = [adj.T for adj in ssp_graph] + ssp_graph += ssp_graph_t + + # the effective number of relations after adding symmetric adjacency matrices and/or self connections + self.aug_num_rels = len(ssp_graph) + self.graph = ssp_multigraph_to_dgl(ssp_graph) + self.ssp_graph = ssp_graph + self.id2entity = id2entity + self.id2relation = id2relation + + self.max_n_label = np.array([0, 0]) + with self.main_env.begin() as txn: + #a = txn.get('max_n_label_sub'.encode()) + #print(a) + self.max_n_label[0] = int.from_bytes(txn.get('max_n_label_sub'.encode()), byteorder='little') + self.max_n_label[1] = int.from_bytes(txn.get('max_n_label_obj'.encode()), byteorder='little') + + self.avg_subgraph_size = struct.unpack('f', txn.get('avg_subgraph_size'.encode())) + self.min_subgraph_size = struct.unpack('f', txn.get('min_subgraph_size'.encode())) + self.max_subgraph_size = struct.unpack('f', txn.get('max_subgraph_size'.encode())) + self.std_subgraph_size = struct.unpack('f', txn.get('std_subgraph_size'.encode())) + + self.avg_enc_ratio = struct.unpack('f', txn.get('avg_enc_ratio'.encode())) + self.min_enc_ratio = struct.unpack('f', txn.get('min_enc_ratio'.encode())) + self.max_enc_ratio = struct.unpack('f', txn.get('max_enc_ratio'.encode())) + self.std_enc_ratio = struct.unpack('f', txn.get('std_enc_ratio'.encode())) + + self.avg_num_pruned_nodes = struct.unpack('f', txn.get('avg_num_pruned_nodes'.encode())) + self.min_num_pruned_nodes = struct.unpack('f', txn.get('min_num_pruned_nodes'.encode())) + self.max_num_pruned_nodes = struct.unpack('f', txn.get('max_num_pruned_nodes'.encode())) + self.std_num_pruned_nodes = struct.unpack('f', txn.get('std_num_pruned_nodes'.encode())) + + logging.info(f"Max distance from sub : {self.max_n_label[0]}, Max distance from obj : {self.max_n_label[1]}") + + # logging.info('=====================') + # logging.info(f"Subgraph size stats: \n Avg size {self.avg_subgraph_size}, \n Min size {self.min_subgraph_size}, \n Max size {self.max_subgraph_size}, \n Std {self.std_subgraph_size}") + + # logging.info('=====================') + # logging.info(f"Enclosed nodes ratio stats: \n Avg size {self.avg_enc_ratio}, \n Min size {self.min_enc_ratio}, \n Max size {self.max_enc_ratio}, \n Std {self.std_enc_ratio}") + + # logging.info('=====================') + # logging.info(f"# of pruned nodes stats: \n Avg size {self.avg_num_pruned_nodes}, \n Min size {self.min_num_pruned_nodes}, \n Max size {self.max_num_pruned_nodes}, \n Std {self.std_num_pruned_nodes}") + + with self.main_env.begin(db=self.db_pos) as txn: + self.num_graphs_pos = int.from_bytes(txn.get('num_graphs'.encode()), byteorder='little') + with self.main_env.begin(db=self.db_neg) as txn: + self.num_graphs_neg = int.from_bytes(txn.get('num_graphs'.encode()), byteorder='little') + + self.__getitem__(0) + + def __getitem__(self, index): + with self.main_env.begin(db=self.db_pos) as txn: + str_id = '{:08}'.format(index).encode('ascii') + nodes_pos, r_label_pos, g_label_pos, n_labels_pos = deserialize(txn.get(str_id)).values() + subgraph_pos = self._prepare_subgraphs(nodes_pos, r_label_pos, n_labels_pos) + subgraphs_neg = [] + r_labels_neg = [] + g_labels_neg = [] + with self.main_env.begin(db=self.db_neg) as txn: + for i in range(self.num_neg_samples_per_link): + str_id = '{:08}'.format(index + i * (self.num_graphs_pos)).encode('ascii') + nodes_neg, r_label_neg, g_label_neg, n_labels_neg = deserialize(txn.get(str_id)).values() + subgraphs_neg.append(self._prepare_subgraphs(nodes_neg, r_label_neg, n_labels_neg)) + r_labels_neg.append(r_label_neg) + g_labels_neg.append(g_label_neg) + + return subgraph_pos, g_label_pos, r_label_pos, subgraphs_neg, g_labels_neg, r_labels_neg + + def __len__(self): + return self.num_graphs_pos + + def _prepare_subgraphs(self, nodes, r_label, n_labels): + if not isinstance(self.graph, dgl.DGLGraph): + subgraph = dgl.graph(self.graph.subgraph(nodes)) + else: + subgraph = self.graph.subgraph(nodes) + #subgraph.edata['type'] = self.graph.edata['type'][self.graph.subgraph(nodes).parent_eid] + subgraph.edata['type'] = self.graph.edata['type'][subgraph.edata[dgl.EID]] + subgraph.edata['label'] = torch.tensor(r_label * np.ones(subgraph.edata['type'].shape), dtype=torch.long) + #print("请输出: ") + #print(subgraph) + #edges_btw_roots = subgraph.edge_id(0, 1, return_array=True) + #edges_btw_roots = subgraph.edge_ids(0, 1) + edges_btw_roots = torch.tensor([]) + try: + edges_btw_roots = subgraph.edge_ids(torch.tensor([0]),torch.tensor([1])) + # edges_btw_roots = np.array([edges_btw_roots]) + except: + #print("Error") + edges_btw_roots = torch.tensor([]) + edges_btw_roots = edges_btw_roots.numpy() + rel_link = np.nonzero(subgraph.edata['type'][edges_btw_roots] == r_label) + if rel_link.squeeze().nelement() == 0: + subgraph = dgl.add_edges(subgraph, 0, 1) + subgraph.edata['type'][-1] = torch.tensor(r_label).type(torch.LongTensor) + subgraph.edata['label'][-1] = torch.tensor(r_label).type(torch.LongTensor) + + + + # map the id read by GraIL to the entity IDs as registered by the KGE embeddings + kge_nodes = [self.kge_entity2id[self.id2entity[n]] for n in nodes] if self.kge_entity2id else None + n_feats = self.node_features[kge_nodes] if self.node_features is not None else None + subgraph = self._prepare_features_new(subgraph, n_labels, n_feats) + + return subgraph + + def _prepare_features(self, subgraph, n_labels, n_feats=None): + # One hot encode the node label feature and concat to n_featsure + n_nodes = subgraph.number_of_nodes() + label_feats = np.zeros((n_nodes, self.max_n_label[0] + 1)) + label_feats[np.arange(n_nodes), n_labels] = 1 + label_feats[np.arange(n_nodes), self.max_n_label[0] + 1 + n_labels[:, 1]] = 1 + n_feats = np.concatenate((label_feats, n_feats), axis=1) if n_feats else label_feats + subgraph.ndata['feat'] = torch.FloatTensor(n_feats) + self.n_feat_dim = n_feats.shape[1] # Find cleaner way to do this -- i.e. set the n_feat_dim + return subgraph + + def _prepare_features_new(self, subgraph, n_labels, n_feats=None): + # One hot encode the node label feature and concat to n_featsure + n_nodes = subgraph.number_of_nodes() + label_feats = np.zeros((n_nodes, self.max_n_label[0] + 1 + self.max_n_label[1] + 1)) + label_feats[np.arange(n_nodes), n_labels[:, 0]] = 1 + label_feats[np.arange(n_nodes), self.max_n_label[0] + 1 + n_labels[:, 1]] = 1 + # label_feats = np.zeros((n_nodes, self.max_n_label[0] + 1 + self.max_n_label[1] + 1)) + # label_feats[np.arange(n_nodes), 0] = 1 + # label_feats[np.arange(n_nodes), self.max_n_label[0] + 1] = 1 + n_feats = np.concatenate((label_feats, n_feats), axis=1) if n_feats is not None else label_feats + subgraph.ndata['feat'] = torch.FloatTensor(n_feats) + + head_id = np.argwhere([label[0] == 0 and label[1] == 1 for label in n_labels]) + tail_id = np.argwhere([label[0] == 1 and label[1] == 0 for label in n_labels]) + n_ids = np.zeros(n_nodes) + n_ids[head_id] = 1 # head + n_ids[tail_id] = 2 # tail + subgraph.ndata['id'] = torch.FloatTensor(n_ids) + + self.n_feat_dim = n_feats.shape[1] # Find cleaner way to do this -- i.e. set the n_feat_dim + return subgraph + + +@register_dataset('grail_link_prediction') +class Grail_LinkPrediction(LinkPredictionDataset): + def __init__(self, dataset_name, *args, **kwargs): + super(Grail_LinkPrediction, self).__init__(*args, **kwargs) + + self.args = kwargs['args'] + self.args.db_path = f'./openhgnn/dataset/data/{self.args.dataset}/subgraphs_en_{self.args.enclosing_sub_graph}_neg_{self.args.num_neg_samples_per_link}_hop_{self.args.hop}' + + self.args.train_file = "train" + self.args.valid_file = "valid" + self.args.file_paths = { + 'train': './openhgnn/dataset/data/{}/{}.txt'.format(self.args.dataset, self.args.train_file), + 'valid': './openhgnn/dataset/data/{}/{}.txt'.format(self.args.dataset, self.args.valid_file) + } + + relation2id_path = f'./openhgnn/dataset/data/{self.args.dataset}/relation2id.json' + + self.data_folder = f'./openhgnn/dataset/data/{self.args.dataset}' + if not os.path.exists(self.data_folder): + os.makedirs(self.data_folder) # makedirs 创建文件时如果路径不存在会创建这个路径 + url = f'https://github.com/kkteru/grail/blob/master/data/{self.args.dataset}' + self.download_folder(url,self.data_folder) + print("--- download data ---") + + else: + print("--- There is data! ---") + + if not os.path.exists(self.data_folder+'_ind'): + os.makedirs(self.data_folder+'_ind') # makedirs 创建文件时如果路径不存在会创建这个路径 + url = f'https://github.com/kkteru/grail/blob/master/data/{self.args.dataset}_ind' + self.download_folder(url,self.data_folder+'_ind') + print("--- download data ---") + + else: + print("--- There is data! ---") + + if not os.path.isdir(self.args.db_path): + generate_subgraph_datasets(self.args, relation2id_path) + + + with open(relation2id_path) as f: + self.relation2id = json.load(f) + self.train = SubGraphDataset(self.args.db_path, 'train_pos', 'train_neg', self.args.file_paths,add_traspose_rels=self.args.add_traspose_rels,num_neg_samples_per_link=self.args.num_neg_samples_per_link,use_kge_embeddings=self.args.use_kge_embeddings, dataset=self.args.dataset,kge_model=self.args.kge_model, file_name=self.args.train_file) + self.valid = SubGraphDataset(self.args.db_path, 'valid_pos', 'valid_neg', self.args.file_paths, + add_traspose_rels=self.args.add_traspose_rels, + num_neg_samples_per_link=self.args.num_neg_samples_per_link, + use_kge_embeddings=self.args.use_kge_embeddings, dataset=self.args.dataset, + kge_model=self.args.kge_model, file_name= self.args.valid_file) + + def download_folder(self,url, save_path): + response = requests.get(url) + if response.status_code == 200: + # 确保保存路径存在 + os.makedirs(save_path, exist_ok=True) -class kg_sampler(): + # 解析响应内容 + content = response.content.decode('utf-8') + lines = content.splitlines() + + for line in lines: + # 提取文件名 + file_name = line.split('/')[-1] + + # 构建文件的完整URL + file_url = url + '/' + file_name + + # 构建文件的保存路径 + file_save_path = os.path.join(save_path, file_name) + + # 下载文件 + self.download_file(file_url, file_save_path) + + def download_file(self,url, save_path): + response = requests.get(url) + if response.status_code == 200: + with open(save_path, 'wb') as file: + file.write(response.content) + +class kg_sampler( ): def __init__(self, ): self.sampler = 'uniform' return @@ -624,10 +1280,10 @@ def generate_sampled_graph_and_labels(self, triplets, sample_size, split_size, # relabel nodes to have consecutive node ids edges = triplets[edges] - src, rel, dst = edges.transpose() + src, rel, dst = edges.transpose( ) uniq_v, edges = np.unique((src, dst), return_inverse=True) src, dst = np.reshape(edges, (2, -1)) - relabeled_edges = np.stack((src, rel, dst)).transpose() + relabeled_edges = np.stack((src, rel, dst)).transpose( ) # negative sampling samples, labels = negative_sampling(relabeled_edges, len(uniq_v), @@ -698,3 +1354,1255 @@ def sample_edge_uniform(adj_list, degrees, n_triplets, sample_size): """Sample edges uniformly from all the edges.""" all_edges = np.arange(n_triplets) return np.random.choice(all_edges, sample_size, replace=False) + + +# --- ExpressGNN --- + + +# grounded rule stats code +BAD = 0 # sample not valid +FULL_OBSERVERED = 1 # sample valid, but rule contains only observed vars and does not have negation for all atoms +GOOD = 2 # sample valid + + +@register_dataset('express_gnn') +class ExpressGNNDataset(BaseDataset): + def __init__(self, dataset_name, *args, **kwargs): + super( ).__init__(*args, **kwargs) + self.args = kwargs['args'] + self.PRED_DICT = {} + self.dataset_name = dataset_name + self.const_dict = ConstantDict() + self.batchsize = self.args.batchsize + self.shuffle_sampling = self.args.shuffle_sampling + data_root = 'openhgnn' + data_root = os.path.join(data_root, 'dataset') + data_root = os.path.join(data_root, 'data') + data_root = os.path.join(data_root, self.dataset_name) + ext_rule_path = None + folder = os.path.exists(data_root) + print(data_root) + print('folder', folder) + if not folder: # 判断是否存在文件夹如果不存在则创建为文件夹 + os.makedirs(data_root) # makedirs 创建文件时如果路径不存在会创建这个路径 + # 下载数据 + url = f"https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/{dataset_name}.zip" + response = requests.get(url) + with zipfile.ZipFile(io.BytesIO(response.content)) as myzip: + myzip.extractall(data_root) + print("--- download data ---") + + else: + print("--- There is data! ---") + + # Decide the way dataset will be load, set 1 to load FBWN dataset + load_method = 0 + # print(dataset_name[0:13]) + if dataset_name[0:13] == 'EXP_FB15k-237': + load_method = 1 + else: + load_method = 0 + guss_fb = 'EXP_FB15k' in data_root + if guss_fb != (load_method == 1): + print("WARNING: set load_method to 1 if you load Freebase dataset, otherwise 0") + + # FBWN dataset + if load_method == 1: + fact_path_ls = [joinpath(data_root, 'facts.txt'), + joinpath(data_root, 'train.txt')] + query_path = joinpath(data_root, 'test.txt') + pred_path = joinpath(data_root, 'relations.txt') + const_path = joinpath(data_root, 'entities.txt') + valid_path = joinpath(data_root, 'valid.txt') + + rule_path = joinpath(data_root, 'cleaned_rules_weight_larger_than_0.9.txt') + print(rule_path) + print(os.getcwd()) + # print(fact_path_ls + [query_path, pred_path, const_path, valid_path, rule_path]) + # assert all(map(isfile, fact_path_ls + [query_path, pred_path, const_path, valid_path, rule_path])) + + # assuming only one type + TYPE_SET.update(['type']) + + # add all const + for line in iterline(const_path): + self.const_dict.add_const('type', line) + + # add all pred + for line in iterline(pred_path): + self.PRED_DICT[line] = Predicate(line, ['type', 'type']) + + # add all facts + fact_ls = [] + for fact_path in fact_path_ls: + for line in iterline(fact_path): + parts = line.split('\t') + + assert len(parts) == 3, print(parts) + + e1, pred_name, e2 = parts + + assert self.const_dict.has_const('type', e1) and self.const_dict.has_const('type', e2) + assert pred_name in self.PRED_DICT + + fact_ls.append(Fact(pred_name, [e1, e2], 1)) + + # add all validations + valid_ls = [] + for line in iterline(valid_path): + parts = line.split('\t') + + assert len(parts) == 3, print(parts) + + e1, pred_name, e2 = parts + + assert self.const_dict.has_const('type', e1) and self.const_dict.has_const('type', e2) + assert pred_name in self.PRED_DICT + + valid_ls.append(Fact(pred_name, [e1, e2], 1)) + + # add all queries + query_ls = [] + for line in iterline(query_path): + parts = line.split('\t') + + assert len(parts) == 3, print(parts) + + e1, pred_name, e2 = parts + + assert self.const_dict.has_const('type', e1) and self.const_dict.has_const('type', e2) + assert pred_name in self.PRED_DICT + + query_ls.append(Fact(pred_name, [e1, e2], 1)) + + # add all rules + rule_ls = [] + strip_items = lambda ls: list(map(lambda x: x.strip( ), ls)) + first_atom_reg = re.compile(r'([\d.]+) (!?)([^(]+)\((.*)\)') + atom_reg = re.compile(r'(!?)([^(]+)\((.*)\)') + for line in iterline(rule_path): + + atom_str_ls = strip_items(line.split(' v ')) + assert len(atom_str_ls) > 1, 'rule length must be greater than 1, but get %s' % line + + atom_ls = [] + rule_weight = 0.0 + for i, atom_str in enumerate(atom_str_ls): + if i == 0: + m = first_atom_reg.match(atom_str) + assert m is not None, 'matching atom failed for %s' % atom_str + rule_weight = float(m.group(1)) + neg = m.group(2) == '!' + pred_name = m.group(3).strip( ) + var_name_ls = strip_items(m.group(4).split(',')) + else: + m = atom_reg.match(atom_str) + assert m is not None, 'matching atom failed for %s' % atom_str + neg = m.group(1) == '!' + pred_name = m.group(2).strip( ) + var_name_ls = strip_items(m.group(3).split(',')) + + atom = Atom(neg, pred_name, var_name_ls, self.PRED_DICT[pred_name].var_types) + atom_ls.append(atom) + + rule = Formula(atom_ls, rule_weight) + rule_ls.append(rule) + else: + if dataset_name == 'Cora' or dataset_name == 'kinship': + data_root = joinpath(data_root, 'S' + str(self.args.load_s)) + elif dataset_name == 'uw_cse': + if self.args.load_s == 1: + data_root = joinpath(data_root, 'ai') + elif self.args.load_s == 2: + data_root = joinpath(data_root, 'graphics') + elif self.args.load_s == 3: + data_root = joinpath(data_root, 'language') + elif self.args.load_s == 4: + data_root = joinpath(data_root, 'systems') + elif self.args.load_s == 5: + data_root = joinpath(data_root, 'theory') + else: + print('Warning: Invalid load_s') + else: + print('Warning: Invalid dataset for load_method = 0') + rpath = joinpath(data_root, 'rules') if ext_rule_path is None else ext_rule_path + fact_ls, rule_ls, query_ls = self.preprocess_kinship(joinpath(data_root, 'predicates'), + joinpath(data_root, 'facts'), + rpath, + joinpath(data_root, 'queries')) + valid_ls = [] + + self.const_sort_dict = dict( + [(type_name, sorted(list(self.const_dict[type_name]))) for type_name in self.const_dict.constants.keys( )]) + + if load_method == 1: + self.const2ind = dict([(const, i) for i, const in enumerate(self.const_sort_dict['type'])]) + + # linear in size of facts + self.fact_dict = dict((pred_name, set( )) for pred_name in self.PRED_DICT) + self.test_fact_dict = dict((pred_name, set( )) for pred_name in self.PRED_DICT) + self.valid_dict = dict((pred_name, set( )) for pred_name in self.PRED_DICT) + + self.ht_dict = dict((pred_name, [dict( ), dict( )]) for pred_name in self.PRED_DICT) + self.ht_dict_train = dict((pred_name, [dict( ), dict( )]) for pred_name in self.PRED_DICT) + + def add_ht(pn, c_ls, ht_dict): + if load_method == 0: + if c_ls[0] in ht_dict[pn][0]: + ht_dict[pn][0][c_ls[0]].add(c_ls[0]) + else: + ht_dict[pn][0][c_ls[0]] = {c_ls[0]} + elif load_method == 1: + if c_ls[0] in ht_dict[pn][0]: + ht_dict[pn][0][c_ls[0]].add(c_ls[1]) + else: + ht_dict[pn][0][c_ls[0]] = {c_ls[1]} + + if c_ls[1] in ht_dict[pn][1]: + ht_dict[pn][1][c_ls[1]].add(c_ls[0]) + else: + ht_dict[pn][1][c_ls[1]] = {c_ls[0]} + + const_cnter = Counter() + for fact in fact_ls: + self.fact_dict[fact.pred_name].add((fact.val, tuple(fact.const_ls))) + add_ht(fact.pred_name, fact.const_ls, self.ht_dict) + add_ht(fact.pred_name, fact.const_ls, self.ht_dict_train) + const_cnter.update(fact.const_ls) + + for fact in valid_ls: + self.valid_dict[fact.pred_name].add((fact.val, tuple(fact.const_ls))) + add_ht(fact.pred_name, fact.const_ls, self.ht_dict) + + # the sorted list version + self.fact_dict_2 = dict((pred_name, sorted(list(self.fact_dict[pred_name]))) + for pred_name in self.fact_dict.keys( )) + self.valid_dict_2 = dict((pred_name, sorted(list(self.valid_dict[pred_name]))) + for pred_name in self.valid_dict.keys( )) + + self.rule_ls = rule_ls + + # pred_atom-key dict + self.atom_key_dict_ls = [] + for rule in self.rule_ls: + atom_key_dict = dict( ) + + for atom in rule.atom_ls: + atom_dict = dict((var_name, dict( )) for var_name in atom.var_name_ls) + + for i, var_name in enumerate(atom.var_name_ls): + + if atom.pred_name not in self.fact_dict: + continue + + for v in self.fact_dict[atom.pred_name]: + if v[1][i] not in atom_dict[var_name]: + atom_dict[var_name][v[1][i]] = [v] + else: + atom_dict[var_name][v[1][i]] += [v] + + # happens if predicate occurs more than once in one rule then we merge the set + if atom.pred_name in atom_key_dict: + for k, v in atom_dict.items( ): + if k not in atom_key_dict[atom.pred_name]: + atom_key_dict[atom.pred_name][k] = v + else: + atom_key_dict[atom.pred_name] = atom_dict + + self.atom_key_dict_ls.append(atom_key_dict) + + self.test_fact_ls = [] + self.valid_fact_ls = [] + + for fact in query_ls: + self.test_fact_ls.append((fact.val, fact.pred_name, tuple(fact.const_ls))) + self.test_fact_dict[fact.pred_name].add((fact.val, tuple(fact.const_ls))) + add_ht(fact.pred_name, fact.const_ls, self.ht_dict) + + for fact in valid_ls: + self.valid_fact_ls.append((fact.val, fact.pred_name, tuple(fact.const_ls))) + self.num_rules = len(rule_ls) + + self.rule_gens = None + self.reset( ) + + def generate_gnd_pred(self, pred_name): + """ + return a list of all instantiations of a predicate function, this can be extremely large + :param pred_name: + string + :return: + """ + + assert pred_name in self.PRED_DICT + + pred = self.PRED_DICT[pred_name] + subs = itertools.product(*[self.const_sort_dict[var_type] for var_type in pred.var_types]) + + return [(pred_name, sub) for sub in subs] + + def generate_gnd_rule(self, rule): + + subs = itertools.product(*[self.const_sort_dict[rule.rule_vars[k]] for k in rule.rule_vars.keys( )]) + sub = next(subs, None) + + while sub is not None: + + latent_vars = [] + latent_neg_mask = [] + observed_neg_mask = [] + + for atom in rule.atom_ls: + grounding = tuple(sub[rule.key2ind[var_name]] for var_name in atom.var_name_ls) + pos_gnding, neg_gnding = (1, grounding), (0, grounding) + + if pos_gnding in self.fact_dict[atom.pred_name]: + observed_neg_mask.append(0 if atom.neg else 1) + elif neg_gnding in self.fact_dict[atom.pred_name]: + observed_neg_mask.append(1 if atom.neg else 0) + else: + latent_vars.append((atom.pred_name, grounding)) + latent_neg_mask.append(1 if atom.neg else 0) + + isfullneg = (sum(latent_neg_mask) == len(latent_neg_mask)) and \ + (sum(observed_neg_mask) > 0) + + yield latent_vars, [latent_neg_mask, observed_neg_mask], isfullneg + + sub = next(subs, None) + + def get_batch(self, epoch_mode=False, filter_latent=True): + """ + return the ind-th batch of ground formula and latent variable indicators + :return: + + Parameters + ---------- + filter_latent + epoch_mode + """ + + batch_neg_mask = [[] for _ in range(len(self.rule_ls))] + batch_latent_var_inds = [[] for _ in range(len(self.rule_ls))] + observed_rule_cnts = [0.0 for _ in range(len(self.rule_ls))] + flat_latent_vars = dict( ) + + cnt = 0 + + inds = list(range(len(self.rule_ls))) + + while cnt < self.batchsize: + + if self.shuffle_sampling: + shuffle(inds) + + hasdata = False + for ind in inds: + latent_vars, neg_mask, isfullneg = next(self.rule_gens[ind], (None, None, None)) + + if latent_vars is None: + if epoch_mode: + continue + else: + self.rule_gens[ind] = self.generate_gnd_rule(self.rule_ls[ind]) + latent_vars, neg_mask, isfullneg = next(self.rule_gens[ind]) + + if epoch_mode: + hasdata = True + + # if rule is fully latent + if (len(neg_mask[1]) == 0) and filter_latent: + continue + + # if rule fully observed + if len(latent_vars) == 0: + observed_rule_cnts[ind] += 0 if isfullneg else 1 + cnt += 1 + if cnt >= self.batchsize: + break + else: + continue + + batch_neg_mask[ind].append(neg_mask) + + for latent_var in latent_vars: + if latent_var not in flat_latent_vars: + flat_latent_vars[latent_var] = len(flat_latent_vars) + + batch_latent_var_inds[ind].append([flat_latent_vars[e] for e in latent_vars]) + + cnt += 1 + + if cnt >= self.batchsize: + break + + if epoch_mode and (hasdata is False): + break + + flat_list = sorted([(k, v) for k, v in flat_latent_vars.items( )], key=lambda x: x[1]) + flat_list = [e[0] for e in flat_list] + + return batch_neg_mask, flat_list, batch_latent_var_inds, observed_rule_cnts + + def _instantiate_pred(self, atom, atom_dict, sub, rule, observed_prob): + + key2ind = rule.key2ind + rule_vars = rule.rule_vars + + # substitute with observed fact + if np.random.rand( ) < observed_prob: + + fact_choice_set = None + for var_name in atom.var_name_ls: + const = sub[key2ind[var_name]] + if const is None: + choice_set = itertools.chain.from_iterable([v for k, v in atom_dict[var_name].items( )]) + else: + if const in atom_dict[var_name]: + choice_set = atom_dict[var_name][const] + else: + choice_set = [] + + if fact_choice_set is None: + fact_choice_set = set(choice_set) + else: + fact_choice_set = fact_choice_set.intersection(set(choice_set)) + + if len(fact_choice_set) == 0: + break + + if len(fact_choice_set) == 0: + for var_name in atom.var_name_ls: + if sub[key2ind[var_name]] is None: + sub[key2ind[var_name]] = choice(self.const_sort_dict[rule_vars[var_name]]) + else: + val, const_ls = choice(sorted(list(fact_choice_set))) + for var_name, const in zip(atom.var_name_ls, const_ls): + sub[key2ind[var_name]] = const + + # substitute with random facts + else: + for var_name in atom.var_name_ls: + if sub[key2ind[var_name]] is None: + sub[key2ind[var_name]] = choice(self.const_sort_dict[rule_vars[var_name]]) + + def _gen_mask(self, rule, sub, closed_world): + + latent_vars = [] + observed_vars = [] + latent_neg_mask = [] + observed_neg_mask = [] + + for atom in rule.atom_ls: + grounding = tuple(sub[rule.key2ind[var_name]] for var_name in atom.var_name_ls) + pos_gnding, neg_gnding = (1, grounding), (0, grounding) + + if pos_gnding in self.fact_dict[atom.pred_name]: + observed_vars.append((1, atom.pred_name)) + observed_neg_mask.append(0 if atom.neg else 1) + elif neg_gnding in self.fact_dict[atom.pred_name]: + observed_vars.append((0, atom.pred_name)) + observed_neg_mask.append(1 if atom.neg else 0) + else: + if closed_world and (len(self.test_fact_dict[atom.pred_name]) == 0): + observed_vars.append((0, atom.pred_name)) + observed_neg_mask.append(1 if atom.neg else 0) + else: + latent_vars.append((atom.pred_name, grounding)) + latent_neg_mask.append(1 if atom.neg else 0) + + return latent_vars, observed_vars, latent_neg_mask, observed_neg_mask + + def _get_rule_stat(self, observed_vars, latent_vars, observed_neg_mask, filter_latent, filter_observed): + + is_full_latent = len(observed_vars) == 0 + is_full_observed = len(latent_vars) == 0 + + if is_full_latent and filter_latent: + return BAD + + if is_full_observed: + + if filter_observed: + return BAD + + is_full_neg = sum(observed_neg_mask) == 0 + + if is_full_neg: + return BAD + + else: + return FULL_OBSERVERED + + # if observed var already yields 1 + if sum(observed_neg_mask) > 0: + return BAD + + return GOOD + + def _inst_var(self, sub, var2ind, var2type, at, ht_dict, gen_latent): + + if len(at.var_name_ls) != 2: + raise KeyError + + must_latent = gen_latent + + if must_latent: + + tmp = [sub[var2ind[vn]] for vn in at.var_name_ls] + + for i, subi in enumerate(tmp): + if subi is None: + tmp[i] = random.choice(self.const_sort_dict[var2type[at.var_name_ls[i]]]) + + islatent = (tmp[0] not in ht_dict[0]) or (tmp[1] not in ht_dict[0][tmp[0]]) + for i, vn in enumerate(at.var_name_ls): + sub[var2ind[vn]] = tmp[i] + return [self.const2ind[subi] for subi in tmp], islatent, islatent or at.neg + + vn0 = at.var_name_ls[0] + sub0 = sub[var2ind[vn0]] + vn1 = at.var_name_ls[1] + sub1 = sub[var2ind[vn1]] + + if sub0 is None: + + if sub1 is None: + if len(ht_dict[0]) > 0: + sub0 = random.choice(tuple(ht_dict[0].keys( ))) + sub1 = random.choice(tuple(ht_dict[0][sub0])) + sub[var2ind[vn0]] = sub0 + sub[var2ind[vn1]] = sub1 + return [self.const2ind[sub0], self.const2ind[sub1]], False, at.neg + + else: + if sub1 in ht_dict[1]: + sub0 = random.choice(tuple(ht_dict[1][sub1])) + sub[var2ind[vn0]] = sub0 + return [self.const2ind[sub0], self.const2ind[sub1]], False, at.neg + else: + sub0 = random.choice(self.const_sort_dict[var2type[vn0]]) + sub[var2ind[vn0]] = sub0 + return [self.const2ind[sub0], self.const2ind[sub1]], True, True + + else: + + if sub1 is None: + if sub0 in ht_dict[0]: + sub1 = random.choice(tuple(ht_dict[0][sub0])) + sub[var2ind[vn1]] = sub1 + return [self.const2ind[sub0], self.const2ind[sub1]], False, at.neg + else: + sub1 = random.choice(self.const_sort_dict[var2type[vn1]]) + sub[var2ind[vn1]] = sub1 + return [self.const2ind[sub0], self.const2ind[sub1]], True, True + + else: + islatent = (sub0 not in ht_dict[0]) or (sub1 not in ht_dict[0][sub0]) + return [self.const2ind[sub0], self.const2ind[sub1]], islatent, islatent or at.neg + + def get_batch_fast(self, batchsize, observed_prob=0.9): + + prob_decay = 0.5 + + for rule in self.rule_ls: + + var2ind = rule.key2ind + var2type = rule.rule_vars + samples = [[atom.pred_name, []] for atom in rule.atom_ls] + neg_mask = [[atom.pred_name, []] for atom in rule.atom_ls] + latent_mask = [[atom.pred_name, []] for atom in rule.atom_ls] + obs_var = [[atom.pred_name, []] for atom in rule.atom_ls] + + cnt = 0 + while cnt <= batchsize: + + sub = [None] * len(rule.rule_vars) # substitutions + + sample_buff = [[] for _ in rule.atom_ls] + neg_mask_buff = [[] for _ in rule.atom_ls] + latent_mask_buff = [[] for _ in rule.atom_ls] + + atom_inds = list(range(len(rule.atom_ls))) + shuffle(atom_inds) + succ = True + cur_threshold = observed_prob + obs_list = [] + + for atom_ind in atom_inds: + atom = rule.atom_ls[atom_ind] + pred_ht_dict = self.ht_dict_train[atom.pred_name] + + gen_latent = np.random.rand( ) > cur_threshold + c_ls, islatent, atom_succ = self._inst_var(sub, var2ind, var2type, + atom, pred_ht_dict, gen_latent) + + if not islatent: + obs_var[atom_ind][1].append(c_ls) + + cur_threshold *= prob_decay + succ = succ and atom_succ + obs_list.append(not islatent) + + if succ: + sample_buff[atom_ind].append(c_ls) + latent_mask_buff[atom_ind].append(1 if islatent else 0) + neg_mask_buff[atom_ind].append(0 if atom.neg else 1) + + if succ and any(obs_list): + for i in range(len(rule.atom_ls)): + samples[i][1].extend(sample_buff[i]) + latent_mask[i][1].extend(latent_mask_buff[i]) + neg_mask[i][1].extend(neg_mask_buff[i]) + + cnt += 1 + + yield samples, neg_mask, latent_mask, obs_var + + def get_batch_by_q(self, batchsize, observed_prob=1.0, validation=False): + + samples_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] + neg_mask_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] + latent_mask_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] + obs_var_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] + neg_var_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] + cnt = 0 + + num_ents = len(self.const2ind) + ind2const = self.const_sort_dict['type'] + + def gen_fake(c1, c2, pn): + for _ in range(10): + c1_fake = random.randint(0, num_ents - 1) + c2_fake = random.randint(0, num_ents - 1) + if np.random.rand( ) > 0.5: + if ind2const[c1_fake] not in self.ht_dict_train[pn][1][ind2const[c2]]: + return c1_fake, c2 + else: + if ind2const[c2_fake] not in self.ht_dict_train[pn][0][ind2const[c1]]: + return c1, c2_fake + return None, None + + if validation: + fact_ls = self.valid_fact_ls + else: + fact_ls = self.test_fact_ls + + for val, pred_name, consts in fact_ls: + + for rule_i, rule in enumerate(self.rule_ls): + + # find rule with pred_name as head + if rule.atom_ls[-1].pred_name != pred_name: + continue + + samples = samples_by_r[rule_i] + neg_mask = neg_mask_by_r[rule_i] + latent_mask = latent_mask_by_r[rule_i] + obs_var = obs_var_by_r[rule_i] + neg_var = neg_var_by_r[rule_i] + + var2ind = rule.key2ind + var2type = rule.rule_vars + + sub = [None] * len(rule.rule_vars) # substitutions + vn0, vn1 = rule.atom_ls[-1].var_name_ls + sub[var2ind[vn0]] = consts[0] + sub[var2ind[vn1]] = consts[1] + + sample_buff = [[] for _ in rule.atom_ls] + neg_mask_buff = [[] for _ in rule.atom_ls] + latent_mask_buff = [[] for _ in rule.atom_ls] + + atom_inds = list(range(len(rule.atom_ls) - 1)) + shuffle(atom_inds) + succ = True + obs_list = [] + + for atom_ind in atom_inds: + atom = rule.atom_ls[atom_ind] + pred_ht_dict = self.ht_dict_train[atom.pred_name] + + gen_latent = np.random.rand( ) > observed_prob + c_ls, islatent, atom_succ = self._inst_var(sub, var2ind, var2type, + atom, pred_ht_dict, gen_latent) + + assert atom_succ + + if not islatent: + obs_var[atom_ind][1].append(c_ls) + c1, c2 = gen_fake(c_ls[0], c_ls[1], atom.pred_name) + if c1 is not None: + neg_var[atom_ind][1].append([c1, c2]) + + succ = succ and atom_succ + obs_list.append(not islatent) + + sample_buff[atom_ind].append(c_ls) + latent_mask_buff[atom_ind].append(1 if islatent else 0) + neg_mask_buff[atom_ind].append(0 if atom.neg else 1) + + if succ and any(obs_list): + for i in range(len(rule.atom_ls)): + samples[i][1].extend(sample_buff[i]) + latent_mask[i][1].extend(latent_mask_buff[i]) + neg_mask[i][1].extend(neg_mask_buff[i]) + + samples[-1][1].append([self.const2ind[consts[0]], self.const2ind[consts[1]]]) + latent_mask[-1][1].append(1) + neg_mask[-1][1].append(1) + + cnt += 1 + + if cnt >= batchsize: + yield samples_by_r, latent_mask_by_r, neg_mask_by_r, obs_var_by_r, neg_var_by_r + + samples_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] + neg_mask_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] + latent_mask_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] + obs_var_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] + neg_var_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] + cnt = 0 + + yield samples_by_r, latent_mask_by_r, neg_mask_by_r, obs_var_by_r, neg_var_by_r + + def get_batch_by_q_v2(self, batchsize, observed_prob=1.0): + + samples_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] + neg_mask_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] + latent_mask_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] + obs_var_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] + neg_var_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] + cnt = 0 + + num_ents = len(self.const2ind) + ind2const = self.const_sort_dict['type'] + + def gen_fake(c1, c2, pn): + for _ in range(10): + c1_fake = random.randint(0, num_ents - 1) + c2_fake = random.randint(0, num_ents - 1) + if np.random.rand( ) > 0.5: + if ind2const[c1_fake] not in self.ht_dict_train[pn][1][ind2const[c2]]: + return c1_fake, c2 + else: + if ind2const[c2_fake] not in self.ht_dict_train[pn][0][ind2const[c1]]: + return c1, c2_fake + return None, None + + for val, pred_name, consts in self.test_fact_ls: + + for rule_i, rule in enumerate(self.rule_ls): + + # find rule with pred_name as head + if rule.atom_ls[-1].pred_name != pred_name: + continue + + samples = samples_by_r[rule_i] + neg_mask = neg_mask_by_r[rule_i] + latent_mask = latent_mask_by_r[rule_i] + + var2ind = rule.key2ind + var2type = rule.rule_vars + + sub_ls = [[None for _ in range(len(rule.rule_vars))] for _ in range(2)] # substitutions + + vn0, vn1 = rule.atom_ls[-1].var_name_ls + sub_ls[0][var2ind[vn0]] = consts[0] + sub_ls[0][var2ind[vn1]] = consts[1] + + c1, c2 = gen_fake(self.const2ind[consts[0]], self.const2ind[consts[1]], pred_name) + if c1 is not None: + sub_ls[1][var2ind[vn0]] = ind2const[c1] + sub_ls[1][var2ind[vn1]] = ind2const[c2] + else: + sub_ls.pop(1) + + pos_query_succ = False + + for sub_ind, sub in enumerate(sub_ls): + + sample_buff = [[] for _ in rule.atom_ls] + neg_mask_buff = [[] for _ in rule.atom_ls] + latent_mask_buff = [[] for _ in rule.atom_ls] + + atom_inds = list(range(len(rule.atom_ls) - 1)) + shuffle(atom_inds) + succ = True + obs_list = [] + + for atom_ind in atom_inds: + atom = rule.atom_ls[atom_ind] + pred_ht_dict = self.ht_dict_train[atom.pred_name] + + gen_latent = np.random.rand( ) > observed_prob + if sub_ind == 1: + gen_latent = np.random.rand( ) > 0.5 + c_ls, islatent, atom_succ = self._inst_var(sub, var2ind, var2type, + atom, pred_ht_dict, gen_latent) + + assert atom_succ + + succ = succ and atom_succ + obs_list.append(not islatent) + + sample_buff[atom_ind].append(c_ls) + latent_mask_buff[atom_ind].append(1 if islatent else 0) + neg_mask_buff[atom_ind].append(0 if atom.neg else 1) + + if succ: + if any(obs_list) or ((sub_ind == 1) and pos_query_succ): + + for i in range(len(rule.atom_ls)): + samples[i][1].extend(sample_buff[i]) + latent_mask[i][1].extend(latent_mask_buff[i]) + neg_mask[i][1].extend(neg_mask_buff[i]) + + if sub_ind == 0: + samples[-1][1].append([self.const2ind[consts[0]], self.const2ind[consts[1]]]) + latent_mask[-1][1].append(1) + neg_mask[-1][1].append(1) + pos_query_succ = True + cnt += 1 + else: + samples[-1][1].append([c1, c2]) + latent_mask[-1][1].append(0) # sample a negative fact at head + neg_mask[-1][1].append(1) + + if cnt >= batchsize: + yield samples_by_r, latent_mask_by_r, neg_mask_by_r, obs_var_by_r, neg_var_by_r + + samples_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] + neg_mask_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] + latent_mask_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] + obs_var_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] + neg_var_by_r = [[[atom.pred_name, []] for atom in rule.atom_ls] for rule in self.rule_ls] + cnt = 0 + + yield samples_by_r, latent_mask_by_r, neg_mask_by_r, obs_var_by_r, neg_var_by_r + + def get_batch_rnd(self, observed_prob=0.7, filter_latent=True, closed_world=False, filter_observed=False): + """ + return a batch of gnd formulae by random sampling with controllable bias towards those containing + observed variables. The overall sampling logic is that: + 1) rnd sample a rule from rule_ls + 2) shuffle the predicates contained in the rule + 3) for each of these predicates, with (observed_prob) it will be instantiated as observed variable, and + for (1-observed_prob) if will be simply uniformly instantiated. + 3.1) if observed var, then sample from the knowledge base, which is self.fact_dict, if failed for any + reason, go to 3.2) + 3.2) if uniformly sample, then for each logic variable in the predicate, instantiate it with a uniform + sample from the corresponding constant dict + + :param observed_prob: + probability of instantiating a predicate as observed variable + :param filter_latent: + filter out ground formula containing only latent vars + :param closed_world: + if set True, reduce the sampling space of all predicates not in the test_dict to the set specified in + fact_dict + :param filter_observed: + filter out ground formula containing only observed vars + :return: + + """ + + batch_neg_mask = [[] for _ in range(len(self.rule_ls))] + batch_latent_var_inds = [[] for _ in range(len(self.rule_ls))] + batch_observed_vars = [[] for _ in range(len(self.rule_ls))] + observed_rule_cnts = [0.0 for _ in range(len(self.rule_ls))] + flat_latent_vars = dict( ) + + cnt = 0 + + inds = list(range(len(self.rule_ls))) + + while cnt < self.batchsize: + + # randomly sample a formula + if self.shuffle_sampling: + shuffle(inds) + + for ind in inds: + + rule = self.rule_ls[ind] + atom_key_dict = self.atom_key_dict_ls[ind] + sub = [None] * len(rule.rule_vars) # substitutions + + # randomly sample an atom from the formula + atom_inds = list(range(len(rule.atom_ls))) + shuffle(atom_inds) + for atom_ind in atom_inds: + atom = rule.atom_ls[atom_ind] + atom_dict = atom_key_dict[atom.pred_name] + + # instantiate the predicate + self._instantiate_pred(atom, atom_dict, sub, rule, observed_prob) + + # if variable substitution is complete already then exit + if not (None in sub): + break + + # generate latent and observed var labels and their negation masks + latent_vars, observed_vars, \ + latent_neg_mask, observed_neg_mask = self._gen_mask(rule, sub, closed_world) + + # check sampled ground rule status + stat_code = self._get_rule_stat(observed_vars, latent_vars, observed_neg_mask, + filter_latent, filter_observed) + + # is a valid sample with only observed vars and does not have negation on all of them + if stat_code == FULL_OBSERVERED: + observed_rule_cnts[ind] += 1 + + cnt += 1 + + # is a valid sample + elif stat_code == GOOD: + batch_neg_mask[ind].append([latent_neg_mask, observed_neg_mask]) + + for latent_var in latent_vars: + if latent_var not in flat_latent_vars: + flat_latent_vars[latent_var] = len(flat_latent_vars) + + batch_latent_var_inds[ind].append([flat_latent_vars[e] for e in latent_vars]) + batch_observed_vars[ind].append(observed_vars) + + cnt += 1 + + # not a valid sample + else: + continue + + if cnt >= self.batchsize: + break + + flat_list = sorted([(k, v) for k, v in flat_latent_vars.items( )], key=lambda x: x[1]) + flat_list = [e[0] for e in flat_list] + + return batch_neg_mask, flat_list, batch_latent_var_inds, observed_rule_cnts, batch_observed_vars + + def reset(self): + self.rule_gens = [self.generate_gnd_rule(rule) for rule in self.rule_ls] + + def get_stats(self): + + num_ents = sum([len(v) for k, v in self.const_sort_dict.items( )]) + num_rels = len(self.PRED_DICT) + num_facts = sum([len(v) for k, v in self.fact_dict.items( )]) + num_queries = len(self.test_fact_ls) + + num_gnd_atom = 0 + for pred_name, pred in self.PRED_DICT.items( ): + cnt = 1 + for var_type in pred.var_types: + cnt *= len(self.const_sort_dict[var_type]) + num_gnd_atom += cnt + + num_gnd_rule = 0 + for rule in self.rule_ls: + cnt = 1 + for var_type in rule.rule_vars.values( ): + cnt *= len(self.const_sort_dict[var_type]) + num_gnd_rule += cnt + + return num_ents, num_rels, num_facts, num_queries, num_gnd_atom, num_gnd_rule + + def preprocess_kinship(self, ppath, fpath, rpath, qpath): + """ + + :param ppath: + predicate file path + :param fpath: + facts file path + :param rpath: + rule file path + :param qpath: + query file path + + :return: + + """ + assert all(map(isfile, [ppath, fpath, rpath, qpath])) + + strip_items = lambda ls: list(map(lambda x: x.strip( ), ls)) + + pred_reg = re.compile(r'(.*)\((.*)\)') + + with open(ppath) as f: + for line in f: + + # skip empty lines + if line.strip( ) == '': + continue + + m = pred_reg.match(line.strip( )) + assert m is not None, 'matching predicate failed for %s' % line + + name, var_types = m.group(1), m.group(2) + var_types = list(map(lambda x: x.strip( ), var_types.split(','))) + + self.PRED_DICT[name] = Predicate(name, var_types) + TYPE_SET.update(var_types) + + fact_ls = [] + fact_reg = re.compile(r'(!?)(.*)\((.*)\)') + with open(fpath) as f: + for line in f: + + # skip empty lines + if line.strip( ) == '': + continue + + m = fact_reg.match(line.strip( )) + assert m is not None, 'matching fact failed for %s' % line + + val = 0 if m.group(1) == '!' else 1 + name, consts = m.group(2), m.group(3) + consts = strip_items(consts.split(',')) + + fact_ls.append(Fact(name, consts, val)) + + for var_type in self.PRED_DICT[name].var_types: + self.const_dict.add_const(var_type, consts.pop(0)) + + rule_ls = [] + first_atom_reg = re.compile(r'([\d.]+) (!?)([\w\d]+)\((.*)\)') + atom_reg = re.compile(r'(!?)([\w\d]+)\((.*)\)') + with open(rpath) as f: + for line in f: + + # skip empty lines + if line.strip( ) == '': + continue + + atom_str_ls = strip_items(line.strip( ).split(' v ')) + assert len(atom_str_ls) > 1, 'rule length must be greater than 1, but get %s' % line + + atom_ls = [] + rule_weight = 0.0 + for i, atom_str in enumerate(atom_str_ls): + if i == 0: + m = first_atom_reg.match(atom_str) + assert m is not None, 'matching atom failed for %s' % atom_str + rule_weight = float(m.group(1)) + neg = m.group(2) == '!' + pred_name = m.group(3).strip( ) + var_name_ls = strip_items(m.group(4).split(',')) + else: + m = atom_reg.match(atom_str) + assert m is not None, 'matching atom failed for %s' % atom_str + neg = m.group(1) == '!' + pred_name = m.group(2).strip( ) + var_name_ls = strip_items(m.group(3).split(',')) + + atom = Atom(neg, pred_name, var_name_ls, self.PRED_DICT[pred_name].var_types) + atom_ls.append(atom) + + rule = Formula(atom_ls, rule_weight) + rule_ls.append(rule) + + query_ls = [] + with open(qpath) as f: + for line in f: + + # skip empty lines + if line.strip( ) == '': + continue + + m = fact_reg.match(line.strip( )) + assert m is not None, 'matching fact failed for %s' % line + + val = 0 if m.group(1) == '!' else 1 + name, consts = m.group(2), m.group(3) + consts = strip_items(consts.split(',')) + + query_ls.append(Fact(name, consts, val)) + + for var_type in self.PRED_DICT[name].var_types: + self.const_dict.add_const(var_type, consts.pop(0)) + + return fact_ls, rule_ls, query_ls + + +TYPE_SET = set( ) + + +def iterline(fpath): + with open(fpath) as f: + + for line in f: + + line = line.strip( ) + if line == '': + continue + + yield line + + +class ConstantDict: + + def __init__(self): + self.constants = {} + + def add_const(self, const_type, const): + """ + + :param const_type: + string + :param const: + string + """ + + # if const_type not in TYPE_DICT: + # TYPE_DICT[const_type] = len(TYPE_DICT) + + if const_type in self.constants: + self.constants[const_type].add(const) + else: + self.constants[const_type] = {const} + + def __getitem__(self, key): + return self.constants[key] + + def has_const(self, key, const): + if key in self.constants: + return const in self[key] + else: + return False + + +class Predicate: + + def __init__(self, name, var_types): + """ + + :param name: + string + :param var_types: + list of strings + """ + self.name = name + self.var_types = var_types + self.num_args = len(var_types) + + def __repr__(self): + return '%s(%s)' % (self.name, ','.join(self.var_types)) + + +class Fact: + def __init__(self, pred_name, const_ls, val): + self.pred_name = pred_name + self.const_ls = deepcopy(const_ls) + self.val = val + + def __repr__(self): + return self.pred_name + '(%s)' % ','.join(self.const_ls) + + +class Atom: + def __init__(self, neg, pred_name, var_name_ls, var_type_ls): + self.neg = neg + self.pred_name = pred_name + self.var_name_ls = var_name_ls + self.var_type_ls = var_type_ls + + def __repr__(self): + return ('!' if self.neg else '') + self.pred_name + '(%s)' % ','.join(self.var_name_ls) + + +class Formula: + """ + only support clause form with disjunction, e.g. ! + """ + + def __init__(self, atom_ls, weight): + self.weight = weight + self.atom_ls = atom_ls + self.rule_vars = dict( ) + + for atom in self.atom_ls: + self.rule_vars.update(zip(atom.var_name_ls, atom.var_type_ls)) + self.key2ind = dict(zip(self.rule_vars.keys( ), range(len(self.rule_vars.keys( ))))) + + def evaluate(self): + pass + + def __repr__(self): + return ' v '.join(list(map(repr, self.atom_ls))) + + +class ConstantDict: + + def __init__(self): + self.constants = {} + + def add_const(self, const_type, const): + """ + + :param const_type: + string + :param const: + string + """ + + # if const_type not in TYPE_DICT: + # TYPE_DICT[const_type] = len(TYPE_DICT) + + if const_type in self.constants: + self.constants[const_type].add(const) + else: + self.constants[const_type] = {const} + + def __getitem__(self, key): + return self.constants[key] + + def has_const(self, key, const): + if key in self.constants: + return const in self[key] + else: + return False + +@register_dataset('NBF_link_prediction') +class NBF_LinkPrediction(LinkPredictionDataset): + r""" + The NBF dataset will be used in task *link prediction*. + + """ + + def __init__(self, dataset_name ,*args, **kwargs): # dataset_name in ['NBF_WN18RR','NBF_FB15k-237'] + + self.dataset = NBF_Dataset(root='./openhgnn/dataset/', name=dataset_name[4:], version="v1") + + + +import os +import requests +import zipfile +import io +@register_dataset('DisenKGAT_link_prediction') +class DisenKGAT_LinkPrediction(LinkPredictionDataset): + def __init__(self, dataset ,*args, **kwargs): # dataset "DisenKGAT" + self.logger = kwargs.get("Logger") + self.args = kwargs.get("args") + self.current_dir = os.path.dirname(os.path.abspath(__file__)) + self.dataset_name = dataset + self.raw_dir = os.path.join(self.current_dir, self.dataset_name ,"raw_dir" ) + self.processed_dir = os.path.join(self.current_dir, self.dataset_name ,"processed_dir" ) + + if not os.path.exists(self.raw_dir): + os.makedirs(self.raw_dir) + self.download() + else: + print("raw_dir already exists") + + def download(self): + + url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/{}.zip".format(self.dataset_name) + response = requests.get(url) + with zipfile.ZipFile(io.BytesIO(response.content)) as myzip: + myzip.extractall(self.raw_dir) + print("--- download finished---") + + + diff --git a/openhgnn/dataset/NBF_dataset.py b/openhgnn/dataset/NBF_dataset.py new file mode 100644 index 00000000..af434597 --- /dev/null +++ b/openhgnn/dataset/NBF_dataset.py @@ -0,0 +1,321 @@ +import os +import torch,requests,zipfile,io +import os.path as osp +from typing import Any, Callable, List, Optional +from collections.abc import Sequence +import copy +import ssl +import sys +import urllib +import errno + + + +class NBF_Dataset(): + + def __init__(self, root, name, version, transform=None, pre_transform=None):# root/name/version == ~/WN18RR/v1 + if isinstance(root, str): + root = osp.expanduser(osp.normpath(root)) + self.root = root + self.name = name + self.version = version + self.transform = transform + self.pre_transform = pre_transform + + assert name in ["FB15k-237", "WN18RR"] + assert version in ["v1", "v2", "v3", "v4"] + + self.urls = { + "FB15k-237": [ + "https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s_ind/train.txt", + "https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s_ind/test.txt", + "https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s/train.txt", + "https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s/valid.txt" + ], + "WN18RR": [ + "https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s_ind/train.txt", + "https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s_ind/test.txt", + "https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s/train.txt", + "https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s/valid.txt" + ] + } + + + + self._download() + self._process() + + self.train_data = torch.load(self.processed_paths[0]) + self.valid_data = torch.load(self.processed_paths[1]) + self.test_data = torch.load(self.processed_paths[2]) + print(self.processed_paths[0]) + + + def _download(self): + if files_exist(self.raw_paths): + return + + makedirs(self.raw_dir) + self.download() + + def _process(self): + + if files_exist(self.processed_paths): + return + + makedirs(self.processed_dir) + self.process() + + + @property + def num_relations(self): + #return int(self.data.edge_type.max()) + 1 + return max(int(self.train_data.edge_type.max()), + int(self.valid_data.edge_type.max()), + int(self.test_data.edge_type.max()), + ) + 1 + + @property + def raw_dir(self): + return os.path.join(self.root, self.name, self.version, "raw") + + @property + def processed_dir(self): + return os.path.join(self.root, self.name, self.version, "processed") + + @property + def processed_file_names(self): + return ["train_data.pt","valid_data.pt","test_data.pt"] + + @property + def raw_file_names(self): + return [ + "train_ind.txt", "test_ind.txt", "train.txt", "valid.txt" + ] + + + def download(self): + + url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/{}.zip".format(self.name) + response = requests.get(url) + with zipfile.ZipFile(io.BytesIO(response.content)) as myzip: + myzip.extractall(self.raw_dir) + print("--- download data finished---") + + + + @property + def raw_paths(self) : + r"""The absolute filepaths that must be present in order to skip + downloading.""" + files = self.raw_file_names + # Prevent a common source of error in which `file_names` are not + # defined as a property. + if isinstance(files, Callable): + files = files() + return [osp.join(self.raw_dir, f) for f in to_list(files)] + + + @property + def processed_paths(self) -> List[str]: + r"""The absolute filepaths that must be present in order to skip + processing.""" + files = self.processed_file_names + # Prevent a common source of error in which `file_names` are not + # defined as a property. + if isinstance(files, Callable): + files = files() + return [osp.join(self.processed_dir, f) for f in to_list(files)] + + + + def process(self): + test_files = self.raw_paths[:2] + train_files = self.raw_paths[2:] + inv_train_entity_vocab = {} + inv_test_entity_vocab = {} + inv_relation_vocab = {} + triplets = [] + num_samples = [] + + for txt_file in train_files: + with open(txt_file, "r") as fin: + num_sample = 0 + for line in fin: + h_token, r_token, t_token = line.strip().split("\t") + if h_token not in inv_train_entity_vocab: + inv_train_entity_vocab[h_token] = len(inv_train_entity_vocab) + h = inv_train_entity_vocab[h_token] + if r_token not in inv_relation_vocab: + inv_relation_vocab[r_token] = len(inv_relation_vocab) + r = inv_relation_vocab[r_token] + if t_token not in inv_train_entity_vocab: + inv_train_entity_vocab[t_token] = len(inv_train_entity_vocab) + t = inv_train_entity_vocab[t_token] + triplets.append((h, t, r)) + num_sample += 1 + num_samples.append(num_sample) + count = 0 + for txt_file in test_files: + with open(txt_file, "r") as fin: + num_sample = 0 + for line in fin: + h_token, r_token, t_token = line.strip().split("\t") + if h_token not in inv_test_entity_vocab: + inv_test_entity_vocab[h_token] = len(inv_test_entity_vocab) + h = inv_test_entity_vocab[h_token] + count += 1 + if r_token in inv_relation_vocab: # assert r_token in inv_relation_vocab + r = inv_relation_vocab[r_token] + if t_token not in inv_test_entity_vocab: + inv_test_entity_vocab[t_token] = len(inv_test_entity_vocab) + t = inv_test_entity_vocab[t_token] + triplets.append((h, t, r)) + num_sample += 1 + num_samples.append(num_sample) + + triplets = torch.tensor(triplets) + + edge_index = triplets[:, :2].t() + edge_type = triplets[:, 2] + + num_relations = int(edge_type.max()) + 1 + + + train_fact_slice = slice(None, sum(num_samples[:1])) + test_fact_slice = slice(sum(num_samples[:2]), sum(num_samples[:3])) + train_fact_index = edge_index[:, train_fact_slice] + train_fact_type = edge_type[train_fact_slice] + test_fact_index = edge_index[:, test_fact_slice] + test_fact_type = edge_type[test_fact_slice] + # add flipped triplets for the fact graphs + train_fact_index = torch.cat([train_fact_index, train_fact_index.flip(0)], dim=-1) + train_fact_type = torch.cat([train_fact_type, train_fact_type + num_relations]) + test_fact_index = torch.cat([test_fact_index, test_fact_index.flip(0)], dim=-1) + test_fact_type = torch.cat([test_fact_type, test_fact_type + num_relations]) + + train_slice = slice(None, sum(num_samples[:1])) + valid_slice = slice(sum(num_samples[:1]), sum(num_samples[:2])) + test_slice = slice(sum(num_samples[:3]), sum(num_samples)) + + train_data = NBF_Data(edge_index=train_fact_index, edge_type=train_fact_type, + num_nodes=len(inv_train_entity_vocab),num_edges = train_fact_index.shape[1], + target_edge_index=edge_index[:, train_slice], target_edge_type=edge_type[train_slice]) + + valid_data = NBF_Data(edge_index=train_fact_index, edge_type=train_fact_type, + num_nodes=len(inv_train_entity_vocab),num_edges = train_fact_index.shape[1], + target_edge_index=edge_index[:, valid_slice], target_edge_type=edge_type[valid_slice]) + + test_data = NBF_Data(edge_index=test_fact_index, edge_type=test_fact_type, + num_nodes=len(inv_test_entity_vocab),num_edges = test_fact_index.shape[1], + target_edge_index=edge_index[:, test_slice], target_edge_type=edge_type[test_slice]) + + if self.pre_transform is not None: + train_data = self.pre_transform(train_data) + valid_data = self.pre_transform(valid_data) + test_data = self.pre_transform(test_data) + + #torch.save((train_data, valid_data, test_data), self.processed_paths[0]) + torch.save(train_data, self.processed_paths[0]) + torch.save(valid_data, self.processed_paths[1]) + torch.save(test_data, self.processed_paths[2]) + + + + + def __repr__(self): + return "%s()" % self.name + + + + + +class NBF_Data(): + def __init__(self, + num_nodes,num_edges, + edge_index,edge_type, + target_edge_index,target_edge_type + ): + super().__init__() + self.num_nodes = num_nodes + self.num_edges = num_edges + self.edge_index = copy.copy(edge_index) + self.edge_type = copy.copy(edge_type) + self.target_edge_index = copy.copy(target_edge_index) + self.target_edge_type = copy.copy(target_edge_type) + + + +def files_exist(files: List[str]) -> bool: +# NOTE: We return `False` in case `files` is empty, leading to a +# re-processing of files on every instantiation. + return len(files) != 0 and all([osp.exists(f) for f in files]) + + + +def makedirs(path: str): + r"""Recursively creates a directory. + + Args: + path (str): The path to create. + """ + try: + os.makedirs(osp.expanduser(osp.normpath(path))) + except OSError as e: + if e.errno != errno.EEXIST and osp.isdir(path): + raise e + + + +def to_list(value: Any) -> Sequence: + if isinstance(value, Sequence) and not isinstance(value, str): + return value + else: + return [value] + + + +def download_url(url: str, folder: str, log: bool = True, + filename: Optional[str] = None): + r"""Downloads the content of an URL to a specific folder. + + Args: + url (str): The URL. + folder (str): The folder. + log (bool, optional): If :obj:`False`, will not print anything to the + console. (default: :obj:`True`) + """ + + if filename is None: + filename = url.rpartition('/')[2] + filename = filename if filename[0] == '?' else filename.split('?')[0] + + path = osp.join(folder, filename) + + if osp.exists(path): # pragma: no cover + if log and 'pytest' not in sys.modules: + print(f'Using existing file {filename}', file=sys.stderr) + return path + + if log and 'pytest' not in sys.modules: + print(f'Downloading {url}', file=sys.stderr) + + makedirs(folder) + + context = ssl._create_unverified_context() + data = urllib.request.urlopen(url, context=context) + + with open(path, 'wb') as f: + # workaround for https://bugs.python.org/issue42853 + while True: + chunk = data.read(10 * 1024 * 1024) + if not chunk: + break + f.write(chunk) + + return path + + + + + \ No newline at end of file diff --git a/openhgnn/dataset/SACN_dataset.py b/openhgnn/dataset/SACN_dataset.py new file mode 100644 index 00000000..e69de29b diff --git a/openhgnn/dataset/__init__.py b/openhgnn/dataset/__init__.py index cdac3444..4b5553a5 100644 --- a/openhgnn/dataset/__init__.py +++ b/openhgnn/dataset/__init__.py @@ -10,7 +10,12 @@ from .adapter import AsLinkPredictionDataset, AsNodeClassificationDataset from .mg2vec_dataset import Mg2vecDataSet from .meirec_dataset import MeiRECDataset, get_data_loader - +from .AdapropT_dataset import AdapropTDataLoader +from .AdapropI_dataset import AdapropIDataLoader +from .LTE_dataset import * +from .SACN_dataset import * +from .NBF_dataset import NBF_Dataset +from .Ingram_dataset import Ingram_KG_TrainData, Ingram_KG_TestData DATASET_REGISTRY = {} @@ -50,32 +55,65 @@ def try_import_task_dataset(task): return True -common = ['Cora','Citeseer','Pubmed','Texas','Cornell'] +common = ['Cora', 'Citeseer', 'Pubmed', 'Texas', 'Cornell'] hgbl_datasets = ['HGBl-amazon', 'HGBl-LastFM', 'HGBl-PubMed'] hgbn_datasets = ['HGBn-ACM', 'HGBn-DBLP', 'HGBn-Freebase', 'HGBn-IMDB'] -kg_lp_datasets = ['wn18', 'FB15k', 'FB15k-237'] + +kg_lp_datasets = ['wn18', 'FB15k', 'EXP_FB15k-237', 'EXP_FB15k-237_data_ratio_0', 'EXP_FB15k-237_data_ratio_0', + 'EXP_FB15k-237_data_ratio_0.1', 'EXP_FB15k-237_data_ratio_0.2', 'EXP_FB15k-237_data_ratio_zero_shot', + 'kinship', 'uw_cse'] + +kg_sub_datasets = [f'fb237_v{i}' for i in range(1, 5)] +kg_sub_datasets += [f'nell_v{i}' for i in range(1, 5)] +kg_sub_datasets += [f'WN18RR_v{i}' for i in range(1,5)] +kg_subT_datasets = ['family'] ohgbl_datasets = ['ohgbl-MTWM', 'ohgbl-yelp1', 'ohgbl-yelp2', 'ohgbl-Freebase'] ohgbn_datasets = ['ohgbn-Freebase', 'ohgbn-yelp2', 'ohgbn-acm', 'ohgbn-imdb'] hypergraph_datasets = ['GPS', 'drug', 'MovieLens', 'wordnet', 'aminer4AEHCL'] + + def build_dataset(dataset, task, *args, **kwargs): + args =kwargs.get('args') + model = args.model if isinstance(dataset, DGLDataset): return dataset + #-------------------更改部分------------------- + if dataset == 'NL-100': + train_dataloader = Ingram_KG_TrainData('',dataset) + valid_dataloader = Ingram_KG_TestData('', dataset,'valid') + test_dataloader = Ingram_KG_TestData('',dataset,'test') + return train_dataloader,valid_dataloader,test_dataloader + # -------------------更改部分------------------- if dataset == 'meirec': train_dataloader = get_data_loader("train", batch_size=args[0]) test_dataloader = get_data_loader("test", batch_size=args[0]) return train_dataloader, test_dataloader + #-------------------更改部分------------------- + if dataset == 'AdapropT': + dataload=AdapropTDataLoader(args) + return dataload + # -------------------更改部分------------------- + #-------------------更改部分------------------- + if dataset == 'AdapropI': + dataload=AdapropIDataLoader(args) + return dataload + + if dataset == 'SACN' or dataset == 'LTE': + return if dataset in CLASS_DATASETS: return build_dataset_v2(dataset, task) if not try_import_task_dataset(task): exit(1) + _dataset = None if dataset in ['aifb', 'mutag', 'bgs', 'am']: _dataset = 'rdf_' + task elif dataset in ['acm4NSHE', 'acm4GTN', 'academic4HetGNN', 'acm_han', 'acm_han_raw', 'acm4HeCo', 'dblp', - 'dblp4MAGNN', 'imdb4GTN', 'acm4NARS', 'demo_graph', 'yelp4HeGAN', 'DoubanMovie', - 'Book-Crossing', 'amazon4SLICE', 'MTWM', 'HNE-PubMed', 'HGBl-ACM', 'HGBl-DBLP', 'HGBl-IMDB','amazon', 'yelp4HGSL']: + 'dblp4MAGNN', 'imdb4MAGNN', 'imdb4GTN', 'acm4NARS', 'demo_graph', 'yelp4HeGAN', 'DoubanMovie', + 'Book-Crossing', 'amazon4SLICE', 'MTWM', 'HNE-PubMed', 'HGBl-ACM', 'HGBl-DBLP', 'HGBl-IMDB', + 'amazon', 'yelp4HGSL']: _dataset = 'hin_' + task elif dataset in ['imdb4MAGNN']: _dataset = 'hin_' + task @@ -90,8 +128,19 @@ def build_dataset(dataset, task, *args, **kwargs): elif dataset in hgbl_datasets: _dataset = 'HGBl_link_prediction' elif dataset in kg_lp_datasets: + if model == 'ExpressGNN': + assert task == 'link_prediction' + _dataset = 'express_gnn' + return DATASET_REGISTRY[_dataset](dataset, logger=kwargs['logger'], args=kwargs['args']) + else: + assert task == 'link_prediction' + _dataset = 'kg_link_prediction' + elif dataset in kg_sub_datasets: + assert task == 'link_prediction' + _dataset = 'kg_sub_link_prediction' + elif dataset in kg_subT_datasets: assert task == 'link_prediction' - _dataset = 'kg_link_prediction' + _dataset = 'kg_subT_link_prediction' elif dataset in ['LastFM4KGCN']: _dataset = 'kgcn_recommendation' elif dataset in ['gowalla', 'yelp2018', 'amazon-book']: @@ -108,12 +157,26 @@ def build_dataset(dataset, task, *args, **kwargs): _dataset = 'mag_dataset' elif dataset in hypergraph_datasets: _dataset = task - elif dataset in ['LastFM_KGAT','yelp2018_KGAT','amazon-book_KGAT']: - change_name={'LastFM_KGAT':'last-fm','yelp2018_KGAT':'yelp2018','amazon-book_KGAT':'amazon-book'} - dataset=change_name[dataset] - _dataset='kgat_recommendation' + elif dataset in ['LastFM_KGAT', 'yelp2018_KGAT', 'amazon-book_KGAT']: + change_name = {'LastFM_KGAT': 'last-fm', 'yelp2018_KGAT': 'yelp2018', 'amazon-book_KGAT': 'amazon-book'} + dataset = change_name[dataset] + _dataset = 'kgat_recommendation' elif dataset in common: - _dataset = 'common_' + task + if model == 'ExpressGNN' and dataset == 'Cora': + assert task == 'link_prediction' + _dataset = 'express_gnn' + return DATASET_REGISTRY[_dataset](dataset, logger=kwargs['logger'], args=kwargs['args']) + else: + _dataset = 'common_' + task + elif dataset in ['NBF_WN18RR','NBF_FB15k-237']: + _dataset = 'NBF_' + task + elif dataset in ['DisenKGAT_WN18RR','DisenKGAT_FB15k-237']: + _dataset = 'DisenKGAT_' + task # == 'DisenKGAT_link_prediction' + return DATASET_REGISTRY[_dataset](dataset, logger=kwargs['logger'],args = kwargs.get('args')) + + if kwargs['args'].model=='Grail' or kwargs['args'].model=='ComPILE': + _dataset = 'grail_'+ task + return DATASET_REGISTRY[_dataset](dataset, logger=kwargs['logger'],args=kwargs['args']) return DATASET_REGISTRY[_dataset](dataset, logger=kwargs['logger']) @@ -122,8 +185,8 @@ def build_dataset(dataset, task, *args, **kwargs): "link_prediction": "openhgnn.dataset.LinkPredictionDataset", "recommendation": "openhgnn.dataset.RecommendationDataset", "edge_classification": "openhgnn.dataset.EdgeClassificationDataset", - "hypergraph":"openhgnn.dataset.HypergraphDataset", - "pretrain":"openhgnn.dataset.mag_dataset" + "hypergraph": "openhgnn.dataset.HypergraphDataset", + "pretrain": "openhgnn.dataset.mag_dataset" } from .NodeClassificationDataset import NodeClassificationDataset @@ -132,13 +195,14 @@ def build_dataset(dataset, task, *args, **kwargs): from .EdgeClassificationDataset import EdgeClassificationDataset from .HypergraphDataset import HGraphDataset + def build_dataset_v2(dataset, task): if dataset in CLASS_DATASETS: path = ".".join(CLASS_DATASETS[dataset].split(".")[:-1]) module = importlib.import_module(path) class_name = CLASS_DATASETS[dataset].split(".")[-1] dataset_class = getattr(module, class_name) - d = dataset_class() + d = dataset_class( ) if task == 'node_classification': target_ntype = getattr(d, 'category') if target_ntype is None: @@ -157,11 +221,11 @@ def build_dataset_v2(dataset, task): "imdb4GTN": "openhgnn.dataset.IMDB4GTNDataset", "alircd_small": "openhgnn.dataset.AliRCDSmallDataset", "alircd_session1": "openhgnn.dataset.AliRCDSession1Dataset", - "ICDM":"openhgnn.dataset.AliICDMDataset", + "ICDM": "openhgnn.dataset.AliICDMDataset", "ohgbn-alircd_session1": "openhgnn.dataset.AliRCDSession1Dataset", "alircd_session2": "openhgnn.dataset.AliRCDSession2Dataset", "ohgbn-alircd_session2": "openhgnn.dataset.AliRCDSession2Dataset", - "pretrain":"openhgnn.dataset.mag_dataset" + "pretrain": "openhgnn.dataset.mag_dataset" } __all__ = [ diff --git a/openhgnn/experiment.py b/openhgnn/experiment.py index 00643470..efe84d27 100644 --- a/openhgnn/experiment.py +++ b/openhgnn/experiment.py @@ -48,6 +48,8 @@ class Experiment(object): 'TransH': 'TransX_trainer', 'TransR': 'TransX_trainer', 'TransD': 'TransX_trainer', + 'RedGNN': 'RedGNN_trainer', + 'RedGNNT': 'RedGNNT_trainer', 'GIE': 'TransX_trainer', 'HAN': { 'node_classification': 'han_nc_trainer', @@ -62,6 +64,16 @@ class Experiment(object): 'SHGP': 'SHGP_trainer', 'HGCL': 'hgcltrainer', 'lightGCN': 'lightGCN_trainer', + 'Grail': 'Grail_trainer', + 'ComPILE': 'ComPILE_trainer', + 'AdapropT':'AdapropT_trainer', + 'AdapropI':'AdapropI_trainer', + 'LTE':'LTE_trainer', + 'SACN':'SACN_trainer', + 'ExpressGNN': 'ExpressGNN_trainer', + 'NBF':'NBF_trainer', + 'Ingram': 'Ingram_trainer', + 'DisenKGAT': 'DisenKGAT_trainer' } immutable_params = ['model', 'dataset', 'task'] diff --git a/openhgnn/layers/AdapropI.py b/openhgnn/layers/AdapropI.py new file mode 100644 index 00000000..c41abda0 --- /dev/null +++ b/openhgnn/layers/AdapropI.py @@ -0,0 +1,143 @@ +import time +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..utils.utils import scatter + + +class GNNLayer(torch.nn.Module): + def __init__(self, in_dim, out_dim, attn_dim, n_rel, act=lambda x: x): + super(GNNLayer, self).__init__() + self.n_rel = n_rel + self.in_dim = in_dim + self.out_dim = out_dim + self.attn_dim = attn_dim + self.act = act + self.rela_embed = nn.Embedding(2 * n_rel + 1, in_dim) + self.Ws_attn = nn.Linear(in_dim, attn_dim, bias=False) + self.Wr_attn = nn.Linear(in_dim, attn_dim, bias=False) + self.Wqr_attn = nn.Linear(in_dim, attn_dim) + self.W_attn = nn.Linear(attn_dim, 1, bias=False) + self.W_h = nn.Linear(in_dim, out_dim, bias=False) + + def forward(self, q_sub, q_rel, hidden, edges, n_node, old_nodes_new_idx): + # edges: [batch_idx, head, rela, tail, old_idx, new_idx] + sub = edges[:, 4] + rel = edges[:, 2] + obj = edges[:, 5] + hs = hidden[sub] + hr = self.rela_embed(rel) + + r_idx = edges[:, 0] + h_qr = self.rela_embed(q_rel)[r_idx] + mess1 = hs + mess2 = mess1 + hr + alpha_2 = torch.sigmoid(self.W_attn(nn.ReLU()(self.Ws_attn(mess1) + self.Wr_attn(hr) + self.Wqr_attn(h_qr)))) + message = mess2 * alpha_2 + message_agg = scatter(message, index=obj, dim=0, dim_size=n_node, reduce='sum') + + hidden_new = self.W_h(message_agg) + hidden_new = self.act(hidden_new) + + return hidden_new + + +class GNNModel(torch.nn.Module): + def __init__(self, params, loader): + super(GNNModel, self).__init__() + self.n_layer = params.n_layer + self.init_dim = params.init_dim + self.hidden_dim = params.hidden_dim + self.attn_dim = params.attn_dim + self.n_rel = params.n_rel + self.loader = loader + self.increase = params.increase + self.topk = params.topk + acts = {'relu': nn.ReLU(), 'tanh': torch.tanh, 'idd': lambda x: x} + act = acts[params.act] + dropout = params.dropout + + self.layers = [] + self.Ws_layers = [] + for i in range(self.n_layer): + self.layers.append(GNNLayer(self.hidden_dim, self.hidden_dim, self.attn_dim, self.n_rel, act=act)) + self.Ws_layers.append(nn.Linear(self.hidden_dim, 1, bias=False)) + self.layers = nn.ModuleList(self.layers) + self.Ws_layers = nn.ModuleList(self.Ws_layers) + + self.dropout = nn.Dropout(dropout) + self.W_final = nn.Linear(self.hidden_dim, 1, bias=False) # get score + self.gru = nn.GRU(self.hidden_dim, self.hidden_dim) + + def soft_to_hard(self, i, hidden, nodes, n_ent, batch_size, old_nodes_new_idx): + n_node = len(nodes) + bool_diff_node_idx = torch.ones(n_node).bool().cuda() + bool_diff_node_idx[old_nodes_new_idx] = False + bool_same_node_idx = ~bool_diff_node_idx + diff_nodes = nodes[bool_diff_node_idx] + diff_node_logits = self.Ws_layers[i](hidden[bool_diff_node_idx].detach()).squeeze(-1) + + soft_all = torch.ones((batch_size, n_ent)) * float('-inf') + soft_all = soft_all.cuda() + soft_all[diff_nodes[:, 0], diff_nodes[:, 1]] = diff_node_logits + soft_all = F.softmax(soft_all, dim=-1) + + diff_node_logits = self.topk * soft_all[diff_nodes[:, 0], diff_nodes[:, 1]] + _, argtopk = torch.topk(soft_all, k=self.topk, dim=-1) + r_idx = torch.arange(batch_size).unsqueeze(1).repeat(1, self.topk).cuda() + hard_all = torch.zeros((batch_size, n_ent)).bool().cuda() + hard_all[r_idx, argtopk] = True + bool_sampled_diff_nodes = hard_all[diff_nodes[:, 0], diff_nodes[:, 1]] + + hidden[bool_diff_node_idx][bool_sampled_diff_nodes] *= ( + 1 - diff_node_logits[bool_sampled_diff_nodes].detach() + diff_node_logits[ + bool_sampled_diff_nodes]).unsqueeze(1) + bool_same_node_idx[bool_diff_node_idx] = bool_sampled_diff_nodes + + return hidden, bool_same_node_idx + + def forward(self, subs, rels, mode='transductive'): + n = len(subs) + n_ent = self.loader.n_ent if mode == 'transductive' else self.loader.n_ent_ind + q_sub = torch.LongTensor(subs).cuda() + q_rel = torch.LongTensor(rels).cuda() + h0 = torch.zeros((1, n, self.hidden_dim)).cuda() + nodes = torch.cat([torch.arange(n).unsqueeze(1).cuda(), q_sub.unsqueeze(1)], 1) + hidden = torch.zeros(n, self.hidden_dim).cuda() + time_1 = 0 + time_2 = 0 + + for i in range(self.n_layer): + t_1 = time.time() + nodes, edges, old_nodes_new_idx = self.loader.get_neighbors(nodes.data.cpu().numpy(), mode) + time_1 += time.time() - t_1 + + t_2 = time.time() + hidden = self.layers[i](q_sub, q_rel, hidden, edges, nodes.size(0), old_nodes_new_idx) + h0 = torch.zeros(1, nodes.size(0), hidden.size(1)).cuda().index_copy_(1, old_nodes_new_idx, h0) + hidden = self.dropout(hidden) + hidden, h0 = self.gru(hidden.unsqueeze(0), h0) + hidden = hidden.squeeze(0) + + if i < self.n_layer - 1: + if self.increase: + hidde, bool_same_nodes = self.soft_to_hard(i, hidden, nodes, n_ent, n, old_nodes_new_idx) + else: + exit() + + nodes = nodes[bool_same_nodes] + hidden = hidden[bool_same_nodes] + h0 = h0[:, bool_same_nodes] + + time_2 += time.time() - t_2 + + self.time_1 = time_1 + self.time_2 = time_2 + scores = self.W_final(hidden).squeeze(-1) + scores_all = torch.zeros((n, n_ent)).cuda() + scores_all[[nodes[:, 0], nodes[:, 1]]] = scores + return scores_all + + + diff --git a/openhgnn/layers/AdapropT.py b/openhgnn/layers/AdapropT.py new file mode 100644 index 00000000..61d4d2c5 --- /dev/null +++ b/openhgnn/layers/AdapropT.py @@ -0,0 +1,196 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import time +import numpy as np +from ..utils.utils import scatter +from collections import defaultdict + + +class GNNLayer(torch.nn.Module): + def __init__(self, in_dim, out_dim, attn_dim, n_rel, n_ent, n_node_topk=-1, n_edge_topk=-1, tau=1.0, + act=lambda x: x): + super(GNNLayer, self).__init__() + self.n_rel = n_rel + self.n_ent = n_ent + self.in_dim = in_dim + self.out_dim = out_dim + self.attn_dim = attn_dim + self.act = act + self.n_node_topk = n_node_topk + self.n_edge_topk = n_edge_topk + self.tau = tau + self.rela_embed = nn.Embedding(2 * n_rel + 1, in_dim) + self.Ws_attn = nn.Linear(in_dim, attn_dim, bias=False) + self.Wr_attn = nn.Linear(in_dim, attn_dim, bias=False) + self.Wqr_attn = nn.Linear(in_dim, attn_dim) + self.w_alpha = nn.Linear(attn_dim, 1) + self.W_h = nn.Linear(in_dim, out_dim, bias=False) + self.W_samp = nn.Linear(in_dim, 1, bias=False) + + def train(self, mode=True): + if not isinstance(mode, bool): + raise ValueError("training mode is expected to be boolean") + self.training = mode + if self.training and self.tau > 0: + self.softmax = lambda x: F.gumbel_softmax(x, tau=self.tau, hard=False) + else: + self.softmax = lambda x: F.softmax(x, dim=1) + for module in self.children(): + module.train(mode) + return self + + def forward(self, q_sub, q_rel, hidden, edges, nodes, old_nodes_new_idx, batchsize): + # edges: [N_edge_of_all_batch, 6] + # with (batch_idx, head, rela, tail, head_idx, tail_idx) + # note that head_idx and tail_idx are relative index + sub = edges[:, 4] + rel = edges[:, 2] + obj = edges[:, 5] + hs = hidden[sub] + hr = self.rela_embed(rel) + r_idx = edges[:, 0] + h_qr = self.rela_embed(q_rel)[r_idx] + n_node = nodes.shape[0] + message = hs + hr + + # sample edges w.r.t. alpha + if self.n_edge_topk > 0: + alpha = self.w_alpha(nn.ReLU()(self.Ws_attn(hs) + self.Wr_attn(hr) + self.Wqr_attn(h_qr))).squeeze(-1) + edge_prob = F.gumbel_softmax(alpha, tau=1, hard=False) + topk_index = torch.argsort(edge_prob, descending=True)[:self.n_edge_topk] + edge_prob_hard = torch.zeros((alpha.shape[0])).cuda() + edge_prob_hard[topk_index] = 1 + alpha *= (edge_prob_hard - edge_prob.detach() + edge_prob) + alpha = torch.sigmoid(alpha).unsqueeze(-1) + + else: + alpha = torch.sigmoid(self.w_alpha( + nn.ReLU()(self.Ws_attn(hs) + self.Wr_attn(hr) + self.Wqr_attn(h_qr)))) # [N_edge_of_all_batch, 1] + + # aggregate message and then propagate + message = alpha * message + message_agg = scatter(message, index=obj, dim=0, dim_size=n_node, reduce='sum') + hidden_new = self.act(self.W_h(message_agg)) # [n_node, dim] + hidden_new = hidden_new.clone() + + # forward without node sampling + if self.n_node_topk <= 0: + return hidden_new + + # forward with node sampling + # indexing sampling operation + tmp_diff_node_idx = torch.ones(n_node) + tmp_diff_node_idx[old_nodes_new_idx] = 0 + bool_diff_node_idx = tmp_diff_node_idx.bool() + diff_node = nodes[bool_diff_node_idx] + + # project logit to fixed-size tensor via indexing + diff_node_logit = self.W_samp(hidden_new[bool_diff_node_idx]).squeeze(-1) # [all_batch_new_nodes] + + # save logit to node_scores for later indexing + node_scores = torch.ones((batchsize, self.n_ent)).cuda() * float('-inf') + node_scores[diff_node[:, 0], diff_node[:, 1]] = diff_node_logit + + # select top-k nodes + # (train mode) self.softmax == F.gumbel_softmax + # (eval mode) self.softmax == F.softmax + node_scores = self.softmax(node_scores) # [batchsize, n_ent] + topk_index = torch.topk(node_scores, self.n_node_topk, dim=1).indices.reshape(-1) + topk_batchidx = torch.arange(batchsize).repeat(self.n_node_topk, 1).T.reshape(-1) + batch_topk_nodes = torch.zeros((batchsize, self.n_ent)).cuda() + batch_topk_nodes[topk_batchidx, topk_index] = 1 + + # get sampled nodes' relative index + bool_sampled_diff_nodes_idx = batch_topk_nodes[diff_node[:, 0], diff_node[:, 1]].bool() + bool_same_node_idx = ~bool_diff_node_idx.cuda() + bool_same_node_idx[bool_diff_node_idx] = bool_sampled_diff_nodes_idx + + # update node embeddings + diff_node_prob_hard = batch_topk_nodes[diff_node[:, 0], diff_node[:, 1]] + diff_node_prob = node_scores[diff_node[:, 0], diff_node[:, 1]] + hidden_new[bool_diff_node_idx] *= (diff_node_prob_hard - diff_node_prob.detach() + diff_node_prob).unsqueeze(-1) + + # extract sampled nodes an their embeddings + new_nodes = nodes[bool_same_node_idx] + hidden_new = hidden_new[bool_same_node_idx] + + return hidden_new, new_nodes, bool_same_node_idx + + +class GNNModel(torch.nn.Module): + def __init__(self, params, loader): + super(GNNModel, self).__init__() + self.n_layer = params.n_layer + self.hidden_dim = params.hidden_dim + self.attn_dim = params.attn_dim + self.n_ent = params.n_ent + self.n_rel = params.n_rel + self.n_node_topk = params.n_node_topk + self.n_edge_topk = params.n_edge_topk + self.loader = loader + acts = {'relu': nn.ReLU(), 'tanh': torch.tanh, 'idd': lambda x: x} + act = acts[params.act] + + self.gnn_layers = [] + for i in range(self.n_layer): + i_n_node_topk = self.n_node_topk if 'int' in str(type(self.n_node_topk)) else self.n_node_topk[i] + self.gnn_layers.append(GNNLayer(self.hidden_dim, self.hidden_dim, self.attn_dim, self.n_rel, self.n_ent, \ + n_node_topk=i_n_node_topk, n_edge_topk=self.n_edge_topk, tau=params.tau, + act=act)) + + self.gnn_layers = nn.ModuleList(self.gnn_layers) + self.dropout = nn.Dropout(params.dropout) + self.W_final = nn.Linear(self.hidden_dim, 1, bias=False) + self.gate = nn.GRU(self.hidden_dim, self.hidden_dim) + + def updateTopkNums(self, topk_list): + assert len(topk_list) == self.n_layer + for idx in range(self.n_layer): + self.gnn_layers[idx].n_node_topk = topk_list[idx] + + def fixSamplingWeight(self): + def freeze(m): + m.requires_grad = False + + for i in range(self.n_layer): + self.gnn_layers[i].W_samp.apply(freeze) + + def forward(self, subs, rels, mode='train'): + n = len(subs) # n == B (Batchsize) + q_sub = torch.LongTensor(subs).cuda() # [B] + q_rel = torch.LongTensor(rels).cuda() # [B] + h0 = torch.zeros((1, n, self.hidden_dim)).cuda() # [1, B, dim] + nodes = torch.cat([torch.arange(n).unsqueeze(1).cuda(), q_sub.unsqueeze(1)], + 1) # [B, 2] with (batch_idx, node_idx) + hidden = torch.zeros(n, self.hidden_dim).cuda() # [B, dim] + + for i in range(self.n_layer): + # layers with sampling + # nodes (of i-th layer): [k1, 2] + # edges (of i-th layer): [k2, 6] + # old_nodes_new_idx (of previous layer): [k1'] + nodes, edges, old_nodes_new_idx = self.loader.get_neighbors(nodes.data.cpu().numpy(), n, mode=mode) + n_node = nodes.size(0) + + # GNN forward -> get hidden representation at i-th layer + # hidden: [k1, dim] + hidden, nodes, sampled_nodes_idx = self.gnn_layers[i](q_sub, q_rel, hidden, edges, nodes, old_nodes_new_idx, + n) + + # combine h0 and hi -> update hi with gate operation + h0 = torch.zeros(1, n_node, hidden.size(1)).cuda().index_copy_(1, old_nodes_new_idx, h0) + h0 = h0[0, sampled_nodes_idx, :].unsqueeze(0) + hidden = self.dropout(hidden) + hidden, h0 = self.gate(hidden.unsqueeze(0), h0) + hidden = hidden.squeeze(0) + + # readout + # [K, 2] (batch_idx, node_idx) K is w.r.t. n_nodes + scores = self.W_final(hidden).squeeze(-1) + # non-visited entities have 0 scores + scores_all = torch.zeros((n, self.loader.n_ent)).cuda() + # [B, n_all_nodes] + scores_all[[nodes[:, 0], nodes[:, 1]]] = scores + + return scores_all \ No newline at end of file diff --git a/openhgnn/layers/__init__.py b/openhgnn/layers/__init__.py index 27ff4fc3..60cdb110 100644 --- a/openhgnn/layers/__init__.py +++ b/openhgnn/layers/__init__.py @@ -5,6 +5,9 @@ from .HeteroGraphConv import HeteroGraphConv from .macro_layer import * from .micro_layer import * +from .AdapropT import * +from .AdapropI import * +from .rgcn_layer import * __all__ = [ 'HeteroEmbedLayer', @@ -19,7 +22,9 @@ 'SemanticAttention', 'CompConv', 'AttConv', - 'LSTMConv' + 'LSTMConv', + 'AdapropT', + 'AdapropI' ] classes = __all__ \ No newline at end of file diff --git a/openhgnn/layers/compgcn_layer.py b/openhgnn/layers/compgcn_layer.py new file mode 100644 index 00000000..3fa0d5bc --- /dev/null +++ b/openhgnn/layers/compgcn_layer.py @@ -0,0 +1,123 @@ +import torch +from torch import nn +import dgl +import dgl.function as fn + + +class CompGCNCov(nn.Module): + def __init__(self, in_channels, out_channels, act=lambda x: x, bias=True, drop_rate=0., opn='corr', num_base=-1, + num_rel=None, wni=False, wsi=False, use_bn=True, ltr=True): + super(CompGCNCov, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.act = act # activation function + self.device = None + self.rel = None + self.opn = opn + + self.use_bn = use_bn + self.ltr = ltr + + # relation-type specific parameter + self.in_w = self.get_param([in_channels, out_channels]) + self.out_w = self.get_param([in_channels, out_channels]) + self.loop_w = self.get_param([in_channels, out_channels]) + # transform embedding of relations to next layer + self.w_rel = self.get_param([in_channels, out_channels]) + self.loop_rel = self.get_param([1, in_channels]) # self-loop embedding + + self.drop = nn.Dropout(drop_rate) + self.bn = torch.nn.BatchNorm1d(out_channels) + self.bias = nn.Parameter(torch.zeros(out_channels)) if bias else None + if num_base > 0: + self.rel_wt = self.get_param([num_rel * 2, num_base]) + else: + self.rel_wt = None + + self.wni = wni + self.wsi = wsi + + def get_param(self, shape): + param = nn.Parameter(torch.Tensor(*shape)) + nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu')) + return param + + def message_func(self, edges): + edge_type = edges.data['type'] # [E, 1] + edge_num = edge_type.shape[0] + edge_data = self.comp( + edges.src['h'], self.rel[edge_type]) # [E, in_channel] + # NOTE: first half edges are all in-directions, last half edges are out-directions. + msg = torch.cat([torch.matmul(edge_data[:edge_num // 2, :], self.in_w), + torch.matmul(edge_data[edge_num // 2:, :], self.out_w)]) + msg = msg * edges.data['norm'].reshape(-1, 1) # [E, D] * [E, 1] + return {'msg': msg} + + def reduce_func(self, nodes): + return {'h': self.drop(nodes.data['h'])} + + def comp(self, h, edge_data): + def com_mult(a, b): + r1, i1 = a[..., 0], a[..., 1] + r2, i2 = b[..., 0], b[..., 1] + return torch.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim=-1) + + def conj(a): + a[..., 1] = -a[..., 1] + return a + + def ccorr(a, b): + return torch.irfft(com_mult(conj(torch.rfft(a, 1)), torch.rfft(b, 1)), 1, signal_sizes=(a.shape[-1],)) + + if self.opn == 'mult': + return h * edge_data + elif self.opn == 'sub': + return h - edge_data + elif self.opn == 'corr': + return ccorr(h, edge_data.expand_as(h)) + else: + raise KeyError(f'composition operator {self.opn} not recognized.') + + def forward(self, g: dgl.DGLGraph, x, rel_repr, edge_type, edge_norm): + """ + :param g: dgl Graph, a graph without self-loop + :param x: input node features, [V, in_channel] + :param rel_repr: input relation features: 1. not using bases: [num_rel*2, in_channel] + 2. using bases: [num_base, in_channel] + :param edge_type: edge type, [E] + :param edge_norm: edge normalization, [E] + :return: x: output node features: [V, out_channel] + rel: output relation features: [num_rel*2, out_channel] + """ + self.device = x.device + g = g.local_var() + g.ndata['h'] = x + g.edata['type'] = edge_type + g.edata['norm'] = edge_norm + if self.rel_wt is None: + self.rel = rel_repr + else: + # [num_rel*2, num_base] @ [num_base, in_c] + self.rel = torch.mm(self.rel_wt, rel_repr) + g.update_all(self.message_func, fn.sum( + msg='msg', out='h'), self.reduce_func) + + if (not self.wni) and (not self.wsi): + x = (g.ndata.pop('h') + + torch.mm(self.comp(x, self.loop_rel), self.loop_w)) / 3 + else: + if self.wsi: + x = g.ndata.pop('h') / 2 + if self.wni: + x = torch.mm(self.comp(x, self.loop_rel), self.loop_w) + + if self.bias is not None: + x = x + self.bias + + if self.use_bn: + x = self.bn(x) + + if self.ltr: + return self.act(x), torch.matmul(self.rel, self.w_rel) + else: + return self.act(x), self.rel diff --git a/openhgnn/layers/rgcn_layer.py b/openhgnn/layers/rgcn_layer.py new file mode 100644 index 00000000..95af258e --- /dev/null +++ b/openhgnn/layers/rgcn_layer.py @@ -0,0 +1,415 @@ +""" +based on the implementation in DGL +(https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/relgraphconv.py) +""" + + +"""Torch Module for Relational graph convolution layer""" +# pylint: disable= no-member, arguments-differ, invalid-name + + + + +import functools +import numpy as np +import torch as th +from torch import nn +import dgl.function as fn +from dgl.nn.pytorch import utils +from dgl.base import DGLError +from dgl.subgraph import edge_subgraph +class RelGraphConv(nn.Module): + r"""Relational graph convolution layer. + + Relational graph convolution is introduced in "`Modeling Relational Data with Graph + Convolutional Networks `__" + and can be described in DGL as below: + + .. math:: + + h_i^{(l+1)} = \sigma(\sum_{r\in\mathcal{R}} + \sum_{j\in\mathcal{N}^r(i)}e_{j,i}W_r^{(l)}h_j^{(l)}+W_0^{(l)}h_i^{(l)}) + + where :math:`\mathcal{N}^r(i)` is the neighbor set of node :math:`i` w.r.t. relation + :math:`r`. :math:`e_{j,i}` is the normalizer. :math:`\sigma` is an activation + function. :math:`W_0` is the self-loop weight. + + The basis regularization decomposes :math:`W_r` by: + + .. math:: + + W_r^{(l)} = \sum_{b=1}^B a_{rb}^{(l)}V_b^{(l)} + + where :math:`B` is the number of bases, :math:`V_b^{(l)}` are linearly combined + with coefficients :math:`a_{rb}^{(l)}`. + + The block-diagonal-decomposition regularization decomposes :math:`W_r` into :math:`B` + number of block diagonal matrices. We refer :math:`B` as the number of bases. + + The block regularization decomposes :math:`W_r` by: + + .. math:: + + W_r^{(l)} = \oplus_{b=1}^B Q_{rb}^{(l)} + + where :math:`B` is the number of bases, :math:`Q_{rb}^{(l)}` are block + bases with shape :math:`R^{(d^{(l+1)}/B)*(d^{l}/B)}`. + + Parameters + ---------- + in_feat : int + Input feature size; i.e, the number of dimensions of :math:`h_j^{(l)}`. + out_feat : int + Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`. + num_rels : int + Number of relations. . + regularizer : str + Which weight regularizer to use "basis" or "bdd". + "basis" is short for basis-diagonal-decomposition. + "bdd" is short for block-diagonal-decomposition. + num_bases : int, optional + Number of bases. If is none, use number of relations. Default: ``None``. + bias : bool, optional + True if bias is added. Default: ``True``. + activation : callable, optional + Activation function. Default: ``None``. + self_loop : bool, optional + True to include self loop message. Default: ``True``. + low_mem : bool, optional + True to use low memory implementation of relation message passing function. Default: False. + This option trades speed with memory consumption, and will slowdown the forward/backward. + Turn it on when you encounter OOM problem during training or evaluation. Default: ``False``. + dropout : float, optional + Dropout rate. Default: ``0.0`` + layer_norm: float, optional + Add layer norm. Default: ``False`` + + Examples + -------- + >>> import dgl + >>> import numpy as np + >>> import torch as th + >>> from dgl.nn import RelGraphConv + >>> + >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) + >>> feat = th.ones(6, 10) + >>> conv = RelGraphConv(10, 2, 3, regularizer='basis', num_bases=2) + >>> conv.weight.shape + torch.Size([2, 10, 2]) + >>> etype = th.tensor(np.array([0,1,2,0,1,2]).astype(np.int64)) + >>> res = conv(g, feat, etype) + >>> res + tensor([[ 0.3996, -2.3303], + [-0.4323, -0.1440], + [ 0.3996, -2.3303], + [ 2.1046, -2.8654], + [-0.4323, -0.1440], + [-0.1309, -1.0000]], grad_fn=) + + >>> # One-hot input + >>> one_hot_feat = th.tensor(np.array([0,1,2,3,4,5]).astype(np.int64)) + >>> res = conv(g, one_hot_feat, etype) + >>> res + tensor([[ 0.5925, 0.0985], + [-0.3953, 0.8408], + [-0.9819, 0.5284], + [-1.0085, -0.1721], + [ 0.5962, 1.2002], + [ 0.0365, -0.3532]], grad_fn=) + """ + + def __init__(self, + in_feat, + out_feat, + num_rels, + regularizer="basis", + num_bases=None, + bias=True, + activation=None, + self_loop=True, + low_mem=False, + dropout=0.0, + layer_norm=False, + wni=False): + super(RelGraphConv, self).__init__() + self.in_feat = in_feat + self.out_feat = out_feat + self.num_rels = num_rels + self.regularizer = regularizer + self.num_bases = num_bases + if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases <= 0: + self.num_bases = self.num_rels + self.bias = bias + self.activation = activation + self.self_loop = self_loop + self.low_mem = low_mem + self.layer_norm = layer_norm + + self.wni = wni + + if regularizer == "basis": + # add basis weights + self.weight = nn.Parameter( + th.Tensor(self.num_bases, self.in_feat, self.out_feat)) + if self.num_bases < self.num_rels: + # linear combination coefficients + self.w_comp = nn.Parameter( + th.Tensor(self.num_rels, self.num_bases)) + nn.init.xavier_uniform_( + self.weight, gain=nn.init.calculate_gain('relu')) + if self.num_bases < self.num_rels: + nn.init.xavier_uniform_(self.w_comp, + gain=nn.init.calculate_gain('relu')) + # message func + self.message_func = self.basis_message_func + elif regularizer == "bdd": + if in_feat % self.num_bases != 0 or out_feat % self.num_bases != 0: + raise ValueError( + 'Feature size must be a multiplier of num_bases (%d).' + % self.num_bases + ) + # add block diagonal weights + self.submat_in = in_feat // self.num_bases + self.submat_out = out_feat // self.num_bases + + # assuming in_feat and out_feat are both divisible by num_bases + self.weight = nn.Parameter(th.Tensor( + self.num_rels, self.num_bases * self.submat_in * self.submat_out)) + nn.init.xavier_uniform_( + self.weight, gain=nn.init.calculate_gain('relu')) + # message func + self.message_func = self.bdd_message_func + else: + raise ValueError("Regularizer must be either 'basis' or 'bdd'") + + # bias + if self.bias: + self.h_bias = nn.Parameter(th.Tensor(out_feat)) + nn.init.zeros_(self.h_bias) + + # layer norm + if self.layer_norm: + self.layer_norm_weight = nn.LayerNorm( + out_feat, elementwise_affine=True) + + # weight for self loop + if self.self_loop: + self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat)) + nn.init.xavier_uniform_(self.loop_weight, + gain=nn.init.calculate_gain('relu')) + + self.dropout = nn.Dropout(dropout) + + def basis_message_func(self, edges, etypes): + """Message function for basis regularizer. + + Parameters + ---------- + edges : dgl.EdgeBatch + Input to DGL message UDF. + etypes : torch.Tensor or list[int] + Edge type data. Could be either: + + * An :math:`(|E|,)` dense tensor. Each element corresponds to the edge's type ID. + Preferred format if ``lowmem == False``. + * An integer list. The i^th element is the number of edges of the i^th type. + This requires the input graph to store edges sorted by their type IDs. + Preferred format if ``lowmem == True``. + """ + if self.num_bases < self.num_rels: + # generate all weights from bases + weight = self.weight.view(self.num_bases, + self.in_feat * self.out_feat) + weight = th.matmul(self.w_comp, weight).view( + self.num_rels, self.in_feat, self.out_feat) + else: + weight = self.weight + + h = edges.src['h'] + device = h.device + + if h.dtype == th.int64 and h.ndim == 1: + # Each element is the node's ID. Use index select: weight[etypes, h, :] + # The following is a faster version of it. + if isinstance(etypes, list): + etypes = th.repeat_interleave(th.arange(len(etypes), device=device), + th.tensor(etypes, device=device)) + idim = weight.shape[1] + weight = weight.view(-1, weight.shape[2]) + flatidx = etypes * idim + h + msg = weight.index_select(0, flatidx) + elif self.low_mem: + # A more memory-friendly implementation. + # Calculate msg @ W_r before put msg into edge. + assert isinstance(etypes, list) + h_t = th.split(h, etypes) + msg = [] + for etype in range(self.num_rels): + if h_t[etype].shape[0] == 0: + continue + msg.append(th.matmul(h_t[etype], weight[etype])) + msg = th.cat(msg) + else: + # Use batched matmult + if isinstance(etypes, list): + etypes = th.repeat_interleave(th.arange(len(etypes), device=device), + th.tensor(etypes, device=device)) + weight = weight.index_select(0, etypes) + msg = th.bmm(h.unsqueeze(1), weight).squeeze(1) + + if 'norm' in edges.data: + msg = msg * edges.data['norm'] + return {'msg': msg} + + def bdd_message_func(self, edges, etypes): + """Message function for block-diagonal-decomposition regularizer. + + Parameters + ---------- + edges : dgl.EdgeBatch + Input to DGL message UDF. + etypes : torch.Tensor or list[int] + Edge type data. Could be either: + + * An :math:`(|E|,)` dense tensor. Each element corresponds to the edge's type ID. + Preferred format if ``lowmem == False``. + * An integer list. The i^th element is the number of edges of the i^th type. + This requires the input graph to store edges sorted by their type IDs. + Preferred format if ``lowmem == True``. + """ + h = edges.src['h'] + device = h.device + + if h.dtype == th.int64 and h.ndim == 1: + raise TypeError( + 'Block decomposition does not allow integer ID feature.') + + if self.low_mem: + # A more memory-friendly implementation. + # Calculate msg @ W_r before put msg into edge. + assert isinstance(etypes, list) + h_t = th.split(h, etypes) + msg = [] + for etype in range(self.num_rels): + if h_t[etype].shape[0] == 0: + continue + tmp_w = self.weight[etype].view( + self.num_bases, self.submat_in, self.submat_out) + tmp_h = h_t[etype].view(-1, self.num_bases, self.submat_in) + msg.append(th.einsum('abc,bcd->abd', tmp_h, + tmp_w).reshape(-1, self.out_feat)) + msg = th.cat(msg) + else: + # Use batched matmult + if isinstance(etypes, list): + etypes = th.repeat_interleave(th.arange(len(etypes), device=device), + th.tensor(etypes, device=device)) + weight = self.weight.index_select(0, etypes).view( + -1, self.submat_in, self.submat_out) + node = h.view(-1, 1, self.submat_in) + msg = th.bmm(node, weight).view(-1, self.out_feat) + if 'norm' in edges.data: + msg = msg * edges.data['norm'] + return {'msg': msg} + + def forward(self, g, feat, etypes, norm=None): + """Forward computation. + + Parameters + ---------- + g : DGLGraph + The graph. + feat : torch.Tensor + Input node features. Could be either + + * :math:`(|V|, D)` dense tensor + * :math:`(|V|,)` int64 vector, representing the categorical values of each + node. It then treat the input feature as an one-hot encoding feature. + etypes : torch.Tensor or list[int] + Edge type data. Could be either + + * An :math:`(|E|,)` dense tensor. Each element corresponds to the edge's type ID. + Preferred format if ``lowmem == False``. + * An integer list. The i^th element is the number of edges of the i^th type. + This requires the input graph to store edges sorted by their type IDs. + Preferred format if ``lowmem == True``. + norm : torch.Tensor, optional + Edge normalizer. Could be either + + * An :math:`(|E|, 1)` tensor storing the normalizer on each edge. + + Returns + ------- + torch.Tensor + New node features. + + Notes + ----- + Under the ``low_mem`` mode, DGL will sort the graph based on the edge types + and compute message passing one type at a time. DGL recommends sorts the + graph beforehand (and cache it if possible) and provides the integer list + format to the ``etypes`` argument. Use DGL's :func:`~dgl.to_homogeneous` API + to get a sorted homogeneous graph from a heterogeneous graph. Pass ``return_count=True`` + to it to get the ``etypes`` in integer list. + """ + if isinstance(etypes, th.Tensor): + if len(etypes) != g.num_edges(): + raise DGLError('"etypes" tensor must have length equal to the number of edges' + ' in the graph. But got {} and {}.'.format( + len(etypes), g.num_edges())) + if self.low_mem and not (feat.dtype == th.int64 and feat.ndim == 1): + # Low-mem optimization is not enabled for node ID input. When enabled, + # it first sorts the graph based on the edge types (the sorting will not + # change the node IDs). It then converts the etypes tensor to an integer + # list, where each element is the number of edges of the type. + # Sort the graph based on the etypes + sorted_etypes, index = th.sort(etypes) + g = edge_subgraph(g, index, relabel_nodes=False) + # Create a new etypes to be an integer list of number of edges. + pos = _searchsorted(sorted_etypes, th.arange( + self.num_rels, device=g.device)) + num = th.tensor([len(etypes)], device=g.device) + etypes = (th.cat([pos[1:], num]) - pos).tolist() + if norm is not None: + norm = norm[index] + + with g.local_scope(): + g.srcdata['h'] = feat + if norm is not None: + g.edata['norm'] = norm + if self.self_loop: + loop_message = utils.matmul_maybe_select(feat[:g.number_of_dst_nodes()], + self.loop_weight) + + if not self.wni: + # message passing + g.update_all(functools.partial(self.message_func, etypes=etypes), + fn.sum(msg='msg', out='h')) + # apply bias and activation + node_repr = g.dstdata['h'] + if self.layer_norm: + node_repr = self.layer_norm_weight(node_repr) + if self.bias: + node_repr = node_repr + self.h_bias + else: + node_repr = 0 + + if self.self_loop: + node_repr = node_repr + loop_message + if self.activation: + node_repr = self.activation(node_repr) + node_repr = self.dropout(node_repr) + return node_repr + + +_TORCH_HAS_SEARCHSORTED = getattr(th, 'searchsorted', None) + + +def _searchsorted(sorted_sequence, values): + # searchsorted is introduced to PyTorch in 1.6.0 + if _TORCH_HAS_SEARCHSORTED: + return th.searchsorted(sorted_sequence, values) + else: + device = values.device + return th.from_numpy(np.searchsorted(sorted_sequence.cpu().numpy(), + values.cpu().numpy())).to(device) diff --git a/openhgnn/models/AdapropI.py b/openhgnn/models/AdapropI.py new file mode 100644 index 00000000..7afc8cbe --- /dev/null +++ b/openhgnn/models/AdapropI.py @@ -0,0 +1,144 @@ +import torch +import torch.nn as nn +import numpy as np +import time +from torch.optim import Adam +from torch.optim.lr_scheduler import ExponentialLR +from ..layers.AdapropI import GNNModel +from ..utils.AdapropI_utils import * +from tqdm import tqdm +from . import BaseModel, register_model + +@register_model('AdapropI') +class AdapropI(BaseModel): + @classmethod + def build_model_from_args(cls, config,loader): + return cls(config,loader) + + def __init__(self, config,loader): + super().__init__() + self.model = AdapropI_Base(config,loader) + + def forward(self, *args): + return self.model(*args) + + def extra_loss(self): + pass +class AdapropI_Base(object): + def __init__(self, args, loader): + self.model = GNNModel(args, loader) + self.model.cuda() + self.loader = loader + self.n_ent = loader.n_ent + self.n_ent_ind = loader.n_ent_ind + self.n_batch = args.n_batch + self.n_train = loader.n_train + self.n_valid = loader.n_valid + self.n_test = loader.n_test + self.n_layer = args.n_layer + self.optimizer = Adam(self.model.parameters(), lr=args.lr, weight_decay=args.lamb) + self.scheduler = ExponentialLR(self.optimizer, args.decay_rate) + self.smooth = 1e-5 + self.params = args + + def train_batch(self, ): + epoch_loss = 0 + i = 0 + batch_size = self.n_batch + n_batch = self.n_train // batch_size + (self.n_train % batch_size > 0) + self.model.train() + self.time_1 = 0 + self.time_2 = 0 + + for i in range(n_batch): + start = i * batch_size + end = min(self.n_train, (i + 1) * batch_size) + batch_idx = np.arange(start, end) + triple = self.loader.get_batch(batch_idx) + + self.model.zero_grad() + scores = self.model(triple[:, 0], triple[:, 1]) + pos_scores = scores[[torch.arange(len(scores)).cuda(), torch.LongTensor(triple[:, 2]).cuda()]] + self.time_1 += self.model.time_1 + self.time_2 += self.model.time_2 + t_2 = time.time() + + max_n = torch.max(scores, 1, keepdim=True)[0] + loss = torch.sum(- pos_scores + max_n + torch.log(torch.sum(torch.exp(scores - max_n), 1))) + + loss.backward() + self.optimizer.step() + self.time_2 += time.time() - t_2 + + for p in self.model.parameters(): + X = p.data.clone() + flag = X != X + X[flag] = np.random.random() + p.data.copy_(X) + epoch_loss += loss.item() + + self.loader.shuffle_train() + self.scheduler.step() + valid_mrr, test_mrr, out_str = self.evaluate() + return valid_mrr, test_mrr, out_str + + def evaluate(self, ): + batch_size = self.n_batch + n_data = self.n_valid + n_batch = n_data // batch_size + (n_data % batch_size > 0) + ranking = [] + masks = [] + self.model.eval() + time_3 = time.time() + for i in range(n_batch): + start = i * batch_size + end = min(n_data, (i + 1) * batch_size) + batch_idx = np.arange(start, end) + subs, rels, objs = self.loader.get_batch(batch_idx, data='valid') + scores = self.model(subs, rels).data.cpu().numpy() + filters = [] + for i in range(len(subs)): + filt = self.loader.val_filters[(subs[i], rels[i])] + filt_1hot = np.zeros((self.n_ent,)) + filt_1hot[np.array(filt)] = 1 + filters.append(filt_1hot) + + masks += [self.n_ent - len(filt)] * int(objs[i].sum()) + + filters = np.array(filters) + ranks = cal_ranks(scores, objs, filters) + ranking += ranks + + ranking = np.array(ranking) + v_mrr, v_mr, v_h1, v_h3, v_h10, v_h1050 = cal_performance(ranking, masks) + + n_data = self.n_test + n_batch = n_data // batch_size + (n_data % batch_size > 0) + ranking = [] + masks = [] + self.model.eval() + for i in range(n_batch): + start = i * batch_size + end = min(n_data, (i + 1) * batch_size) + batch_idx = np.arange(start, end) + subs, rels, objs = self.loader.get_batch(batch_idx, data='test') + scores = self.model(subs, rels, 'inductive').data.cpu().numpy() + filters = [] + for i in range(len(subs)): + filt = self.loader.tst_filters[(subs[i], rels[i])] + filt_1hot = np.zeros((self.n_ent_ind,)) + filt_1hot[np.array(filt)] = 1 + filters.append(filt_1hot) + masks += [self.n_ent_ind - len(filt)] * int(objs[i].sum()) + + filters = np.array(filters) + ranks = cal_ranks(scores, objs, filters) + ranking += ranks + ranking = np.array(ranking) + t_mrr, t_mr, t_h1, t_h3, t_h10, t_h1050 = cal_performance(ranking, masks) + time_3 = time.time() - time_3 + + out_str = '%.4f %.4f %.4f\t%.4f %.1f %.4f %.4f %.4f %.4f\t\t%.4f %.1f %.4f %.4f %.4f %.4f\n' % ( + self.time_1, self.time_2, time_3, v_mrr, v_mr, v_h1, v_h3, v_h10, v_h1050, t_mrr, t_mr, t_h1, t_h3, t_h10, + t_h1050) + return v_h10, t_h10, out_str diff --git a/openhgnn/models/AdapropT.py b/openhgnn/models/AdapropT.py new file mode 100644 index 00000000..7cd3f04f --- /dev/null +++ b/openhgnn/models/AdapropT.py @@ -0,0 +1,209 @@ +import torch +import numpy as np +import time +import os +from torch.optim import Adam +from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau +from ..layers.AdapropT import GNNModel +from ..utils.Adaprop_utils import * +from tqdm import tqdm +from . import BaseModel, register_model + +@register_model('AdapropT') +class AdapropT(BaseModel): + @classmethod + def build_model_from_args(cls, config,loader): + return cls(config,loader) + + def __init__(self, config,loader): + super().__init__() + self.model = AdapropT_Base(config,loader) + + def forward(self, *args): + return self.model(*args) + + def extra_loss(self): + pass + +class AdapropT_Base(object): + def __init__(self, args, loader): + self.model = GNNModel(args, loader) + self.model.cuda() + self.loader = loader + self.n_ent = loader.n_ent + self.n_rel = loader.n_rel + self.n_batch = args.n_batch + self.n_tbatch = args.n_tbatch + self.n_train = loader.n_train + self.n_valid = loader.n_valid + self.n_test = loader.n_test + self.n_layer = args.n_layer + self.args = args + self.optimizer = Adam(self.model.parameters(), lr=args.lr, weight_decay=args.lamb) + + if self.args.scheduler == 'exp': + self.scheduler = ExponentialLR(self.optimizer, args.decay_rate) + else: + raise NotImplementedError(f'==> [Error] {self.scheduler} scheduler is not supported yet.') + + self.t_time = 0 + self.lastSaveGNNPath = None + self.modelName = f'{args.n_layer}-layers' + for i in range(args.n_layer): + i_n_node_topk = args.n_node_topk if 'int' in str(type(args.n_node_topk)) else args.n_node_topk[i] + self.modelName += f'-{i_n_node_topk}' + print(f'==> model name: {self.modelName}') + + def _update(self): + self.optimizer = Adam(self.model.parameters(), lr=self.args.lr, weight_decay=self.args.lamb) + + def saveModelToFiles(self, best_metric, deleteLastFile=True): + savePath = f'{self.loader.task_dir}/saveModel/{self.modelName}-{best_metric}.pt' + print(f'Save checkpoint to : {savePath}') + torch.save({ + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'best_mrr': best_metric, + }, savePath) + + if deleteLastFile and self.lastSaveGNNPath != None: + print(f'Remove last checkpoint: {self.lastSaveGNNPath}') + os.remove(self.lastSaveGNNPath) + + self.lastSaveGNNPath = savePath + + def loadModel(self, filePath, layers=-1): + print(f'Load weight from {filePath}') + assert os.path.exists(filePath) + checkpoint = torch.load(filePath, map_location=torch.device(f'cuda:{self.args.gpu}')) + + if layers != -1: + extra_layers = self.model.gnn_layers[layers:] + self.model.gnn_layers = self.model.gnn_layers[:layers] + self.model.load_state_dict(checkpoint['model_state_dict']) + self.model.gnn_layers += extra_layers + else: + self.model.load_state_dict(checkpoint['model_state_dict']) + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + def train_batch(self, ): + epoch_loss = 0 + i = 0 + batch_size = self.n_batch + n_batch = self.loader.n_train // batch_size + (self.loader.n_train % batch_size > 0) + t_time = time.time() + self.model.train() + + for i in tqdm(range(n_batch)): + start = i * batch_size + end = min(self.loader.n_train, (i + 1) * batch_size) + batch_idx = np.arange(start, end) + triple = self.loader.get_batch(batch_idx) + + self.model.zero_grad() + scores = self.model(triple[:, 0], triple[:, 1]) + pos_scores = scores[[torch.arange(len(scores)).cuda(), torch.LongTensor(triple[:, 2]).cuda()]] + + max_n = torch.max(scores, 1, keepdim=True)[0] + loss = torch.sum(- pos_scores + max_n + torch.log(torch.sum(torch.exp(scores - max_n), 1))) + loss.backward() + self.optimizer.step() + + # avoid NaN + for p in self.model.parameters(): + X = p.data.clone() + flag = X != X + X[flag] = np.random.random() + p.data.copy_(X) + epoch_loss += loss.item() + + self.t_time += time.time() - t_time + + if self.args.scheduler == 'exp': + self.scheduler.step() + + self.loader.shuffle_train() + + return + + def evaluate(self, verbose=True, eval_val=True, eval_test=False, recordDistance=False): + batch_size = self.n_tbatch + n_data = self.n_valid + n_batch = n_data // batch_size + (n_data % batch_size > 0) + ranking = [] + self.model.eval() + i_time = time.time() + + # - - - - - - val set - - - - - - + if not eval_val: + v_mrr, v_h1, v_h10 = 0, 0, 0 + else: + iterator = tqdm(range(n_batch)) if verbose else range(n_batch) + for i in iterator: + start = i * batch_size + end = min(n_data, (i + 1) * batch_size) + batch_idx = np.arange(start, end) + subs, rels, objs = self.loader.get_batch(batch_idx, data='valid') + scores = self.model(subs, rels, mode='valid') + scores = scores.data.cpu().numpy() + + filters = [] + for i in range(len(subs)): + filt = self.loader.filters[(subs[i], rels[i])] + filt_1hot = np.zeros((self.n_ent,)) + filt_1hot[np.array(filt)] = 1 + filters.append(filt_1hot) + + # scores / objs / filters: [batch_size, n_ent] + filters = np.array(filters) + ranks = cal_ranks(scores, objs, filters) + ranking += ranks + + ranking = np.array(ranking) + v_mrr, v_h1, v_h10 = cal_performance(ranking) + + # - - - - - - test set - - - - - - + if not eval_test: + t_mrr, t_h1, t_h10 = -1, -1, -1 + else: + n_data = self.n_test + n_batch = n_data // batch_size + (n_data % batch_size > 0) + ranking = [] + self.model.eval() + iterator = tqdm(range(n_batch)) if verbose else range(n_batch) + + for i in iterator: + start = i * batch_size + end = min(n_data, (i + 1) * batch_size) + batch_idx = np.arange(start, end) + subs, rels, objs = self.loader.get_batch(batch_idx, data='test') + scores = self.model(subs, rels, mode='test') + scores = scores.data.cpu().numpy() + + filters = [] + for i in range(len(subs)): + filt = self.loader.filters[(subs[i], rels[i])] + filt_1hot = np.zeros((self.n_ent,)) + filt_1hot[np.array(filt)] = 1 + filters.append(filt_1hot) + + filters = np.array(filters) + ranks = cal_ranks(scores, objs, filters) + ranking += ranks + + ranking = np.array(ranking) + t_mrr, t_h1, t_h10 = cal_performance(ranking) + + i_time = time.time() - i_time + out_str = '[VALID] MRR:%.4f H@1:%.4f H@10:%.4f\t [TEST] MRR:%.4f H@1:%.4f H@10:%.4f \t[TIME] train:%.4f inference:%.4f\n' % ( + v_mrr, v_h1, v_h10, t_mrr, t_h1, t_h10, self.t_time, i_time) + + result_dict = {} + result_dict['v_mrr'] = v_mrr + result_dict['v_h1'] = v_h1 + result_dict['v_h10'] = v_h10 + result_dict['t_mrr'] = t_mrr + result_dict['t_h1'] = t_h1 + result_dict['t_h10'] = t_h10 + + return result_dict, out_str \ No newline at end of file diff --git a/openhgnn/models/ComPILE.py b/openhgnn/models/ComPILE.py new file mode 100644 index 00000000..6eecd6ba --- /dev/null +++ b/openhgnn/models/ComPILE.py @@ -0,0 +1,647 @@ +import os + +import math +import numpy as np +import torch +import torch.nn as nn +import dgl.function as fn +import torch.nn.functional as F +from . import BaseModel, register_model +import torch.nn.functional as F +from torch.nn import Identity +from dgl import mean_nodes +import abc + +@register_model('ComPILE') +class ComPILE(BaseModel): + @classmethod + def build_model_from_args(cls, args, relation2id): + return cls(args,relation2id) + + def __init__(self, args, relation2id): + super(ComPILE, self).__init__() + super().__init__() + self.relation2id = relation2id + self.max_label_value = args.max_label_value + self.params = args + self.latent_dim = self.params.emb_dim + self.output_dim = 1 + self.node_emb = self.params.inp_dim + self.relation_emb = self.params.rel_emb_dim + self.edge_emb = self.node_emb * 2 + self.relation_emb + self.hidden_size = self.params.emb_dim + self.num_relation = self.params.num_rels + + self.final_relation_embeddings = nn.Parameter(torch.randn(self.params.num_rels, self.params.rel_emb_dim)) + self.relation_to_edge = nn.Linear(self.params.rel_emb_dim, self.hidden_size) + + # self.linear1 = nn.Linear(self.params.emb_dim + self.relation_emb + 2*self.params.emb_dim, 16) + self.linear1 = nn.Linear(self.params.emb_dim, 16) + self.linear2 = nn.Linear(16, 1) + + self.node_fdim = self.node_emb + self.edge_fdim = self.edge_emb + + self.bias = False + self.depth = 3 + self.dropout = 0.5 + self.layers_per_message = 1 + self.undirected = False + self.node_messages = False + self.args = args + # Dropout + self.dropout_layer = nn.Dropout(p=self.dropout) + + # Activation + self.act_func = get_activation_function('ReLU') + + # Cached zeros + self.cached_zero_vector = nn.Parameter(torch.zeros(self.hidden_size), requires_grad=False) + # Input + input_dim = self.node_fdim + self.W_i_node = nn.Linear(input_dim, self.hidden_size, bias=self.bias) + input_dim = self.edge_fdim + self.W_i_edge = nn.Linear(input_dim, self.hidden_size, bias=self.bias) + + w_h_input_size_node = self.hidden_size + self.edge_fdim + self.W_h_node = nn.Linear(w_h_input_size_node, self.hidden_size, bias=self.bias) + + self.input_attention1 = nn.Linear(self.hidden_size * 2, self.hidden_size, bias=self.bias) + self.input_attention2 = nn.Linear(self.hidden_size, 1, bias=self.bias) + + w_h_input_size_edge = self.hidden_size + for depth in range(self.depth - 1): + self._modules['W_h_edge_{}'.format(depth)] = nn.Linear(w_h_input_size_edge, self.hidden_size, + bias=self.bias) + # self._modules['W_h_edge_{}'.format(depth)] = nn.Linear(w_h_input_size_edge * 3 + self.params.rel_emb_dim, self.hidden_size, bias=self.bias) + self._modules['Attention1_{}'.format(depth)] = nn.Linear(self.hidden_size + self.relation_emb, + self.hidden_size, bias=self.bias) + self._modules['Attention2_{}'.format(depth)] = nn.Linear(self.hidden_size, 1, bias=self.bias) + + self.W_o = nn.Linear(self.hidden_size * 2, self.hidden_size) + + self.gru = BatchGRU(self.hidden_size) + + self.communicate_mlp = nn.Linear(self.hidden_size * 3, self.hidden_size, bias=self.bias) + + for depth in range(self.depth - 1): + self._modules['W_h_node_{}'.format(depth)] = nn.Linear(self.hidden_size, self.hidden_size, bias=self.bias) + + def forward(self, subgraph): + + target_relation = [] + for i in range(len(subgraph)): + graph = subgraph[i] + target = graph.edata['label'][-1].squeeze() + target_relation.append(self.final_relation_embeddings[target, :].unsqueeze(0)) + target_relation = torch.cat(target_relation, dim=0) + graph_embed, source_embed, target_embed = self.batch_subgraph(subgraph) + # print(graph_embed.shape, source_embed.shape, target_embed.shape, target_relation.shape) + # conv_input = torch.cat((source_embed, target_relation, target_embed, graph_embed), dim=1) + # conv_input = torch.cat([graph_embed, source_embed + target_relation -target_embed], dim=-1) + + # conv_input = (graph_embed) + torch.tanh(source_embed + target_relation -target_embed) + + conv_input = torch.tanh(source_embed + target_relation - target_embed) + out_conv = (self.linear1(conv_input)) + out_conv = self.linear2(out_conv) + return out_conv + + def batch_subgraph(self, subgraph): + + graph_sizes = []; + node_feat = [] + list_num_nodes = np.zeros((len(subgraph),), dtype=np.int32) + list_num_edges = np.zeros((len(subgraph),), dtype=np.int32) + node_count = 0; + edge_count = 0; + edge_feat = [] + total_edge = []; + source_node = []; + target_node = [] + total_target_relation = []; + total_edge2 = [] + total_source = []; + total_target = [] + for i in range(len(subgraph)): + graph = subgraph[i] + node_embedding = graph.ndata['feat'] + node_feat.append(node_embedding) + + graph_sizes.append(graph.number_of_nodes()) + list_num_nodes[i] = graph.number_of_nodes() + list_num_edges[i] = graph.number_of_edges() + + nodes = list((graph.nodes()).data.numpy()) + source = list((graph.edges()[0]).data.numpy()) + target = list((graph.edges()[1]).data.numpy()) + relation = graph.edata['type'] + relation_now = self.final_relation_embeddings[relation, :] + + target_relation = graph.edata['label'] + target_relation_now = self.final_relation_embeddings[target_relation, :] + total_target_relation.append(target_relation_now) + + mapping = dict(zip(nodes, [i for i in range(node_count, node_count + list_num_nodes[i])])) + + source_map_now = np.array([mapping[v] for v in source]) - node_count + target_map_now = np.array([mapping[v] for v in target]) - node_count + source_embed = node_embedding[source_map_now, :] + target_embed = node_embedding[target_map_now, :] + source_embed = source_embed.to(device=self.final_relation_embeddings.device) + target_embed = target_embed.to(device=self.final_relation_embeddings.device) + + edge_embed = torch.cat([source_embed, relation_now, target_embed], dim=1) + # edge_embed = source_embed + relation_now - target_embed + edge_feat.append(edge_embed) + + source_now = (graph.ndata['id'] == 1).nonzero().squeeze() + node_count + target_now = (graph.ndata['id'] == 2).nonzero().squeeze() + node_count + source_node.append(source_now) + target_node.append(target_now) + + target_now = target_now.unsqueeze(0).repeat(list_num_edges[i], 1).long() + source_now = source_now.unsqueeze(0).repeat(list_num_edges[i], 1).long() + total_source.append(source_now); + total_target.append(target_now) + + node_count += list_num_nodes[i] + + source_map = torch.LongTensor(np.array([mapping[v] for v in source])).unsqueeze(0) + target_map = torch.LongTensor(np.array([mapping[v] for v in target])).unsqueeze(0) + + edge_pair = torch.cat([target_map, torch.LongTensor( + np.array(range(edge_count, edge_count + list_num_edges[i]))).unsqueeze(0)], dim=0) + + edge_pair2 = torch.cat([source_map, torch.LongTensor( + np.array(range(edge_count, edge_count + list_num_edges[i]))).unsqueeze(0)], dim=0) + + edge_count += list_num_edges[i] + total_edge.append(edge_pair) + total_edge2.append(edge_pair2) + + source_node = np.array(source_node); + target_node = np.array(target_node) + + total_edge = torch.cat(total_edge, dim=1) + total_edge2 = torch.cat(total_edge2, dim=1) + total_target_relation = torch.cat(total_target_relation, dim=0) + total_source = torch.cat(total_source, dim=0) + total_target = torch.cat(total_target, dim=0) + + total_num_nodes = np.sum(list_num_nodes) + total_num_edges = np.sum(list_num_edges) + + e2n_value = torch.FloatTensor(torch.ones(total_edge.shape[1])) + e2n_sp = torch.sparse.FloatTensor(total_edge, e2n_value, torch.Size([total_num_nodes, total_num_edges])) + e2n_sp2 = torch.sparse.FloatTensor(total_edge2, e2n_value, torch.Size([total_num_nodes, total_num_edges])) + # e2n_sp = F.normalize(e2n_sp, dim=2, p=1) + + node_feat = torch.cat(node_feat, dim=0) + e2n_sp = e2n_sp.to(device=self.final_relation_embeddings.device) + e2n_sp2 = e2n_sp2.to(device=self.final_relation_embeddings.device) + node_feat = node_feat.to(device=self.final_relation_embeddings.device) + + edge_feat = torch.cat(edge_feat, dim=0) + graph_embed, source_embed, target_embed = self.gnn(node_feat, edge_feat, e2n_sp, e2n_sp2, graph_sizes, + total_target_relation, total_source, total_target, + source_node, target_node, list(list_num_edges)) + + return graph_embed, source_embed, target_embed + + def gnn(self, node_feat, edge_feat, e2n_sp, e2n_sp2, graph_sizes, target_relation, total_source, total_target, + source_node, target_node, edge_sizes=None, node_degs=None): + + input_node = self.W_i_node(node_feat) # num_nodes x hidden_size + input_node = self.act_func(input_node) + message_node = input_node.clone() + relation_embed = (edge_feat[:, self.node_emb: self.node_emb + self.relation_emb]) + + input_edge = self.W_i_edge(edge_feat) # num_edges x hidden_size + message_edge = self.act_func(input_edge) + input_edge = self.act_func(input_edge) + + graph_source_embed = message_node[total_source, :].squeeze(1) + graph_target_embed = message_node[total_target, :].squeeze(1) + graph_edge_embed = graph_source_embed + target_relation - graph_target_embed + edge_target_message = gnn_spmm(e2n_sp.t(), message_node) + edge_source_message = gnn_spmm(e2n_sp2.t(), message_node) + edge_message = edge_source_message + relation_embed - edge_target_message + # print(total_source.shape, total_target.shape, graph_source_embed.shape) + attention = torch.cat([graph_edge_embed, edge_message], dim=1) + attention = torch.relu(self.input_attention1(attention)) + attention = torch.sigmoid(self.input_attention2(attention)) + + # Message passing + for depth in range(self.depth - 1): + # agg_message = index_select_ND(message_edge, a2b) + # agg_message = agg_message.sum(dim=1) * agg_message.max(dim=1)[0] + # agg_message = gnn_spmm(e2n_sp, message_edge)/e2n_sp.sum(1, keepdim=True) + message_edge = (message_edge * attention) + agg_message = gnn_spmm(e2n_sp, message_edge) + message_node = message_node + agg_message + message_node = self.act_func(self._modules['W_h_node_{}'.format(depth)](message_node)) + + # directed graph + # rev_message = message_edge[b2revb] # num_edges x hidden + # message_edge = message_node[b2a] - rev_message # num_edges x hidden + edge_target_message = gnn_spmm(e2n_sp.t(), message_node) + edge_source_message = gnn_spmm(e2n_sp2.t(), message_node) + # message_edge = torch.cat([message_edge, edge_source_message, relation_embed, edge_target_message], dim=-1) + message_edge = torch.relu( + message_edge + torch.tanh(edge_source_message + relation_embed - edge_target_message)) + message_edge = self._modules['W_h_edge_{}'.format(depth)](message_edge) + message_edge = self.act_func(input_edge + message_edge) + message_edge = self.dropout_layer(message_edge) # num_edges x hidden + + graph_source_embed = message_node[total_source, :].squeeze(1) + graph_target_embed = message_node[total_target, :].squeeze(1) + graph_edge_embed = graph_source_embed + target_relation - graph_target_embed + edge_message = edge_source_message + relation_embed - edge_target_message + attention = torch.cat([graph_edge_embed, edge_message], dim=1) + attention = torch.relu(self._modules['Attention1_{}'.format(depth)](attention)) + attention = torch.sigmoid(self._modules['Attention2_{}'.format(depth)](attention)) + + # communicate + # agg_message = index_select_ND(message_edge, a2b) + # agg_message = agg_message.sum(dim=1) * agg_message.max(dim=1)[0] + + # agg_message = gnn_spmm(e2n_sp, message_edge)/e2n_sp.sum(1, keepdim=True) + message_edge = (message_edge * attention) + agg_message = gnn_spmm(e2n_sp, message_edge) + agg_message2 = self.communicate_mlp(torch.cat([agg_message, message_node, input_node], 1)) + # ============================================================================= + # + # ============================================================================= + # readout + + # node_hiddens = agg_message2 + + a_message = torch.relu(self.gru(agg_message2, graph_sizes)) + node_hiddens = self.act_func(self.W_o(a_message)) # num_nodes x hidden + node_hiddens = self.dropout_layer(node_hiddens) # num_nodes x hidden + + # Readout + mol_vecs = [] + a_start = 0 + for a_size in graph_sizes: + if a_size == 0: + assert 0 + cur_hiddens = node_hiddens.narrow(0, a_start, a_size) + mol_vecs.append(cur_hiddens.mean(0)) + a_start += a_size + mol_vecs = torch.stack(mol_vecs, dim=0) + + source_embed = node_hiddens[source_node, :] + target_embed = node_hiddens[target_node, :] + + # print(mol_vecs.shape, source_embed.shape, target_embed.shape) + + return mol_vecs, source_embed, target_embed + + +from torch.autograd import Variable + + +class MySpMM(torch.autograd.Function): + + @staticmethod + def forward(ctx, sp_mat, dense_mat): + ctx.save_for_backward(sp_mat, dense_mat) + + return torch.mm(sp_mat, dense_mat) + + @staticmethod + def backward(ctx, grad_output): + sp_mat, dense_mat = ctx.saved_variables + grad_matrix1 = grad_matrix2 = None + + assert not ctx.needs_input_grad[0] + if ctx.needs_input_grad[1]: + grad_matrix2 = Variable(torch.mm(sp_mat.data.t(), grad_output.data)) + + return grad_matrix1, grad_matrix2 + + +def gnn_spmm(sp_mat, dense_mat): + return MySpMM.apply(sp_mat, dense_mat) + + +def get_activation_function(activation): + """ + Gets an activation function module given the name of the activation. + + :param activation: The name of the activation function. + :return: The activation function module. + """ + if activation == 'ReLU': + return nn.ReLU() + elif activation == 'LeakyReLU': + return nn.LeakyReLU(0.1) + elif activation == 'PReLU': + return nn.PReLU() + elif activation == 'tanh': + return nn.Tanh() + elif activation == 'SELU': + return nn.SELU() + elif activation == 'ELU': + return nn.ELU() + else: + raise ValueError('Activation "{}" not supported.'.format(activation)) + + +class BatchGRU(nn.Module): + def __init__(self, hidden_size=300): + super(BatchGRU, self).__init__() + self.hidden_size = hidden_size + self.gru = nn.GRU(self.hidden_size, self.hidden_size, batch_first=True, + bidirectional=True) + self.bias = nn.Parameter(torch.Tensor(self.hidden_size)) + self.bias.data.uniform_(-1.0 / math.sqrt(self.hidden_size), + 1.0 / math.sqrt(self.hidden_size)) + + def forward(self, node, a_scope): + hidden = node + # print(hidden.shape) + message = F.relu(node + self.bias) + MAX_node_len = max(a_scope) + # padding + message_lst = [] + hidden_lst = [] + a_start = 0 + for i in a_scope: + i = int(i) + if i == 0: + assert 0 + cur_message = message.narrow(0, a_start, i) + cur_hidden = hidden.narrow(0, a_start, i) + hidden_lst.append(cur_hidden.max(0)[0].unsqueeze(0).unsqueeze(0)) + a_start += i + cur_message = torch.nn.ZeroPad2d((0, 0, 0, MAX_node_len - cur_message.shape[0]))(cur_message) + message_lst.append(cur_message.unsqueeze(0)) + + message_lst = torch.cat(message_lst, 0) + hidden_lst = torch.cat(hidden_lst, 1) + hidden_lst = hidden_lst.repeat(2, 1, 1) + cur_message, cur_hidden = self.gru(message_lst, hidden_lst) + + # unpadding + cur_message_unpadding = [] + kk = 0 + for a_size in a_scope: + a_size = int(a_size) + cur_message_unpadding.append(cur_message[kk, :a_size].view(-1, 2 * self.hidden_size)) + kk += 1 + cur_message_unpadding = torch.cat(cur_message_unpadding, 0) + + # message = torch.cat([torch.cat([message.narrow(0, 0, 1), message.narrow(0, 0, 1)], 1), + # cur_message_unpadding], 0) + # print(cur_message_unpadding.shape) + return cur_message_unpadding + + + +class RGCN(nn.Module): + def __init__(self, params): + super(RGCN, self).__init__() + + self.max_label_value = params.max_label_value + self.inp_dim = params.inp_dim + self.emb_dim = params.emb_dim + self.attn_rel_emb_dim = params.attn_rel_emb_dim + self.num_rels = params.num_rels + self.aug_num_rels = params.aug_num_rels + self.num_bases = params.num_bases + self.num_hidden_layers = params.num_gcn_layers + self.dropout = params.dropout + self.edge_dropout = params.edge_dropout + # self.aggregator_type = params.gnn_agg_type + self.has_attn = params.has_attn + + self.device = params.device + + if self.has_attn: + self.attn_rel_emb = nn.Embedding(self.num_rels, self.attn_rel_emb_dim, sparse=False) + else: + self.attn_rel_emb = None + + if params.gnn_agg_type == "sum": + self.aggregator = SumAggregator(self.emb_dim) + elif params.gnn_agg_type == "mlp": + self.aggregator = MLPAggregator(self.emb_dim) + elif params.gnn_agg_type == "gru": + self.aggregator = GRUAggregator(self.emb_dim) + + self.layers = nn.ModuleList() + #input layer + self.layers.append(RGCNBasisLayer(self.inp_dim, + self.emb_dim, + # self.input_basis_weights, + self.aggregator, + self.attn_rel_emb_dim, + self.aug_num_rels, + self.num_bases, + activation=F.relu, + dropout=self.dropout, + edge_dropout=self.edge_dropout, + is_input_layer=True, + has_attn=self.has_attn)) + #hidden layer + for idx in range(self.num_hidden_layers - 1): + self.layers.append(RGCNBasisLayer(self.emb_dim, + self.emb_dim, + # self.basis_weights, + self.aggregator, + self.attn_rel_emb_dim, + self.aug_num_rels, + self.num_bases, + activation=F.relu, + dropout=self.dropout, + edge_dropout=self.edge_dropout, + has_attn=self.has_attn)) + + + def forward(self,g): + for layer in self.layers: + layer(g, self.attn_rel_emb) + return g.ndata.pop('h') + +class RGCNLayer(nn.Module): + def __init__(self, inp_dim, out_dim, aggregator, bias=None, activation=None, dropout=0.0, edge_dropout=0.0, is_input_layer=False): + super(RGCNLayer, self).__init__() + self.bias = bias + self.activation = activation + + if self.bias: + self.bias = nn.Parameter(torch.Tensor(out_dim)) + nn.init.xavier_uniform_(self.bias, + gain=nn.init.calculate_gain('relu')) + + self.aggregator = aggregator + + if dropout: + self.dropout = nn.Dropout(dropout) + else: + self.dropout = None + + if edge_dropout: + self.edge_dropout = nn.Dropout(edge_dropout) + else: + self.edge_dropout = Identity() #Identify需要注意,和原模型有一定出入 + + # define how propagation is done in subclass + def propagate(self, g): + raise NotImplementedError + + def forward(self, g, attn_rel_emb=None): + + self.propagate(g, attn_rel_emb) + + # apply bias and activation + node_repr = g.ndata['h'] + if self.bias: + node_repr = node_repr + self.bias + if self.activation: + node_repr = self.activation(node_repr) + if self.dropout: + node_repr = self.dropout(node_repr) + + g.ndata['h'] = node_repr + + if self.is_input_layer: + g.ndata['repr'] = g.ndata['h'].unsqueeze(1) + else: + g.ndata['repr'] = torch.cat([g.ndata['repr'], g.ndata['h'].unsqueeze(1)], dim=1) + +class RGCNBasisLayer(RGCNLayer): + def __init__(self, inp_dim, out_dim, aggregator, attn_rel_emb_dim, num_rels, num_bases=-1, bias=None, + activation=None, dropout=0.0, edge_dropout=0.0, is_input_layer=False, has_attn=False): + super( + RGCNBasisLayer, + self).__init__( + inp_dim, + out_dim, + aggregator, + bias, + activation, + dropout=dropout, + edge_dropout=edge_dropout, + is_input_layer=is_input_layer) + self.inp_dim = inp_dim + self.out_dim = out_dim + self.attn_rel_emb_dim = attn_rel_emb_dim + self.num_rels = num_rels + self.num_bases = num_bases + self.is_input_layer = is_input_layer + self.has_attn = has_attn + + if self.num_bases <= 0 or self.num_bases > self.num_rels: + self.num_bases = self.num_rels + + # add basis weights + # self.weight = basis_weights + self.weight = nn.Parameter(torch.Tensor(self.num_bases, self.inp_dim, self.out_dim)) + self.w_comp = nn.Parameter(torch.Tensor(self.num_rels, self.num_bases)) + print("here: ______________________________________________________________________________________") + print(self.w_comp) + print(self.w_comp.shape) + + if self.has_attn: + self.A = nn.Linear(2 * self.inp_dim + 2 * self.attn_rel_emb_dim, inp_dim) + self.B = nn.Linear(inp_dim, 1) + + self.self_loop_weight = nn.Parameter(torch.Tensor(self.inp_dim, self.out_dim)) + + nn.init.xavier_uniform_(self.self_loop_weight, gain=nn.init.calculate_gain('relu')) + nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) + nn.init.xavier_uniform_(self.w_comp, gain=nn.init.calculate_gain('relu')) + + import multiprocessing + + def propagate(self, g, attn_rel_emb=None): + # generate all weights from bases + #torch.cuda.init() + print(self.num_bases) + print(self.inp_dim) + print(self.out_dim) + print(self.weight.shape) + weight = self.weight.view(self.num_bases, + self.inp_dim * self.out_dim) + print(weight.shape) + print(self.w_comp) + weight = torch.matmul(self.w_comp, weight).view( + self.num_rels, self.inp_dim, self.out_dim) + print("after") + print(self.w_comp) + print(weight.shape) + g.edata['w'] = self.edge_dropout(torch.ones(g.number_of_edges(), 1).to(weight.device)) + + input_ = 'feat' if self.is_input_layer else 'h' + + def msg_func(edges): + w = weight.index_select(0, edges.data['type']) + msg = edges.data['w'] * torch.bmm(edges.src[input_].unsqueeze(1), w).squeeze(1) + curr_emb = torch.mm(edges.dst[input_], self.self_loop_weight) # (B, F) + + if self.has_attn: + e = torch.cat([edges.src[input_], edges.dst[input_], attn_rel_emb(edges.data['type']), attn_rel_emb(edges.data['label'])], dim=1) + a = torch.sigmoid(self.B(F.relu(self.A(e)))) + else: + a = torch.ones((len(edges), 1)).to(device=w.device) + + return {'curr_emb': curr_emb, 'msg': msg, 'alpha': a} + + g.update_all(msg_func, self.aggregator, None) + + +class Aggregator(nn.Module): + def __init__(self, emb_dim): + super(Aggregator, self).__init__() + + def forward(self, node): + curr_emb = node.mailbox['curr_emb'][:, 0, :] # (B, F) + nei_msg = torch.bmm(node.mailbox['alpha'].transpose(1, 2), node.mailbox['msg']).squeeze(1) # (B, F) + # nei_msg, _ = torch.max(node.mailbox['msg'], 1) # (B, F) + + new_emb = self.update_embedding(curr_emb, nei_msg) + + return {'h': new_emb} + + @abc.abstractmethod + def update_embedding(curr_emb, nei_msg): + raise NotImplementedError + + +class SumAggregator(Aggregator): + def __init__(self, emb_dim): + super(SumAggregator, self).__init__(emb_dim) + + def update_embedding(self, curr_emb, nei_msg): + new_emb = nei_msg + curr_emb + + return new_emb + + +class MLPAggregator(Aggregator): + def __init__(self, emb_dim): + super(MLPAggregator, self).__init__(emb_dim) + self.linear = nn.Linear(2 * emb_dim, emb_dim) + + def update_embedding(self, curr_emb, nei_msg): + inp = torch.cat((nei_msg, curr_emb), 1) + new_emb = F.relu(self.linear(inp)) + + return new_emb + + +class GRUAggregator(Aggregator): + def __init__(self, emb_dim): + super(GRUAggregator, self).__init__(emb_dim) + self.gru = nn.GRUCell(emb_dim, emb_dim) + + def update_embedding(self, curr_emb, nei_msg): + new_emb = self.gru(nei_msg, curr_emb) + + return new_emb + + diff --git a/openhgnn/models/DisenKGAT.py b/openhgnn/models/DisenKGAT.py new file mode 100644 index 00000000..4c08d3f9 --- /dev/null +++ b/openhgnn/models/DisenKGAT.py @@ -0,0 +1,1764 @@ +from . import BaseModel, register_model +import numpy as np +from torch import Tensor +import torch +from torch.nn import functional as F +from torch.nn.init import xavier_normal_ +from torch.nn import Parameter +import torch.nn as nn +np.set_printoptions(precision=4) +from textwrap import indent +from typing import Any, Dict, List, Optional, Tuple, Union,Any +import numpy as np +import scipy.sparse + + + + + + + + +def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + index = broadcast(index, src, dim) + if out is None: + size = list(src.size()) + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + out = torch.zeros(size, dtype=src.dtype, device=src.device) + return out.scatter_add_(dim, index, src) + else: + return out.scatter_add_(dim, index, src) + +def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + return scatter_sum(src, index, dim, out, dim_size) + +def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size) + +def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + out = scatter_sum(src, index, dim, out, dim_size) + dim_size = out.size(dim) + + index_dim = dim + if index_dim < 0: + index_dim = index_dim + src.dim() + if index.dim() <= index_dim: + index_dim = index.dim() - 1 + + ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) + count = scatter_sum(ones, index, index_dim, None, dim_size) + count[count < 1] = 1 + count = broadcast(count, out, dim) + if out.is_floating_point(): + out.true_divide_(count) + else: + out.div_(count, rounding_mode='floor') + return out + +def scatter_min( + src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size) + +def scatter_max( + src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size) + +def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, + reduce: str = "sum") -> torch.Tensor: + + if reduce == 'sum' or reduce == 'add': + return scatter_sum(src, index, dim, out, dim_size) + if reduce == 'mul': + return scatter_mul(src, index, dim, out, dim_size) + elif reduce == 'mean': + return scatter_mean(src, index, dim, out, dim_size) + elif reduce == 'min': + return scatter_min(src, index, dim, out, dim_size)[0] + elif reduce == 'max': + return scatter_max(src, index, dim, out, dim_size)[0] + else: + raise ValueError + +def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): + if dim < 0: + dim = other.dim() + dim + if src.dim() == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.dim(), other.dim()): + src = src.unsqueeze(-1) + src = src.expand(other.size()) + return src + +def segment_sum_csr(src: torch.Tensor, indptr: torch.Tensor, + out: Optional[torch.Tensor] = None) -> torch.Tensor: + return torch.ops.torch_scatter.segment_sum_csr(src, indptr, out) + +def segment_add_csr(src: torch.Tensor, indptr: torch.Tensor, + out: Optional[torch.Tensor] = None) -> torch.Tensor: + return torch.ops.torch_scatter.segment_sum_csr(src, indptr, out) + +def segment_mean_csr(src: torch.Tensor, indptr: torch.Tensor, + out: Optional[torch.Tensor] = None) -> torch.Tensor: + return torch.ops.torch_scatter.segment_mean_csr(src, indptr, out) + +def segment_min_csr( + src: torch.Tensor, indptr: torch.Tensor, + out: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.ops.torch_scatter.segment_min_csr(src, indptr, out) + +def segment_max_csr( + src: torch.Tensor, indptr: torch.Tensor, + out: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.ops.torch_scatter.segment_max_csr(src, indptr, out) + +def segment_csr(src: torch.Tensor, indptr: torch.Tensor, + out: Optional[torch.Tensor] = None, + reduce: str = "sum") -> torch.Tensor: + + if reduce == 'sum' or reduce == 'add': + return segment_sum_csr(src, indptr, out) + elif reduce == 'mean': + return segment_mean_csr(src, indptr, out) + elif reduce == 'min': + return segment_min_csr(src, indptr, out)[0] + elif reduce == 'max': + return segment_max_csr(src, indptr, out)[0] + else: + raise ValueError + + + + +def is_torch_sparse_tensor(src: Any) -> bool: + if isinstance(src, Tensor): + if src.layout == torch.sparse_coo: + return True + if src.layout == torch.sparse_csr: + return True + if src.layout == torch.sparse_csc: + return True + return False + + + +from torch_sparse.storage import SparseStorage, get_layout +@torch.jit.script +class SparseTensor(object): + storage: SparseStorage + + def __init__( + self, + row: Optional[torch.Tensor] = None, + rowptr: Optional[torch.Tensor] = None, + col: Optional[torch.Tensor] = None, + value: Optional[torch.Tensor] = None, + sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None, + is_sorted: bool = False, + trust_data: bool = False, + ): + self.storage = SparseStorage( + row=row, + rowptr=rowptr, + col=col, + value=value, + sparse_sizes=sparse_sizes, + rowcount=None, + colptr=None, + colcount=None, + csr2csc=None, + csc2csr=None, + is_sorted=is_sorted, + trust_data=trust_data, + ) + + @classmethod + def from_storage(self, storage: SparseStorage): + out = SparseTensor( + row=storage._row, + rowptr=storage._rowptr, + col=storage._col, + value=storage._value, + sparse_sizes=storage._sparse_sizes, + is_sorted=True, + trust_data=True, + ) + out.storage._rowcount = storage._rowcount + out.storage._colptr = storage._colptr + out.storage._colcount = storage._colcount + out.storage._csr2csc = storage._csr2csc + out.storage._csc2csr = storage._csc2csr + return out + + @classmethod + def from_edge_index( + self, + edge_index: torch.Tensor, + edge_attr: Optional[torch.Tensor] = None, + sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None, + is_sorted: bool = False, + trust_data: bool = False, + ): + return SparseTensor(row=edge_index[0], rowptr=None, col=edge_index[1], + value=edge_attr, sparse_sizes=sparse_sizes, + is_sorted=is_sorted, trust_data=trust_data) + + @classmethod + def from_dense(self, mat: torch.Tensor, has_value: bool = True): + if mat.dim() > 2: + index = mat.abs().sum([i for i in range(2, mat.dim())]).nonzero() + else: + index = mat.nonzero() + index = index.t() + + row = index[0] + col = index[1] + + value: Optional[torch.Tensor] = None + if has_value: + value = mat[row, col] + + return SparseTensor(row=row, rowptr=None, col=col, value=value, + sparse_sizes=(mat.size(0), mat.size(1)), + is_sorted=True, trust_data=True) + + @classmethod + def from_torch_sparse_coo_tensor(self, mat: torch.Tensor, + has_value: bool = True): + mat = mat.coalesce() + index = mat._indices() + row, col = index[0], index[1] + + value: Optional[torch.Tensor] = None + if has_value: + value = mat.values() + + return SparseTensor(row=row, rowptr=None, col=col, value=value, + sparse_sizes=(mat.size(0), mat.size(1)), + is_sorted=True, trust_data=True) + + @classmethod + def from_torch_sparse_csr_tensor(self, mat: torch.Tensor, + has_value: bool = True): + rowptr = mat.crow_indices() + col = mat.col_indices() + + value: Optional[torch.Tensor] = None + if has_value: + value = mat.values() + + return SparseTensor(row=None, rowptr=rowptr, col=col, value=value, + sparse_sizes=(mat.size(0), mat.size(1)), + is_sorted=True, trust_data=True) + + @classmethod + def eye(self, M: int, N: Optional[int] = None, has_value: bool = True, + dtype: Optional[int] = None, device: Optional[torch.device] = None, + fill_cache: bool = False): + + N = M if N is None else N + + row = torch.arange(min(M, N), device=device) + col = row + + rowptr = torch.arange(M + 1, device=row.device) + if M > N: + rowptr[N + 1:] = N + + value: Optional[torch.Tensor] = None + if has_value: + value = torch.ones(row.numel(), dtype=dtype, device=row.device) + + rowcount: Optional[torch.Tensor] = None + colptr: Optional[torch.Tensor] = None + colcount: Optional[torch.Tensor] = None + csr2csc: Optional[torch.Tensor] = None + csc2csr: Optional[torch.Tensor] = None + + if fill_cache: + rowcount = torch.ones(M, dtype=torch.long, device=row.device) + if M > N: + rowcount[N:] = 0 + + colptr = torch.arange(N + 1, dtype=torch.long, device=row.device) + colcount = torch.ones(N, dtype=torch.long, device=row.device) + if N > M: + colptr[M + 1:] = M + colcount[M:] = 0 + csr2csc = csc2csr = row + + out = SparseTensor( + row=row, + rowptr=rowptr, + col=col, + value=value, + sparse_sizes=(M, N), + is_sorted=True, + trust_data=True, + ) + out.storage._rowcount = rowcount + out.storage._colptr = colptr + out.storage._colcount = colcount + out.storage._csr2csc = csr2csc + out.storage._csc2csr = csc2csr + return out + + def copy(self): + return self.from_storage(self.storage) + + def clone(self): + return self.from_storage(self.storage.clone()) + + def type(self, dtype: torch.dtype, non_blocking: bool = False): + value = self.storage.value() + if value is None or dtype == value.dtype: + return self + return self.from_storage( + self.storage.type(dtype=dtype, non_blocking=non_blocking)) + + def type_as(self, tensor: torch.Tensor, non_blocking: bool = False): + return self.type(dtype=tensor.dtype, non_blocking=non_blocking) + + def to_device(self, device: torch.device, non_blocking: bool = False): + if device == self.device(): + return self + return self.from_storage( + self.storage.to_device(device=device, non_blocking=non_blocking)) + + def device_as(self, tensor: torch.Tensor, non_blocking: bool = False): + return self.to_device(device=tensor.device, non_blocking=non_blocking) + + # Formats ################################################################# + + def coo(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + return self.storage.row(), self.storage.col(), self.storage.value() + + def csr(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + return self.storage.rowptr(), self.storage.col(), self.storage.value() + + def csc(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + perm = self.storage.csr2csc() + value = self.storage.value() + if value is not None: + value = value[perm] + return self.storage.colptr(), self.storage.row()[perm], value + + # Storage inheritance ##################################################### + + def has_value(self) -> bool: + return self.storage.has_value() + + def set_value_(self, value: Optional[torch.Tensor], + layout: Optional[str] = None): + self.storage.set_value_(value, layout) + return self + + def set_value(self, value: Optional[torch.Tensor], + layout: Optional[str] = None): + return self.from_storage(self.storage.set_value(value, layout)) + + def sparse_sizes(self) -> Tuple[int, int]: + return self.storage.sparse_sizes() + + def sparse_size(self, dim: int) -> int: + return self.storage.sparse_sizes()[dim] + + def sparse_resize(self, sparse_sizes: Tuple[int, int]): + return self.from_storage(self.storage.sparse_resize(sparse_sizes)) + + def sparse_reshape(self, num_rows: int, num_cols: int): + return self.from_storage( + self.storage.sparse_reshape(num_rows, num_cols)) + + def is_coalesced(self) -> bool: + return self.storage.is_coalesced() + + def coalesce(self, reduce: str = "sum"): + return self.from_storage(self.storage.coalesce(reduce)) + + def fill_cache_(self): + self.storage.fill_cache_() + return self + + def clear_cache_(self): + self.storage.clear_cache_() + return self + + def __eq__(self, other) -> bool: + if not isinstance(other, self.__class__): + return False + + if self.sizes() != other.sizes(): + return False + + rowptrA, colA, valueA = self.csr() + rowptrB, colB, valueB = other.csr() + + if valueA is None and valueB is not None: + return False + if valueA is not None and valueB is None: + return False + if not torch.equal(rowptrA, rowptrB): + return False + if not torch.equal(colA, colB): + return False + if valueA is None and valueB is None: + return True + return torch.equal(valueA, valueB) + + # Utility functions ####################################################### + + def fill_value_(self, fill_value: float, dtype: Optional[int] = None): + value = torch.full((self.nnz(), ), fill_value, dtype=dtype, + device=self.device()) + return self.set_value_(value, layout='coo') + + def fill_value(self, fill_value: float, dtype: Optional[int] = None): + value = torch.full((self.nnz(), ), fill_value, dtype=dtype, + device=self.device()) + return self.set_value(value, layout='coo') + + def sizes(self) -> List[int]: + sparse_sizes = self.sparse_sizes() + value = self.storage.value() + if value is not None: + return list(sparse_sizes) + list(value.size())[1:] + else: + return list(sparse_sizes) + + def size(self, dim: int) -> int: + return self.sizes()[dim] + + def dim(self) -> int: + return len(self.sizes()) + + def nnz(self) -> int: + return self.storage.col().numel() + + def numel(self) -> int: + value = self.storage.value() + if value is not None: + return value.numel() + else: + return self.nnz() + + def density(self) -> float: + return self.nnz() / (self.sparse_size(0) * self.sparse_size(1)) + + def sparsity(self) -> float: + return 1 - self.density() + + def avg_row_length(self) -> float: + return self.nnz() / self.sparse_size(0) + + def avg_col_length(self) -> float: + return self.nnz() / self.sparse_size(1) + + def bandwidth(self) -> int: + row, col, _ = self.coo() + return int((row - col).abs_().max()) + + def avg_bandwidth(self) -> float: + row, col, _ = self.coo() + return float((row - col).abs_().to(torch.float).mean()) + + def bandwidth_proportion(self, bandwidth: int) -> float: + row, col, _ = self.coo() + tmp = (row - col).abs_() + return int((tmp <= bandwidth).sum()) / self.nnz() + + def is_quadratic(self) -> bool: + return self.sparse_size(0) == self.sparse_size(1) + + def is_symmetric(self) -> bool: + if not self.is_quadratic(): + return False + + rowptr, col, value1 = self.csr() + colptr, row, value2 = self.csc() + + if (rowptr != colptr).any() or (col != row).any(): + return False + + if value1 is None or value2 is None: + return True + else: + return bool((value1 == value2).all()) + + def to_symmetric(self, reduce: str = "sum"): + N = max(self.size(0), self.size(1)) + + row, col, value = self.coo() + idx = col.new_full((2 * col.numel() + 1, ), -1) + idx[1:row.numel() + 1] = row + idx[row.numel() + 1:] = col + idx[1:] *= N + idx[1:row.numel() + 1] += col + idx[row.numel() + 1:] += row + + idx, perm = idx.sort() + mask = idx[1:] > idx[:-1] + perm = perm[1:].sub_(1) + idx = perm[mask] + + if value is not None: + ptr = mask.nonzero().flatten() + ptr = torch.cat([ptr, ptr.new_full((1, ), perm.size(0))]) + value = torch.cat([value, value])[perm] + value = segment_csr(value, ptr, reduce=reduce) + + new_row = torch.cat([row, col], dim=0, out=perm)[idx] + new_col = torch.cat([col, row], dim=0, out=perm)[idx] + + out = SparseTensor( + row=new_row, + rowptr=None, + col=new_col, + value=value, + sparse_sizes=(N, N), + is_sorted=True, + trust_data=True, + ) + return out + + def detach_(self): + value = self.storage.value() + if value is not None: + value.detach_() + return self + + def detach(self): + value = self.storage.value() + if value is not None: + value = value.detach() + return self.set_value(value, layout='coo') + + def requires_grad(self) -> bool: + value = self.storage.value() + if value is not None: + return value.requires_grad + else: + return False + + def requires_grad_(self, requires_grad: bool = True, + dtype: Optional[int] = None): + if requires_grad and not self.has_value(): + self.fill_value_(1., dtype) + + value = self.storage.value() + if value is not None: + value.requires_grad_(requires_grad) + return self + + def pin_memory(self): + return self.from_storage(self.storage.pin_memory()) + + def is_pinned(self) -> bool: + return self.storage.is_pinned() + + def device(self): + return self.storage.col().device + + def cpu(self): + return self.to_device(device=torch.device('cpu'), non_blocking=False) + + def cuda(self): + return self.from_storage(self.storage.cuda()) + + def is_cuda(self) -> bool: + return self.storage.col().is_cuda + + def dtype(self): + value = self.storage.value() + return value.dtype if value is not None else torch.float + + def is_floating_point(self) -> bool: + value = self.storage.value() + return torch.is_floating_point(value) if value is not None else True + + def bfloat16(self): + return self.type(dtype=torch.bfloat16, non_blocking=False) + + def bool(self): + return self.type(dtype=torch.bool, non_blocking=False) + + def byte(self): + return self.type(dtype=torch.uint8, non_blocking=False) + + def char(self): + return self.type(dtype=torch.int8, non_blocking=False) + + def half(self): + return self.type(dtype=torch.half, non_blocking=False) + + def float(self): + return self.type(dtype=torch.float, non_blocking=False) + + def double(self): + return self.type(dtype=torch.double, non_blocking=False) + + def short(self): + return self.type(dtype=torch.short, non_blocking=False) + + def int(self): + return self.type(dtype=torch.int, non_blocking=False) + + def long(self): + return self.type(dtype=torch.long, non_blocking=False) + + # Conversions ############################################################# + + def to_dense(self, dtype: Optional[int] = None) -> torch.Tensor: + row, col, value = self.coo() + + if value is not None: + mat = torch.zeros(self.sizes(), dtype=value.dtype, + device=self.device()) + else: + mat = torch.zeros(self.sizes(), dtype=dtype, device=self.device()) + + if value is not None: + mat[row, col] = value + else: + mat[row, col] = torch.ones(self.nnz(), dtype=mat.dtype, + device=mat.device) + + return mat + + def to_torch_sparse_coo_tensor( + self, dtype: Optional[int] = None) -> torch.Tensor: + row, col, value = self.coo() + index = torch.stack([row, col], dim=0) + + if value is None: + value = torch.ones(self.nnz(), dtype=dtype, device=self.device()) + + return torch.sparse_coo_tensor(index, value, self.sizes()) + + def to_torch_sparse_csr_tensor( + self, dtype: Optional[int] = None) -> torch.Tensor: + rowptr, col, value = self.csr() + + if value is None: + value = torch.ones(self.nnz(), dtype=dtype, device=self.device()) + + return torch.sparse_csr_tensor(rowptr, col, value, self.sizes()) + + def to_torch_sparse_csc_tensor( + self, dtype: Optional[int] = None) -> torch.Tensor: + colptr, row, value = self.csc() + + if value is None: + value = torch.ones(self.nnz(), dtype=dtype, device=self.device()) + + return torch.sparse_csc_tensor(colptr, row, value, self.sizes()) + +# Python Bindings ############################################################# + +def share_memory_(self: SparseTensor) -> SparseTensor: + self.storage.share_memory_() + return self + + +def is_shared(self: SparseTensor) -> bool: + return self.storage.is_shared() + + +def to(self, *args: Optional[List[Any]], + **kwargs: Optional[Dict[str, Any]]) -> SparseTensor: + device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)[:3] + + if dtype is not None: + self = self.type(dtype=dtype, non_blocking=non_blocking) + if device is not None: + self = self.to_device(device=device, non_blocking=non_blocking) + + return self + + +def cpu(self) -> SparseTensor: + return self.device_as(torch.tensor(0., device='cpu')) + + +def cuda(self, device: Optional[Union[int, str]] = None, + non_blocking: bool = False): + return self.device_as(torch.tensor(0., device=device or 'cuda')) + + +def __getitem__(self: SparseTensor, index: Any) -> SparseTensor: + index = list(index) if isinstance(index, tuple) else [index] + # More than one `Ellipsis` is not allowed... + if len([ + i for i in index + if not isinstance(i, (torch.Tensor, np.ndarray)) and i == ... + ]) > 1: + raise SyntaxError + + dim = 0 + out = self + while len(index) > 0: + item = index.pop(0) + if isinstance(item, (list, tuple)): + item = torch.tensor(item, device=self.device()) + if isinstance(item, np.ndarray): + item = torch.from_numpy(item).to(self.device()) + + if isinstance(item, int): + out = out.select(dim, item) + dim += 1 + elif isinstance(item, slice): + if item.step is not None: + raise ValueError('Step parameter not yet supported.') + + start = 0 if item.start is None else item.start + start = self.size(dim) + start if start < 0 else start + + stop = self.size(dim) if item.stop is None else item.stop + stop = self.size(dim) + stop if stop < 0 else stop + + out = out.narrow(dim, start, max(stop - start, 0)) + dim += 1 + elif torch.is_tensor(item): + if item.dtype == torch.bool: + out = out.masked_select(dim, item) + dim += 1 + elif item.dtype == torch.long: + out = out.index_select(dim, item) + dim += 1 + elif item == Ellipsis: + if self.dim() - len(index) < dim: + raise SyntaxError + dim = self.dim() - len(index) + else: + raise SyntaxError + + return out + + +def __repr__(self: SparseTensor) -> str: + i = ' ' * 6 + row, col, value = self.coo() + infos = [] + infos += [f'row={indent(row.__repr__(), i)[len(i):]}'] + infos += [f'col={indent(col.__repr__(), i)[len(i):]}'] + + if value is not None: + infos += [f'val={indent(value.__repr__(), i)[len(i):]}'] + + infos += [ + f'size={tuple(self.sizes())}, nnz={self.nnz()}, ' + f'density={100 * self.density():.02f}%' + ] + + infos = ',\n'.join(infos) + + i = ' ' * (len(self.__class__.__name__) + 1) + return f'{self.__class__.__name__}({indent(infos, i)[len(i):]})' + + +SparseTensor.share_memory_ = share_memory_ +SparseTensor.is_shared = is_shared +SparseTensor.to = to +SparseTensor.cpu = cpu +SparseTensor.cuda = cuda +SparseTensor.__getitem__ = __getitem__ +SparseTensor.__repr__ = __repr__ + +# Scipy Conversions ########################################################### + +ScipySparseMatrix = Union[scipy.sparse.coo_matrix, scipy.sparse.csr_matrix, + scipy.sparse.csc_matrix] + + +@torch.jit.ignore +def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor: + colptr = None + if isinstance(mat, scipy.sparse.csc_matrix): + colptr = torch.from_numpy(mat.indptr).to(torch.long) + + mat = mat.tocsr() + rowptr = torch.from_numpy(mat.indptr).to(torch.long) + mat = mat.tocoo() + row = torch.from_numpy(mat.row).to(torch.long) + col = torch.from_numpy(mat.col).to(torch.long) + value = None + if has_value: + value = torch.from_numpy(mat.data) + sparse_sizes = mat.shape[:2] + + storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value, + sparse_sizes=sparse_sizes, rowcount=None, + colptr=colptr, colcount=None, csr2csc=None, + csc2csr=None, is_sorted=True) + + return SparseTensor.from_storage(storage) + + +@torch.jit.ignore +def to_scipy(self: SparseTensor, layout: Optional[str] = None, + dtype: Optional[torch.dtype] = None) -> ScipySparseMatrix: + assert self.dim() == 2 + layout = get_layout(layout) + + if not self.has_value(): + ones = torch.ones(self.nnz(), dtype=dtype).numpy() + + if layout == 'coo': + row, col, value = self.coo() + row = row.detach().cpu().numpy() + col = col.detach().cpu().numpy() + value = value.detach().cpu().numpy() if self.has_value() else ones + return scipy.sparse.coo_matrix((value, (row, col)), self.sizes()) + elif layout == 'csr': + rowptr, col, value = self.csr() + rowptr = rowptr.detach().cpu().numpy() + col = col.detach().cpu().numpy() + value = value.detach().cpu().numpy() if self.has_value() else ones + return scipy.sparse.csr_matrix((value, col, rowptr), self.sizes()) + elif layout == 'csc': + colptr, row, value = self.csc() + colptr = colptr.detach().cpu().numpy() + row = row.detach().cpu().numpy() + value = value.detach().cpu().numpy() if self.has_value() else ones + return scipy.sparse.csc_matrix((value, row, colptr), self.sizes()) + +SparseTensor.from_scipy = from_scipy +SparseTensor.to_scipy = to_scipy + + + +def softmax( + src: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + num_nodes: Optional[int] = None, + dim: int = 0, +) -> Tensor: + r"""Computes a sparsely evaluated softmax. + Given a value tensor :attr:`src`, this function first groups the values + along the first dimension based on the indices specified in :attr:`index`, + and then proceeds to compute the softmax individually for each group. + + Args: + src (Tensor): The source tensor. + index (LongTensor, optional): The indices of elements for applying the + softmax. (default: :obj:`None`) + ptr (LongTensor, optional): If given, computes the softmax based on + sorted inputs in CSR representation. (default: :obj:`None`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) + dim (int, optional): The dimension in which to normalize. + (default: :obj:`0`) + + :rtype: :class:`Tensor` + + Examples: + + >>> src = torch.tensor([1., 1., 1., 1.]) + >>> index = torch.tensor([0, 0, 1, 2]) + >>> ptr = torch.tensor([0, 2, 3, 4]) + >>> softmax(src, index) + tensor([0.5000, 0.5000, 1.0000, 1.0000]) + + >>> softmax(src, None, ptr) + tensor([0.5000, 0.5000, 1.0000, 1.0000]) + + >>> src = torch.randn(4, 4) + >>> ptr = torch.tensor([0, 4]) + >>> softmax(src, index, dim=-1) + tensor([[0.7404, 0.2596, 1.0000, 1.0000], + [0.1702, 0.8298, 1.0000, 1.0000], + [0.7607, 0.2393, 1.0000, 1.0000], + [0.8062, 0.1938, 1.0000, 1.0000]]) + """ + if ptr is not None: + dim = dim + src.dim() if dim < 0 else dim + size = ([1] * dim) + [-1] + count = ptr[1:] - ptr[:-1] + ptr = ptr.view(size) + src_max = segment_csr(src.detach(), ptr, reduce='max') + src_max = src_max.repeat_interleave(count, dim=dim) + out = (src - src_max).exp() + out_sum = segment_csr(out, ptr, reduce='sum') + 1e-16 + out_sum = out_sum.repeat_interleave(count, dim=dim) + elif index is not None: + N = maybe_num_nodes(index, num_nodes) + src_max = scatter(src.detach(), index, dim, dim_size=N, reduce='max') + out = src - src_max.index_select(dim, index) + out = out.exp() + out_sum = scatter(out, index, dim, dim_size=N, reduce='sum') + 1e-16 + out_sum = out_sum.index_select(dim, index) + else: + raise NotImplementedError + + return out / out_sum + + +@torch.jit._overload +def maybe_num_nodes(edge_index, num_nodes=None): + + pass + + +@torch.jit._overload +def maybe_num_nodes(edge_index, num_nodes=None): + + pass + + +def maybe_num_nodes(edge_index, num_nodes=None): + if num_nodes is not None: + return num_nodes + elif isinstance(edge_index, Tensor): + if is_torch_sparse_tensor(edge_index): + return max(edge_index.size(0), edge_index.size(1)) + return int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0 + else: + return max(edge_index.size(0), edge_index.size(1)) + + + +class DisenLayer(nn.Module): + def __init__(self, edge_index, edge_type, in_channels, out_channels, num_rels, + act=lambda x: x, params=None, head_num=1): + #super(self.__class__, self).__init__(aggr='add', flow='target_to_source', node_dim=0) + ######################################## + super(DisenLayer, self).__init__() + self.node_dim = 0 + ################################### + self.edge_index = edge_index + self.edge_type = edge_type + self.p = params + self.in_channels = in_channels + self.out_channels = out_channels + self.act = act + self.device = None + self.head_num = head_num + self.num_rels = num_rels + + # params for init +####################### + self.drop = torch.nn.Dropout(self.p.gcn_drop) + self.dropout = torch.nn.Dropout(0.3) + self.bn = torch.nn.BatchNorm1d(self.p.num_factors * out_channels) + if self.p.bias: + self.register_parameter('bias', Parameter(torch.zeros(out_channels))) + + + num_edges = self.edge_index.size(1) // 2 + if self.device is None: + self.device = self.edge_index.device + self.in_index, self.out_index = self.edge_index[:, :num_edges], self.edge_index[:, num_edges:] + self.in_type, self.out_type = self.edge_type[:num_edges], self.edge_type[num_edges:] + self.loop_index = torch.stack([torch.arange(self.p.num_ent), torch.arange(self.p.num_ent)]).to(self.device) + self.loop_type = torch.full((self.p.num_ent,), 2 * self.num_rels, dtype=torch.long).to(self.device) + num_ent = self.p.num_ent + + self.leakyrelu = nn.LeakyReLU(0.2) + if self.p.att_mode == 'cat_emb' or self.p.att_mode == 'cat_weight': + self.att_weight = get_param((1, self.p.num_factors, 2 * out_channels)) + else: + self.att_weight = get_param((1, self.p.num_factors, out_channels)) + self.rel_weight = get_param((2 * self.num_rels + 1, self.p.num_factors, out_channels)) + self.loop_rel = get_param((1, out_channels)) + self.w_rel = get_param((out_channels, out_channels)) + + def forward(self, x, rel_embed, mode): +# message 和 aggregate,update + rel_embed = torch.cat([rel_embed, self.loop_rel], dim=0) + edge_index = torch.cat([self.edge_index, self.loop_index], dim=1) + edge_type = torch.cat([self.edge_type, self.loop_type]) + + # x.shape == [14541,3,200], edge_index.shape == [2,558771], rel_embed.shape == [475,200] + # rel_weight.shape == [475,3,200] + + # 原代码 out真实形状为[14541,3,200] + #out = self.propagate(edge_index, size=None, x=x, edge_type=edge_type,rel_embed=rel_embed, rel_weight=self.rel_weight) + +#############################修改后代码########################################### + # flow 是目标导源,这里j表示源节点,但是用到的却是edge_index[1](真实目标节点) + edge_index_i= edge_index[0] + edge_index_j= edge_index[1] + x_i = torch.index_select(x, dim=0, index=edge_index_i) + x_j = torch.index_select(x, dim=0, index=edge_index_j) + + + message_res = self.message(edge_index_i=edge_index_i,edge_index_j=edge_index_j,x_i=x_i,x_j=x_j, + edge_type=edge_type,rel_embed=rel_embed,rel_weight=self.rel_weight) + # message_res.shape == [558771,3,200] + aggr_res = self.aggregate(input=message_res,edge_index_i=edge_index_i) + out = self.update(aggr_res) # out.shape 真正的形状应该是[14541,3,200] +####################################################################### + if self.p.bias: + out = out + self.bias + out = self.bn(out.view(-1, self.p.num_factors * self.p.gcn_dim)).view(-1, self.p.num_factors, self.p.gcn_dim) + # out.shape == [14541,3,200] + entity1 = out if self.p.no_act else self.act(out) + return entity1, torch.matmul(rel_embed, self.w_rel)[:-1] + + def message(self, edge_index_i, edge_index_j, x_i, x_j, edge_type, rel_embed, rel_weight): + ''' + edge_index_i : [E] + x_i: [E, F] + x_j: [E, F] + ''' + rel_embed = torch.index_select(rel_embed, 0, edge_type) + rel_weight = torch.index_select(rel_weight, 0, edge_type) + xj_rel = self.rel_transform(x_j, rel_embed, rel_weight) + # start to compute the attention + alpha = self._get_attention(edge_index_i, edge_index_j, x_i, x_j, rel_embed, rel_weight, xj_rel) + alpha = self.drop(alpha) + + # xj_rel == [558771,3,200] alpha == [558771,3,1] , 相乘之后的形状是[558771,3,200] + return xj_rel * alpha # 每条边上,加权后的每条边的源节点特征 + + def aggregate(self,input,edge_index_i): # input是每条边上源节点的特征,edge_index_i是每条边上目标节点的id + return scatter_sum(input,edge_index_i,dim=0) + + def update(self, aggr_out): # aggr_out == [14541,3,200] + return aggr_out + + def _get_attention(self, edge_index_i, edge_index_j, x_i, x_j, rel_embed, rel_weight, mes_xj): + if self.p.att_mode == 'learn': + alpha = self.leakyrelu(torch.einsum('ekf, xkf->ek', [mes_xj, self.att_weight])) # [E K] + alpha = softmax(alpha, edge_index_i, num_nodes=self.p.num_ent) + + elif self.p.att_mode == 'dot_weight': + sub_rel_emb = x_i * rel_weight + obj_rel_emb = x_j * rel_weight + + alpha = self.leakyrelu(torch.einsum('ekf,ekf->ek', [sub_rel_emb, obj_rel_emb])) + alpha = softmax(alpha, edge_index_i, num_nodes=self.p.num_ent) + + elif self.p.att_mode == 'dot_emb': + sub_rel_emb = x_i * rel_embed.unsqueeze(1) + obj_rel_emb = x_j * rel_embed.unsqueeze(1) + + alpha = self.leakyrelu(torch.einsum('ekf,ekf->ek', [sub_rel_emb, obj_rel_emb])) + alpha = softmax(alpha, edge_index_i, num_nodes=self.p.num_ent) + + elif self.p.att_mode == 'cat_weight': + sub_rel_emb = x_i * rel_weight + obj_rel_emb = x_j * rel_weight + + alpha = self.leakyrelu(torch.einsum('ekf,xkf->ek', torch.cat([sub_rel_emb, obj_rel_emb], dim=2), self.att_weight)) + alpha = softmax(alpha, edge_index_i, num_nodes=self.p.num_ent) + + elif self.p.att_mode == 'cat_emb': + sub_rel_emb = x_i * rel_embed.unsqueeze(1) + obj_rel_emb = x_j * rel_embed.unsqueeze(1) + + alpha = self.leakyrelu(torch.einsum('ekf,xkf->ek', torch.cat([sub_rel_emb, obj_rel_emb], dim=2), self.att_weight)) + alpha = softmax(alpha, edge_index_i, num_nodes=self.p.num_ent) + else: + raise NotImplementedError + + return alpha.unsqueeze(2) + + + def rel_transform(self, ent_embed, rel_embed, rel_weight, opn=None): + if opn is None: + opn = self.p.opn + if opn == 'corr': + trans_embed = ccorr(ent_embed * rel_weight, rel_embed.unsqueeze(1)) + elif opn == 'corr_ra': + trans_embed = ccorr(ent_embed * rel_weight, rel_embed) + elif opn == 'sub': + trans_embed = ent_embed * rel_weight - rel_embed.unsqueeze(1) + elif opn == 'es': + trans_embed = ent_embed + elif opn == 'sub_ra': + trans_embed = ent_embed * rel_weight - rel_embed.unsqueeze(1) + elif opn == 'mult': + trans_embed = (ent_embed * rel_embed.unsqueeze(1)) * rel_weight + elif opn == 'mult_ra': + trans_embed = (ent_embed * rel_embed) * rel_weight + elif opn == 'cross': + trans_embed = ent_embed * rel_embed.unsqueeze(1) * rel_weight + ent_embed * rel_weight + elif opn == 'cross_wo_rel': + trans_embed = ent_embed * rel_weight + elif opn == 'cross_simplfy': + trans_embed = ent_embed * rel_embed + ent_embed + elif opn == 'concat': + trans_embed = torch.cat([ent_embed, rel_embed], dim=1) + elif opn == 'concat_ra': + trans_embed = torch.cat([ent_embed, rel_embed], dim=1) * rel_weight + elif opn == 'ent_ra': + trans_embed = ent_embed * rel_weight + rel_embed + else: + raise NotImplementedError + + return trans_embed + + def __repr__(self): + return '{}({}, {}, num_rels={})'.format( + self.__class__.__name__, self.in_channels, self.out_channels, self.num_rels) + +class CLUBSample(nn.Module): # Sampled version of the CLUB estimator + def __init__(self, x_dim, y_dim, hidden_size): + super(CLUBSample, self).__init__() + self.p_mu = nn.Sequential(nn.Linear(x_dim, hidden_size//2), + nn.ReLU(), + nn.Linear(hidden_size//2, y_dim)) + + self.p_logvar = nn.Sequential(nn.Linear(x_dim, hidden_size//2), + nn.ReLU(), + nn.Linear(hidden_size//2, y_dim), + nn.Tanh()) + + def get_mu_logvar(self, x_samples): + mu = self.p_mu(x_samples) + logvar = self.p_logvar(x_samples) + return mu, logvar + + + def loglikeli(self, x_samples, y_samples): +# print(x_samples.size()) +# print(y_samples.size()) + mu, logvar = self.get_mu_logvar(x_samples) + + return (-(mu - y_samples)**2 /2./logvar.exp()).sum(dim=1).mean(dim=0) + + + def forward(self, x_samples, y_samples): + mu, logvar = self.get_mu_logvar(x_samples) + + sample_size = x_samples.shape[0] + #random_index = torch.randint(sample_size, (sample_size,)).long() + random_index = torch.randperm(sample_size).long() + + positive = - (mu - y_samples)**2 / logvar.exp() + negative = - (mu - y_samples[random_index])**2 / logvar.exp() + upper_bound = (positive.sum(dim = -1) - negative.sum(dim = -1)).mean() + return upper_bound/2. + + def learning_loss(self, x_samples, y_samples): + return - self.loglikeli(x_samples, y_samples) + +class BaseModel(torch.nn.Module): + def __init__(self, params): + super(BaseModel, self).__init__() + + self.p = params + self.act = torch.tanh + self.bceloss = torch.nn.BCELoss() + + def loss(self, pred, true_label): + return self.bceloss(pred, true_label) + +class SparseInputLinear(nn.Module): + def __init__(self, inp_dim, out_dim): + super(SparseInputLinear, self).__init__() + weight = np.zeros((inp_dim, out_dim), dtype=np.float32) + weight = nn.Parameter(torch.from_numpy(weight)) + bias = np.zeros(out_dim, dtype=np.float32) + bias = nn.Parameter(torch.from_numpy(bias)) + self.inp_dim, self.out_dim = inp_dim, out_dim + self.weight, self.bias = weight, bias + self.reset_parameters() + + def reset_parameters(self): + stdv = 1. / np.sqrt(self.weight.size(1)) + self.weight.data.uniform_(-stdv, stdv) + self.bias.data.uniform_(-stdv, stdv) + + def forward(self, x): # *nn.Linear* does not accept sparse *x*. + return torch.mm(x, self.weight) + self.bias + +class CapsuleBase(BaseModel): + def __init__(self, edge_index, edge_type, num_rel, params=None): + super(CapsuleBase, self).__init__(params) + self.edge_index = edge_index + self.edge_type = edge_type + self.device = self.edge_index.device + self.init_embed = get_param((self.p.num_ent, self.p.init_dim)) + self.init_rel = get_param((num_rel * 2, self.p.gcn_dim)) + self.pca = SparseInputLinear(self.p.init_dim, self.p.num_factors * self.p.gcn_dim) + conv_ls = [] + for i in range(self.p.gcn_layer): + conv = DisenLayer(self.edge_index, self.edge_type, self.p.init_dim, self.p.gcn_dim, num_rel, + act=self.act, params=self.p, head_num=self.p.head_num) + self.add_module('conv_{}'.format(i), conv) + conv_ls.append(conv) + self.conv_ls = conv_ls + if self.p.mi_train: + if self.p.mi_method == 'club_b': + num_dis = int((self.p.num_factors) * (self.p.num_factors - 1) / 2) + # print(num_dis) + self.mi_Discs = nn.ModuleList([CLUBSample(self.p.gcn_dim, self.p.gcn_dim, self.p.gcn_dim) for fac in range(num_dis)]) + elif self.p.mi_method == 'club_s': + self.mi_Discs = nn.ModuleList([CLUBSample((fac + 1 ) * self.p.gcn_dim, self.p.gcn_dim, (fac + 1 ) * self.p.gcn_dim) for fac in range(self.p.num_factors - 1)]) + + self.register_parameter('bias', Parameter(torch.zeros(self.p.num_ent))) + self.rel_drop = nn.Dropout(0.1) + self.leakyrelu = nn.LeakyReLU(0.2) + + def lld_bst(self, sub, rel, drop1, mode='train'): + x = self.act(self.pca(self.init_embed)).view(-1, self.p.num_factors, self.p.gcn_dim) # [N K F] + r = self.init_rel + for conv in self.conv_ls: + x, r = conv(x, r, mode) # N K F + if self.p.mi_drop: + x = drop1(x) + else: + continue + + sub_emb = torch.index_select(x, 0, sub) + lld_loss = 0. + sub_emb = sub_emb.view(-1, self.p.gcn_dim * self.p.num_factors) + if self.p.mi_method == 'club_s': + for i in range(self.p.num_factors - 1): + bnd = i + 1 + lld_loss += self.mi_Discs[i].learning_loss(sub_emb[:, :bnd * self.p.gcn_dim], sub_emb[:, bnd * self.p.gcn_dim : (bnd + 1) * self.p.gcn_dim]) + + elif self.p.mi_method == 'club_b': + cnt = 0 + for i in range(self.p.num_factors): + for j in range( i + 1, self.p.num_factors): + lld_loss += self.mi_Discs[cnt].learning_loss(sub_emb[:, i * self.p.gcn_dim: (i + 1) * self.p.gcn_dim], sub_emb[:, j * self.p.gcn_dim: (j + 1) * self.p.gcn_dim]) + cnt += 1 + return lld_loss + + def mi_cal(self, sub_emb): + def loss_dependence_hisc(zdata_trn, ncaps, nhidden): + loss_dep = torch.zeros(1).cuda() + hH = (-1/nhidden)*torch.ones(nhidden, nhidden).cuda() + torch.eye(nhidden).cuda() + kfactor = torch.zeros(ncaps, nhidden, nhidden).cuda() + + for mm in range(ncaps): + data_temp = zdata_trn[:, mm * nhidden:(mm + 1) * nhidden] + kfactor[mm, :, :] = torch.mm(data_temp.t(), data_temp) + + for mm in range(ncaps): + for mn in range(mm + 1, ncaps): + mat1 = torch.mm(hH, kfactor[mm, :, :]) + mat2 = torch.mm(hH, kfactor[mn, :, :]) + mat3 = torch.mm(mat1, mat2) + teststat = torch.trace(mat3) + + loss_dep = loss_dep + teststat + return loss_dep + + def loss_dependence_club_s(sub_emb): + mi_loss = 0. + for i in range(self.p.num_factors - 1): + bnd = i + 1 + mi_loss += self.mi_Discs[i](sub_emb[:, :bnd * self.p.gcn_dim], sub_emb[:, bnd * self.p.gcn_dim : (bnd + 1) * self.p.gcn_dim]) + return mi_loss + + def loss_dependence_club_b(sub_emb): + mi_loss = 0. + cnt = 0 + for i in range(self.p.num_factors): + for j in range( i + 1, self.p.num_factors): + mi_loss += self.mi_Discs[cnt](sub_emb[:, i * self.p.gcn_dim: (i + 1) * self.p.gcn_dim], sub_emb[:, j * self.p.gcn_dim: (j + 1) * self.p.gcn_dim]) + cnt += 1 + return mi_loss + def DistanceCorrelation(tensor_1, tensor_2): + # tensor_1, tensor_2: [channel] + # ref: https://en.wikipedia.org/wiki/Distance_correlation + channel = tensor_1.shape[0] + zeros = torch.zeros(channel, channel).to(tensor_1.device) + zero = torch.zeros(1).to(tensor_1.device) + tensor_1, tensor_2 = tensor_1.unsqueeze(-1), tensor_2.unsqueeze(-1) + """cul distance matrix""" + a_, b_ = torch.matmul(tensor_1, tensor_1.t()) * 2, \ + torch.matmul(tensor_2, tensor_2.t()) * 2 # [channel, channel] + tensor_1_square, tensor_2_square = tensor_1 ** 2, tensor_2 ** 2 + a, b = torch.sqrt(torch.max(tensor_1_square - a_ + tensor_1_square.t(), zeros) + 1e-8), \ + torch.sqrt(torch.max(tensor_2_square - b_ + tensor_2_square.t(), zeros) + 1e-8) # [channel, channel] + """cul distance correlation""" + A = a - a.mean(dim=0, keepdim=True) - a.mean(dim=1, keepdim=True) + a.mean() + B = b - b.mean(dim=0, keepdim=True) - b.mean(dim=1, keepdim=True) + b.mean() + dcov_AB = torch.sqrt(torch.max((A * B).sum() / channel ** 2, zero) + 1e-8) + dcov_AA = torch.sqrt(torch.max((A * A).sum() / channel ** 2, zero) + 1e-8) + dcov_BB = torch.sqrt(torch.max((B * B).sum() / channel ** 2, zero) + 1e-8) + return dcov_AB / torch.sqrt(dcov_AA * dcov_BB + 1e-8) + + if self.p.mi_method == 'club_s': + mi_loss = loss_dependence_club_s(sub_emb) + elif self.p.mi_method == 'club_b': + mi_loss = loss_dependence_club_b(sub_emb) + elif self.p.mi_method == 'hisc': + mi_loss = loss_dependence_hisc(sub_emb, self.p.num_factors, self.p.gcn_dim) + elif self.p.mi_method == "dist": + cor = 0. + for i in range(self.p.num_factors): + for j in range(i + 1, self.p.num_factors): + cor += DistanceCorrelation(sub_emb[:, i * self.p.gcn_dim: (i + 1) * self.p.gcn_dim], sub_emb[:, j * self.p.gcn_dim: (j + 1) * self.p.gcn_dim]) + return cor + else: + raise NotImplementedError + + return mi_loss + + def forward_base(self, sub, rel, drop1, drop2, mode): + if not self.p.no_enc: + x = self.act(self.pca(self.init_embed)).view(-1, self.p.num_factors, self.p.gcn_dim) # [N K F] + r = self.init_rel + for conv in self.conv_ls: + x, r = conv(x, r, mode) # N K F + x = drop1(x) + else: + x = self.init_embed + r = self.init_rel + x = drop1(x) + sub_emb = torch.index_select(x, 0, sub) + rel_emb = torch.index_select(self.init_rel, 0, rel).repeat(1, self.p.num_factors) + mi_loss = 0. + sub_emb = sub_emb.view(-1, self.p.gcn_dim * self.p.num_factors) + mi_loss = self.mi_cal(sub_emb) + + return sub_emb, rel_emb, x, mi_loss + + def test_base(self, sub, rel, drop1, drop2, mode): + if not self.p.no_enc: + x = self.act(self.pca(self.init_embed)).view(-1, self.p.num_factors, self.p.gcn_dim) # [N K F] + r = self.init_rel + for conv in self.conv_ls: + x, r = conv(x, r, mode) # N K F + x = drop1(x) + else: + x = self.init_embed.view(-1, self.p.num_factors, self.p.gcn_dim) + r = self.init_rel + x = drop1(x) + sub_emb = torch.index_select(x, 0, sub) + rel_emb = torch.index_select(self.init_rel, 0, rel).repeat(1, self.p.num_factors) + + return sub_emb, rel_emb, x, 0. + + + + +class DisenKGAT_TransE(CapsuleBase): + + def __init__(self, edge_index, edge_type, params=None): + super(self.__class__, self).__init__(edge_index, edge_type, params.num_rel, params) + self.drop = torch.nn.Dropout(self.p.hid_drop) + self.rel_weight = self.conv_ls[-1].rel_weight # DisenLayers + gamma_init = torch.FloatTensor([self.p.init_gamma]) + if not self.p.fix_gamma: + self.register_parameter('gamma', Parameter(gamma_init)) + + def lld_best(self, sub, rel): + return self.lld_bst(sub, rel, self.drop) + + def forward(self, sub, rel, neg_ents=None, mode='train'): + if mode == 'train' and self.p.mi_train: + sub_emb, rel_emb, all_ent, corr = self.forward_base(sub, rel, self.drop, self.drop, mode) # all_ent is about memory + sub_emb = sub_emb.view(-1, self.p.num_factors, self.p.gcn_dim) + else: + sub_emb, rel_emb, all_ent, corr = self.test_base(sub, rel, self.drop, self.drop, mode) + + rel_weight = torch.index_select(self.rel_weight, 0, rel) + rel_emb = rel_emb.view(-1, self.p.num_factors, self.p.gcn_dim) + sub_emb = sub_emb.view(-1, self.p.num_factors, self.p.gcn_dim) + if self.p.score_method == 'dot_rel': + sub_rel_emb = sub_emb * rel_weight + rel_emb = rel_emb + attention = self.leakyrelu(torch.einsum('bkf,bkf->bk', [sub_rel_emb, rel_emb])) # B K + elif self.p.score_method == 'dot_sub': + sub_rel_emb = sub_emb + attention = self.leakyrelu(torch.einsum('bkf,bkf->bk', [sub_rel_emb, rel_emb])) # B K + elif self.p.score_method == 'cat_rel': + sub_rel_emb = sub_emb * rel_weight + attention = self.leakyrelu(self.fc_a(torch.cat([sub_rel_emb, rel_emb], dim=2)).squeeze()) # B K + elif self.p.score_method == 'cat_sub': + sub_rel_emb = sub_emb + attention = self.leakyrelu(self.fc_a(torch.cat([sub_rel_emb, rel_emb], dim=2)).squeeze()) # B K + elif self.p.score_method == 'learn': + att_rel = torch.index_select(self.fc_att, 0, rel) + attention = self.leakyrelu(att_rel) # [B K] + attention = nn.Softmax(dim=-1)(attention) + # calculate the score + obj_emb = sub_emb + rel_emb + if self.p.gamma_method == 'ada': + x = self.gamma - torch.norm(obj_emb.unsqueeze(1) - all_ent, p=1, dim=3).transpose(1, 2) + elif self.p.gamma_method == 'norm': + x2 = torch.sum(obj_emb * obj_emb, dim=-1) # x2.shape == [2048,3] + y2 = torch.sum(all_ent * all_ent, dim=-1) # y2真正应该是[14541,3] + xy = torch.einsum('bkf,nkf->bkn', [obj_emb, all_ent]) + x = self.gamma - (x2.unsqueeze(2) + y2.t() - 2 * xy) + + elif self.p.gamma_method == 'fix': + x = self.p.gamma - torch.norm(obj_emb.unsqueeze(1) - all_ent, p=1, dim=3).transpose(1, 2) + # start to attention on prediction + if self.p.score_order == 'before': + x = torch.einsum('bk,bkn->bn', [attention, x]) + pred = torch.sigmoid(x) + elif self.p.score_order == 'after': + x = torch.sigmoid(x) + pred = torch.einsum('bk,bkn->bn', [attention, x]) + pred = torch.clamp(pred, min=0., max=1.0) + return pred, corr + + +class DisenKGAT_DistMult(CapsuleBase): + def __init__(self, edge_index, edge_type, params=None): + super(self.__class__, self).__init__(edge_index, edge_type, params.num_rel, params) + self.drop = torch.nn.Dropout(self.p.hid_drop) + self.rel_weight = self.conv_ls[-1].rel_weight + + def lld_best(self, sub, rel): + return self.lld_bst(sub, rel, self.drop) + + def forward(self, sub, rel, neg_ents=None, mode='train'): + if mode == 'train' and self.p.mi_train: + sub_emb, rel_emb, all_ent, corr = self.forward_base(sub, rel, self.drop, self.drop, mode) + sub_emb = sub_emb.view(-1, self.p.num_factors, self.p.gcn_dim) + else: + sub_emb, rel_emb, all_ent, corr = self.test_base(sub, rel, self.drop, self.drop, mode) + rel_weight = torch.index_select(self.rel_weight, 0, rel) + rel_emb = rel_emb.view(-1, self.p.num_factors, self.p.gcn_dim) + sub_emb = sub_emb.view(-1, self.p.num_factors, self.p.gcn_dim) + if self.p.score_method == 'dot_rel': + sub_rel_emb = sub_emb * rel_weight + rel_emb = rel_emb + attention = self.leakyrelu(torch.einsum('bkf,bkf->bk', [sub_rel_emb, rel_emb])) # B K + elif self.p.score_method == 'dot_sub': + sub_rel_emb = sub_emb + attention = self.leakyrelu(torch.einsum('bkf,bkf->bk', [sub_rel_emb, rel_emb])) # B K + elif self.p.score_method == 'cat_rel': + sub_rel_emb = sub_emb * rel_weight + attention = self.leakyrelu(self.fc_a(torch.cat([sub_rel_emb, rel_emb], dim=2)).squeeze()) # B K + elif self.p.score_method == 'cat_sub': + sub_rel_emb = sub_emb + attention = self.leakyrelu(self.fc_a(torch.cat([sub_rel_emb, rel_emb], dim=2)).squeeze()) # B K + elif self.p.score_method == 'learn': + att_rel = torch.index_select(self.fc_att, 0, rel) + attention = self.leakyrelu(att_rel) # [B K] + attention = nn.Softmax(dim=-1)(attention) + # calculate the score + obj_emb = sub_emb * rel_emb + x = torch.einsum('bkf,nkf->bkn', [obj_emb, all_ent]) + x += self.bias.expand_as(x) + # start to attention on prediction + if self.p.score_order == 'before': + x = torch.einsum('bk,bkn->bn', [attention, x]) + pred = torch.sigmoid(x) + elif self.p.score_order == 'after': + x = torch.sigmoid(x) + pred = torch.einsum('bk,bkn->bn', [attention, x]) + pred = torch.clamp(pred, min=0., max=1.0) + + return pred, corr + + +class DisenKGAT_ConvE(CapsuleBase): + def __init__(self, edge_index, edge_type, params=None): + super(self.__class__, self).__init__(edge_index, edge_type, params.num_rel, params) + self.embed_dim = self.p.embed_dim + + self.bn0 = torch.nn.BatchNorm2d(1) + self.bn1 = torch.nn.BatchNorm2d(self.p.num_filt) + self.bn2 = torch.nn.BatchNorm1d(self.embed_dim) + + self.hidden_drop = torch.nn.Dropout(self.p.hid_drop) + self.hidden_drop2 = torch.nn.Dropout(self.p.hid_drop2) + self.feature_drop = torch.nn.Dropout(self.p.feat_drop) + self.m_conv1 = torch.nn.Conv2d(1, out_channels=self.p.num_filt, kernel_size=(self.p.ker_sz, self.p.ker_sz), + stride=1, padding=0, bias=self.p.bias) + + flat_sz_h = int(2 * self.p.k_w) - self.p.ker_sz + 1 + flat_sz_w = self.p.k_h - self.p.ker_sz + 1 + self.flat_sz = flat_sz_h * flat_sz_w * self.p.num_filt + self.fc = torch.nn.Linear(self.flat_sz, self.embed_dim) + if self.p.score_method.startswith('cat'): + self.fc_a = nn.Linear(2 * self.p.gcn_dim, 1) + elif self.p.score_method == 'learn': + self.fc_att = get_param((2 * self.p.num_rel, self.p.num_factors)) + self.rel_weight = self.conv_ls[-1].rel_weight + + def concat(self, e1_embed, rel_embed): + e1_embed = e1_embed.view(-1, 1, self.embed_dim) + rel_embed = rel_embed.view(-1, 1, self.embed_dim) + stack_inp = torch.cat([e1_embed, rel_embed], 1) + stack_inp = torch.transpose(stack_inp, 2, 1).reshape((-1, 1, 2 * self.p.k_w, self.p.k_h)) + return stack_inp + + def lld_best(self, sub, rel): + return self.lld_bst(sub, rel, self.hidden_drop) + + def forward(self, sub, rel, neg_ents=None, mode='train'): + if mode == 'train' and self.p.mi_train: + sub_emb, rel_emb, all_ent, corr = self.forward_base(sub, rel, self.hidden_drop, self.feature_drop, mode) + # sub_emb = sub_emb.view(-1, self.p.num_factors, self.p.gcn_dim) + else: + sub_emb, rel_emb, all_ent, corr = self.test_base(sub, rel, self.hidden_drop, self.feature_drop, mode) + sub_emb = sub_emb.view(-1, self.p.gcn_dim) + rel_emb = rel_emb.view(-1, self.p.gcn_dim) + all_ent = all_ent.view(-1, self.p.num_factors, self.p.gcn_dim) + + stk_inp = self.concat(sub_emb, rel_emb) + x = self.bn0(stk_inp) + x = self.m_conv1(x) + x = self.bn1(x) + x = F.relu(x) + x = self.feature_drop(x) + x = x.view(-1, self.flat_sz) + x = self.fc(x) + x = self.hidden_drop2(x) + x = self.bn2(x) + x = F.relu(x) + x = x.view(-1, self.p.num_factors, self.p.gcn_dim) + # start to calculate the attention + rel_weight = torch.index_select(self.rel_weight, 0, rel) # B K F + rel_emb = rel_emb.view(-1, self.p.num_factors, self.p.gcn_dim) + sub_emb = sub_emb.view(-1, self.p.num_factors, self.p.gcn_dim) # B K F + if self.p.score_method == 'dot_rel': + sub_rel_emb = sub_emb * rel_weight + attention = self.leakyrelu(torch.einsum('bkf,bkf->bk', [sub_rel_emb, rel_emb])) # B K + elif self.p.score_method == 'dot_sub': + sub_rel_emb = sub_emb + attention = self.leakyrelu(torch.einsum('bkf,bkf->bk', [sub_rel_emb, rel_emb])) # B K + elif self.p.score_method == 'cat_rel': + sub_rel_emb = sub_emb * rel_weight + attention = self.leakyrelu(self.fc_a(torch.cat([sub_rel_emb, rel_emb], dim=2)).squeeze()) # B K + elif self.p.score_method == 'cat_sub': + sub_rel_emb = sub_emb + attention = self.leakyrelu(self.fc_a(torch.cat([sub_rel_emb, rel_emb], dim=2)).squeeze()) # B K + elif self.p.score_method == 'learn': + att_rel = torch.index_select(self.fc_att, 0, rel) + attention = self.leakyrelu(att_rel) # [B K] + attention = nn.Softmax(dim=-1)(attention) + x = torch.einsum('bkf,nkf->bkn', [x, all_ent]) + x += self.bias.expand_as(x) + + if self.p.score_order == 'before': + x = torch.einsum('bk,bkn->bn', [attention, x]) + pred = torch.sigmoid(x) + elif self.p.score_order == 'after': + x = torch.sigmoid(x) + pred = torch.einsum('bk,bkn->bn', [attention, x]) + pred = torch.clamp(pred, min=0., max=1.0) + return pred, corr + + +class DisenKGAT_InteractE(CapsuleBase): + def __init__(self, edge_index, edge_type, params=None): + super(self.__class__, self).__init__(edge_index, edge_type, params.num_rel, params) + self.inp_drop = torch.nn.Dropout(self.p.iinp_drop) + self.feature_map_drop = torch.nn.Dropout2d(self.p.ifeat_drop) + self.hidden_drop = torch.nn.Dropout(self.p.ihid_drop) + + self.hidden_drop_gcn = torch.nn.Dropout(0) + + self.bn0 = torch.nn.BatchNorm2d(self.p.iperm) + + flat_sz_h = self.p.ik_h + flat_sz_w = 2 * self.p.ik_w + self.padding = 0 + + self.bn1 = torch.nn.BatchNorm2d(self.p.inum_filt * self.p.iperm) + self.flat_sz = flat_sz_h * flat_sz_w * self.p.inum_filt * self.p.iperm + + self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim) + self.fc = torch.nn.Linear(self.flat_sz, self.p.embed_dim) + self.chequer_perm = self.get_chequer_perm() + if self.p.score_method.startswith('cat'): + self.fc_a = nn.Linear(2 * self.p.gcn_dim, 1) + elif self.p.score_method == 'learn': + self.fc_att = get_param((2 * self.p.num_rel, self.p.num_factors)) + self.rel_weight = self.conv_ls[-1].rel_weight + self.register_parameter('bias', Parameter(torch.zeros(self.p.num_ent))) + self.register_parameter('conv_filt', + Parameter(torch.zeros(self.p.inum_filt, 1, self.p.iker_sz, self.p.iker_sz))) + xavier_normal_(self.conv_filt) + + def circular_padding_chw(self, batch, padding): + upper_pad = batch[..., -padding:, :] + lower_pad = batch[..., :padding, :] + temp = torch.cat([upper_pad, batch, lower_pad], dim=2) + + left_pad = temp[..., -padding:] + right_pad = temp[..., :padding] + padded = torch.cat([left_pad, temp, right_pad], dim=3) + return padded + + def lld_best(self, sub, rel): + return self.lld_bst(sub, rel, self.inp_drop) + + def forward(self, sub, rel, neg_ents=None, mode='train'): + if mode == 'train' and self.p.mi_train: + sub_emb, rel_emb, all_ent, corr = self.forward_base(sub, rel, self.inp_drop, self.hidden_drop_gcn, mode) + # sub_emb = sub_emb.view(-1, self.p.num_factors, self.p.gcn_dim) + else: + sub_emb, rel_emb, all_ent, corr = self.test_base(sub, rel, self.inp_drop, self.hidden_drop_gcn, mode) + sub_emb = sub_emb.view(-1, self.p.gcn_dim) + rel_emb = rel_emb.view(-1, self.p.gcn_dim) + all_ent = all_ent.view(-1, self.p.num_factors, self.p.gcn_dim) + # sub: [B K F] + # rel: [B K F] + # all_ent: [N K F] + comb_emb = torch.cat([sub_emb, rel_emb], dim=1) + chequer_perm = comb_emb[:, self.chequer_perm] + stack_inp = chequer_perm.reshape((-1, self.p.iperm, 2 * self.p.ik_w, self.p.ik_h)) + stack_inp = self.bn0(stack_inp) + x = stack_inp + x = self.circular_padding_chw(x, self.p.iker_sz // 2) + x = F.conv2d(x, self.conv_filt.repeat(self.p.iperm, 1, 1, 1), padding=self.padding, groups=self.p.iperm) + x = self.bn1(x) + x = F.relu(x) + x = self.feature_map_drop(x) + x = x.view(-1, self.flat_sz) + x = self.fc(x) + x = self.hidden_drop(x) + x = self.bn2(x) + x = F.relu(x) # [B*K F] + x = x.view(-1, self.p.num_factors, self.p.gcn_dim) + # start to calculate the attention + rel_weight = torch.index_select(self.rel_weight, 0, rel) # B K F + rel_emb = rel_emb.view(-1, self.p.num_factors, self.p.gcn_dim) + sub_emb = sub_emb.view(-1, self.p.num_factors, self.p.gcn_dim) # B K F + if self.p.score_method == 'dot_rel': + sub_rel_emb = sub_emb * rel_weight + attention = self.leakyrelu(torch.einsum('bkf,bkf->bk', [sub_rel_emb, rel_emb])) # B K + elif self.p.score_method == 'dot_sub': + sub_rel_emb = sub_emb + attention = self.leakyrelu(torch.einsum('bkf,bkf->bk', [sub_rel_emb, rel_emb])) # B K + elif self.p.score_method == 'cat_rel': + sub_rel_emb = sub_emb * rel_weight + attention = self.leakyrelu(self.fc_a(torch.cat([sub_rel_emb, rel_emb], dim=2)).squeeze()) # B K + elif self.p.score_method == 'cat_sub': + sub_rel_emb = sub_emb + attention = self.leakyrelu(self.fc_a(torch.cat([sub_rel_emb, rel_emb], dim=2)).squeeze()) # B K + elif self.p.score_method == 'learn': + att_rel = torch.index_select(self.fc_att, 0, rel) + attention = self.leakyrelu(att_rel) # [B K] + attention = nn.Softmax(dim=-1)(attention) + if self.p.strategy == 'one_to_n' or neg_ents is None: + x = torch.einsum('bkf,nkf->bkn', [x, all_ent]) + x += self.bias.expand_as(x) + else: + x = torch.mul(x.unsqueeze(1), all_ent[neg_ents]).sum(dim=-1) + x += self.bias[neg_ents] + if self.p.score_order == 'before': + x = torch.einsum('bk,bkn->bn', [attention, x]) + pred = torch.sigmoid(x) + elif self.p.score_order == 'after': + x = torch.sigmoid(x) + pred = torch.einsum('bk,bkn->bn', [attention, x]) + pred = torch.clamp(pred, min=0., max=1.0) + return pred, corr + + def get_chequer_perm(self): + """ + Function to generate the chequer permutation required for InteractE model + + Parameters + ---------- + + Returns + ------- + + """ + ent_perm = np.int32([np.random.permutation(self.p.embed_dim) for _ in range(self.p.iperm)]) + rel_perm = np.int32([np.random.permutation(self.p.embed_dim) for _ in range(self.p.iperm)]) + + comb_idx = [] + for k in range(self.p.iperm): + temp = [] + ent_idx, rel_idx = 0, 0 + + for i in range(self.p.ik_h): + for j in range(self.p.ik_w): + if k % 2 == 0: + if i % 2 == 0: + temp.append(ent_perm[k, ent_idx]) + ent_idx += 1 + temp.append(rel_perm[k, rel_idx] + self.p.embed_dim) + rel_idx += 1 + else: + temp.append(rel_perm[k, rel_idx] + self.p.embed_dim) + rel_idx += 1 + temp.append(ent_perm[k, ent_idx]) + ent_idx += 1 + else: + if i % 2 == 0: + temp.append(rel_perm[k, rel_idx] + self.p.embed_dim) + rel_idx += 1 + temp.append(ent_perm[k, ent_idx]) + ent_idx += 1 + else: + temp.append(ent_perm[k, ent_idx]) + ent_idx += 1 + temp.append(rel_perm[k, rel_idx] + self.p.embed_dim) + rel_idx += 1 + + comb_idx.append(temp) + + chequer_perm = torch.LongTensor(np.int32(comb_idx)).to(self.device) + return chequer_perm + + + + + +def get_combined_results(left_results, right_results): + results = {} + count = float(left_results['count']) + + results['left_mr'] = round(left_results['mr'] / count, 5) + results['left_mrr'] = round(left_results['mrr'] / count, 5) + results['right_mr'] = round(right_results['mr'] / count, 5) + results['right_mrr'] = round(right_results['mrr'] / count, 5) + results['mr'] = round((left_results['mr'] + right_results['mr']) / (2 * count), 5) + results['mrr'] = round((left_results['mrr'] + right_results['mrr']) / (2 * count), 5) + + for k in range(10): + results['left_hits@{}'.format(k + 1)] = round(left_results['hits@{}'.format(k + 1)] / count, 5) + results['right_hits@{}'.format(k + 1)] = round(right_results['hits@{}'.format(k + 1)] / count, 5) + results['hits@{}'.format(k + 1)] = round( + (left_results['hits@{}'.format(k + 1)] + right_results['hits@{}'.format(k + 1)]) / (2 * count), 5) + return results + +def get_param(shape): + param = Parameter(torch.Tensor(*shape)); + xavier_normal_(param.data) + return param + +def com_mult(a, b): + r1, i1 = a[..., 0], a[..., 1] + r2, i2 = b[..., 0], b[..., 1] + return torch.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim=-1) + +def conj(a): + a[..., 1] = -a[..., 1] + return a + +def cconv(a, b): + return torch.irfft(com_mult(torch.rfft(a, 1), torch.rfft(b, 1)), 1, signal_sizes=(a.shape[-1],)) + +def ccorr(a, b): + return torch.irfft(com_mult(conj(torch.rfft(a, 1)), torch.rfft(b, 1)), 1, signal_sizes=(a.shape[-1],)) + +def frozen_params(module: nn.Module): + for p in module.parameters(): + p.requires_grad = False + +def free_params(module: nn.Module): + for p in module.parameters(): + p.requires_grad = True +# sys.path.append('./') + + diff --git a/openhgnn/models/ExpressGNN.py b/openhgnn/models/ExpressGNN.py new file mode 100644 index 00000000..cce882af --- /dev/null +++ b/openhgnn/models/ExpressGNN.py @@ -0,0 +1,680 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import networkx as nx +from itertools import product +from . import BaseModel, register_model + + +@register_model('ExpressGNN') +class ExpressGNN(BaseModel): + @classmethod + def build_model_from_args(cls, args, hg): + return cls(args=args, + latent_dim=args.embedding_size - args.gcn_free_size, + free_dim=args.gcn_free_size, + device=args.device, + load_method=args.load_method, + rule_list=args.rule_list, + rule_weights_learning=args.rule_weights_learning, + graph=hg, + PRED_DICT=args.PRED_DICT, + slice_dim=args.slice_dim, + transductive=(args.trans == 1)) + + def __init__(self, args, graph, latent_dim, free_dim, device, load_method, rule_list, rule_weights_learning, + PRED_DICT, + num_hops=5, num_layers=2, slice_dim=5, transductive=True): + """ + + Parameters + ---------- + graph: knowledge graph + latent_dim: embedding_size - gcn_free_size + free_dim: gcn_free_size + device: device + load_method: Factorized Posterior's load method, use args to get + rule_list: MLN's rules, should come from dataset + rule_weights_learning: MLN's args, should come from args + num_hops: number of hops of GCN + num_layers: number of layers of GCN + slice_dim: Used by Factorized Posterior + transductive: Used by GCN + """ + # GCN's setting + super(ExpressGNN, self).__init__() + self.graph = graph + self.latent_dim = latent_dim + self.free_dim = free_dim + self.num_hops = num_hops + self.num_layers = num_layers + self.PRED_DICT = PRED_DICT + self.args = args + self.num_ents = graph.num_ents + self.num_rels = graph.num_rels + self.num_nodes = graph.num_nodes + self.num_edges = graph.num_edges + self.num_edge_types = len(graph.edge_type2idx) + + # Factorized Posterior's loss function + self.xent_loss = F.binary_cross_entropy_with_logits + + # Factorized Posterior's + self.load_method = load_method + self.num_rels = graph.num_rels + self.ent2idx = graph.ent2idx + self.rel2idx = graph.rel2idx + self.idx2rel = graph.idx2rel + + # Trainable Embedding + self.num_ents = self.graph.num_ents + self.ent_embeds = nn.Embedding(self.num_ents, self.args.embedding_size) + self.ents = torch.arange(self.num_ents).to(args.device) + + self.edge2node_in, self.edge2node_out, self.node_degree, \ + self.edge_type_masks, self.edge_direction_masks = self.gen_edge2node_mapping() + + self.node_feat, self.const_nodes = self.prepare_node_feature(graph, transductive=transductive) + + if not transductive: + self.node_feat_dim = 1 + self.num_rels + else: + self.node_feat_dim = self.num_ents + self.num_rels + + self.init_node_linear = nn.Linear(self.node_feat_dim, latent_dim, bias=False) + + for param in self.init_node_linear.parameters(): + param.requires_grad = False + + self.node_feat = self.node_feat.to(device) + self.const_nodes = self.const_nodes.to(device) + self.edge2node_in = self.edge2node_in.to(device) + self.edge2node_out = self.edge2node_out.to(device) + self.edge_type_masks = [mask.to(device) for mask in self.edge_type_masks] + self.edge_direction_masks = [mask.to(device) for mask in self.edge_direction_masks] + self.MLPs = nn.ModuleList() + for _ in range(self.num_hops): + self.MLPs.append(MLP(input_size=self.latent_dim, num_layers=self.num_layers, + hidden_size=self.latent_dim, output_size=self.latent_dim)) + + self.edge_type_W = nn.ModuleList() + for _ in range(self.num_edge_types): + ml_edge_type = nn.ModuleList() + for _ in range(self.num_hops): + ml_hop = nn.ModuleList() + for _ in range(2): # 2 directions of edges + ml_hop.append(nn.Linear(latent_dim, latent_dim, bias=False)) + ml_edge_type.append(ml_hop) + self.edge_type_W.append(ml_edge_type) + self.const_nodes_free_params = nn.Parameter(nn.init.kaiming_uniform_(torch.zeros(self.num_ents, free_dim))) + # load Factorized Posterior + if load_method == 1: + self.params_u_R = nn.ModuleList() + self.params_W_R = nn.ModuleList() + self.params_V_R = nn.ModuleList() + for idx in range(self.num_rels): + rel = self.idx2rel[idx] + num_args = self.PRED_DICT[rel].num_args + self.params_W_R.append( + nn.Bilinear(num_args * args.embedding_size, num_args * args.embedding_size, slice_dim, bias=False)) + self.params_V_R.append(nn.Linear(num_args * args.embedding_size, slice_dim, bias=True)) + self.params_u_R.append(nn.Linear(slice_dim, 1, bias=False)) + elif load_method == 0: + self.params_u_R = nn.ParameterList() + self.params_W_R = nn.ModuleList() + self.params_V_R = nn.ModuleList() + self.params_b_R = nn.ParameterList() + for idx in range(self.num_rels): + rel = self.idx2rel[idx] + num_args = self.PRED_DICT[rel].num_args + self.params_u_R.append(nn.Parameter(nn.init.kaiming_uniform_(torch.zeros(slice_dim, 1)).view(-1))) + self.params_W_R.append( + nn.Bilinear(num_args * args.embedding_size, num_args * args.embedding_size, slice_dim, bias=False)) + self.params_V_R.append(nn.Linear(num_args * args.embedding_size, slice_dim, bias=False)) + self.params_b_R.append(nn.Parameter(nn.init.kaiming_uniform_(torch.zeros(slice_dim, 1)).view(-1))) + + # --- MLN --- + + self.rule_weights_lin = nn.Linear(len(rule_list), 1, bias=False) + self.num_rules = len(rule_list) + self.soft_logic = False + + self.alpha_table = nn.Parameter(torch.tensor([10.0 for _ in range(len(self.PRED_DICT))], requires_grad=True)) + + self.predname2ind = dict(e for e in zip(self.PRED_DICT.keys(), range(len(self.PRED_DICT)))) + + if rule_weights_learning == 0: + self.rule_weights_lin.weight.data = torch.tensor([[rule.weight for rule in rule_list]], + dtype=torch.float) + print('rule weights fixed as pre-defined values\n') + else: + self.rule_weights_lin.weight = nn.Parameter( + torch.tensor([[rule.weight for rule in rule_list]], dtype=torch.float)) + print('rule weights set to pre-defined values, learning weights\n') + + def gcn_forward(self, batch_data): + if self.args.use_gcn == 0: + node_embeds = self.ent_embeds(self.ents) + return node_embeds + else: + node_embeds = self.init_node_linear(self.node_feat) + + hop = 0 + hidden = node_embeds + while hop < self.num_hops: + node_aggregate = torch.zeros_like(hidden) + for edge_type in set(self.graph.edge_types): + for direction in range(2): + W = self.edge_type_W[edge_type][hop][direction] + W_nodes = W(hidden) + nodes_attached_on_edges_out = torch.gather(W_nodes, 0, self.edge2node_out) + nodes_attached_on_edges_out *= self.edge_type_masks[edge_type].view(-1, 1) + nodes_attached_on_edges_out *= self.edge_direction_masks[direction].view(-1, 1) + node_aggregate.scatter_add_(0, self.edge2node_in, nodes_attached_on_edges_out) + + hidden = self.MLPs[hop](hidden + node_aggregate) + hop += 1 + + read_out_const_nodes_embed = torch.cat((hidden[self.const_nodes], self.const_nodes_free_params), dim=1) + + return read_out_const_nodes_embed + + def posterior_forward(self, latent_vars, node_embeds, batch_mode=False, fast_mode=False, fast_inference_mode=False): + """ + compute posterior probabilities of specified latent variables + + :param latent_vars: + list of latent variables (i.e. unobserved facts) + :param node_embeds: + node embeddings + :return: + n-dim vector, probability of corresponding latent variable being True + + Parameters + ---------- + fast_inference_mode + fast_mode + batch_mode + """ + + # this mode is only for fast inference on Freebase data + if fast_inference_mode: + assert self.load_method == 1 + + samples = latent_vars + scores = [] + + for ind in range(len(samples)): + pred_name, pred_sample = samples[ind] + + rel_idx = self.rel2idx[pred_name] + + sample_mat = torch.tensor(pred_sample, dtype=torch.long).to(self.args.device) # (bsize, 2) + + sample_query = torch.cat([node_embeds[sample_mat[:, 0]], node_embeds[sample_mat[:, 1]]], dim=1) + + sample_score = self.params_u_R[rel_idx]( + torch.tanh(self.params_W_R[rel_idx](sample_query, sample_query) + + self.params_V_R[rel_idx](sample_query))).view(-1) # (bsize) + scores.append(torch.sigmoid(sample_score)) + return scores + + # this mode is only for fast training on Freebase data + elif fast_mode: + + assert self.load_method == 1 + + samples, neg_mask, latent_mask, obs_var, neg_var = latent_vars + scores = [] + obs_probs = [] + neg_probs = [] + a = [] + for pred_mask in neg_mask: + a.append(pred_mask[1]) + pos_mask_mat = torch.tensor(a) + pos_mask_mat = pos_mask_mat.to(self.args.device) + neg_mask_mat = (pos_mask_mat == 0).type(torch.float) + latent_mask_mat = torch.tensor([pred_mask[1] for pred_mask in latent_mask], dtype=torch.float).to( + self.args.device) + obs_mask_mat = (latent_mask_mat == 0).type(torch.float) + for ind in range(len(samples)): + pred_name, pred_sample = samples[ind] + _, obs_sample = obs_var[ind] + _, neg_sample = neg_var[ind] + + rel_idx = self.rel2idx[pred_name] + + sample_mat = torch.tensor(pred_sample, dtype=torch.long).to(self.args.device) + obs_mat = torch.tensor(obs_sample, dtype=torch.long).to(self.args.device) + neg_mat = torch.tensor(neg_sample, dtype=torch.long).to(self.args.device) + + sample_mat = torch.cat([sample_mat, obs_mat, neg_mat], dim=0) + + sample_query = torch.cat([node_embeds[sample_mat[:, 0]], node_embeds[sample_mat[:, 1]]], dim=1) + + sample_score = self.params_u_R[rel_idx]( + torch.tanh(self.params_W_R[rel_idx](sample_query, sample_query) + + self.params_V_R[rel_idx](sample_query))).view(-1) + var_prob = sample_score[len(pred_sample):] + obs_prob = var_prob[:len(obs_sample)] + neg_prob = var_prob[len(obs_sample):] + sample_score = sample_score[:len(pred_sample)] + + scores.append(sample_score) + obs_probs.append(obs_prob) + neg_probs.append(neg_prob) + score_mat = torch.stack(scores, dim=0) + score_mat = torch.sigmoid(score_mat) + + pos_score = (1 - score_mat) * pos_mask_mat + neg_score = score_mat * neg_mask_mat + + potential = 1 - ((pos_score + neg_score) * latent_mask_mat + obs_mask_mat).prod(dim=0) + + obs_mat = torch.cat(obs_probs, dim=0) + + if obs_mat.size(0) == 0: + obs_loss = 0.0 + else: + obs_loss = self.xent_loss(obs_mat, torch.ones_like(obs_mat), reduction='sum') + + neg_mat = torch.cat(neg_probs, dim=0) + if neg_mat.size(0) != 0: + obs_loss += self.xent_loss(obs_mat, torch.zeros_like(neg_mat), reduction='sum') + + obs_loss /= (obs_mat.size(0) + neg_mat.size(0) + 1e-6) + return potential, (score_mat * latent_mask_mat).view(-1), obs_loss + + elif batch_mode: + assert self.load_method == 1 + + pred_name, x_mat, invx_mat, sample_mat = latent_vars + + rel_idx = self.rel2idx[pred_name] + + x_mat = torch.tensor(x_mat, dtype=torch.long).to(self.args.device) + invx_mat = torch.tensor(invx_mat, dtype=torch.long).to(self.args.device) + sample_mat = torch.tensor(sample_mat, dtype=torch.long).to(self.args.device) + + tail_query = torch.cat([node_embeds[x_mat[:, 0]], node_embeds[x_mat[:, 1]]], dim=1) + head_query = torch.cat([node_embeds[invx_mat[:, 0]], node_embeds[invx_mat[:, 1]]], dim=1) + true_query = torch.cat([node_embeds[sample_mat[:, 0]], node_embeds[sample_mat[:, 1]]], dim=1) + + tail_score = self.params_u_R[rel_idx](torch.tanh(self.params_W_R[rel_idx](tail_query, tail_query) + + self.params_V_R[rel_idx](tail_query))).view(-1) + + head_score = self.params_u_R[rel_idx](torch.tanh(self.params_W_R[rel_idx](head_query, head_query) + + self.params_V_R[rel_idx](head_query))).view(-1) + + true_score = self.params_u_R[rel_idx](torch.tanh(self.params_W_R[rel_idx](true_query, true_query) + + self.params_V_R[rel_idx](true_query))).view(-1) + + probas_tail = torch.sigmoid(tail_score) + probas_head = torch.sigmoid(head_score) + probas_true = torch.sigmoid(true_score) + return probas_tail, probas_head, probas_true + + else: + assert self.load_method == 0 + + probas = torch.zeros(len(latent_vars)).to(self.args.device) + for i in range(len(latent_vars)): + rel, args = latent_vars[i] + args_embed = torch.cat([node_embeds[self.ent2idx[arg]] for arg in args], 0) + rel_idx = self.rel2idx[rel] + + score = self.params_u_R[rel_idx].dot( + torch.tanh(self.params_W_R[rel_idx](args_embed, args_embed) + + self.params_V_R[rel_idx](args_embed) + + self.params_b_R[rel_idx]) + ) + proba = torch.sigmoid(score) + probas[i] = proba + return probas + + def mln_forward(self, neg_mask_ls_ls, latent_var_inds_ls_ls, observed_rule_cnts, posterior_prob, flat_list, + observed_vars_ls_ls): + """ + compute the MLN potential given the posterior probability of latent variables + :param neg_mask_ls_ls: + + :return: + + Parameters + ---------- + flat_list + posterior_prob + observed_vars_ls_ls + latent_var_inds_ls_ls + observed_rule_cnts + """ + + scores = torch.zeros(self.num_rules, dtype=torch.float, device=self.args.device) + pred_ind_flat_list = [] + if self.soft_logic: + pred_name_ls = [e[0] for e in flat_list] + pred_ind_flat_list = [self.predname2ind[pred_name] for pred_name in pred_name_ls] + + for i in range(len(neg_mask_ls_ls)): + neg_mask_ls = neg_mask_ls_ls[i] + latent_var_inds_ls = latent_var_inds_ls_ls[i] + observed_vars_ls = observed_vars_ls_ls[i] + + # sum of scores from gnd rules with latent vars + for j in range(len(neg_mask_ls)): + + latent_neg_mask, observed_neg_mask = neg_mask_ls[j] + latent_var_inds = latent_var_inds_ls[j] + observed_vars = observed_vars_ls[j] + + z_probs = posterior_prob[latent_var_inds].unsqueeze(0) + + z_probs = torch.cat([1 - z_probs, z_probs], dim=0) + + cartesian_prod = z_probs[:, 0] + for j in range(1, z_probs.shape[1]): + cartesian_prod = torch.ger(cartesian_prod, z_probs[:, j]) + cartesian_prod = cartesian_prod.view(-1) + + view_ls = [2 for _ in range(len(latent_neg_mask))] + cartesian_prod = cartesian_prod.view(*[view_ls]) + + if self.soft_logic: + + # observed alpha + obs_vals = [e[0] for e in observed_vars] + pred_names = [e[1] for e in observed_vars] + pred_inds = [self.predname2ind[pn] for pn in pred_names] + alpha = self.alpha_table[pred_inds] # alphas in this formula + act_alpha = torch.sigmoid(alpha) + obs_neg_flag = [(1 if observed_vars[i] != observed_neg_mask[i] else 0) + for i in range(len(observed_vars))] + tn_obs_neg_flag = torch.tensor(obs_neg_flag, dtype=torch.float) + + val = torch.abs(1 - torch.tensor(obs_vals, dtype=torch.float) - act_alpha) + obs_score = torch.abs(tn_obs_neg_flag - val) + + # latent alpha + inds = product(*[[0, 1] for _ in range(len(latent_neg_mask))]) + pred_inds = [pred_ind_flat_list[i] for i in latent_var_inds] + alpha = self.alpha_table[pred_inds] # alphas in this formula + act_alpha = torch.sigmoid(alpha) + tn_latent_neg_mask = torch.tensor(latent_neg_mask, dtype=torch.float) + + for ind in inds: + val = torch.abs(1 - torch.tensor(ind, dtype=torch.float) - act_alpha) + val = torch.abs(tn_latent_neg_mask - val) + cartesian_prod[tuple(ind)] *= torch.max(torch.cat([val, obs_score], dim=0)) + + else: + + if sum(observed_neg_mask) == 0: + cartesian_prod[tuple(latent_neg_mask)] = 0.0 + + scores[i] += cartesian_prod.sum() + + # sum of scores from gnd rule with only observed vars + scores[i] += observed_rule_cnts[i] + + return self.rule_weights_lin(scores) + + def gen_edge2node_mapping(self): + """ + A GCN's function + Returns + ------- + + """ + ei = 0 # edge index with direction + edge_idx = 0 # edge index without direction + edge2node_in = torch.zeros(self.num_edges * 2, dtype=torch.long) + edge2node_out = torch.zeros(self.num_edges * 2, dtype=torch.long) + node_degree = torch.zeros(self.num_nodes) + + edge_type_masks = [] + for _ in range(self.num_edge_types): + edge_type_masks.append(torch.zeros(self.num_edges * 2)) + edge_direction_masks = [] + for _ in range(2): # 2 directions of edges + edge_direction_masks.append(torch.zeros(self.num_edges * 2)) + + for ni, nj in torch.as_tensor(self.graph.edge_pairs): + edge_type = self.graph.edge_types[edge_idx] + edge_idx += 1 + + edge2node_in[ei] = nj + edge2node_out[ei] = ni + node_degree[ni] += 1 + edge_type_masks[edge_type][ei] = 1 + edge_direction_masks[0][ei] = 1 + ei += 1 + + edge2node_in[ei] = ni + edge2node_out[ei] = nj + node_degree[nj] += 1 + edge_type_masks[edge_type][ei] = 1 + edge_direction_masks[1][ei] = 1 + ei += 1 + + edge2node_in = edge2node_in.view(-1, 1).expand(-1, self.latent_dim) + edge2node_out = edge2node_out.view(-1, 1).expand(-1, self.latent_dim) + node_degree = node_degree.view(-1, 1) + return edge2node_in, edge2node_out, node_degree, edge_type_masks, edge_direction_masks + + def weight_update(self, neg_mask_ls_ls, latent_var_inds_ls_ls, observed_rule_cnts, posterior_prob, flat_list, + observed_vars_ls_ls): + """ + A MLN's Function + Parameters + ---------- + neg_mask_ls_ls + latent_var_inds_ls_ls + observed_rule_cnts + posterior_prob + flat_list + observed_vars_ls_ls + + Returns + ------- + + """ + closed_wolrd_potentials = torch.zeros(self.num_rules, dtype=torch.float) + pred_ind_flat_list = [] + if self.soft_logic: + pred_name_ls = [e[0] for e in flat_list] + pred_ind_flat_list = [self.predname2ind[pred_name] for pred_name in pred_name_ls] + + for i in range(len(neg_mask_ls_ls)): + neg_mask_ls = neg_mask_ls_ls[i] + latent_var_inds_ls = latent_var_inds_ls_ls[i] + observed_vars_ls = observed_vars_ls_ls[i] + + # sum of scores from gnd rules with latent vars + for j in range(len(neg_mask_ls)): + + latent_neg_mask, observed_neg_mask = neg_mask_ls[j] + latent_var_inds = latent_var_inds_ls[j] + observed_vars = observed_vars_ls[j] + + has_pos_atom = False + for val in observed_neg_mask + latent_neg_mask: + if val == 1: + has_pos_atom = True + break + + if has_pos_atom: + closed_wolrd_potentials[i] += 1 + + z_probs = posterior_prob[latent_var_inds].unsqueeze(0) + + z_probs = torch.cat([1 - z_probs, z_probs], dim=0) + + cartesian_prod = z_probs[:, 0] + for j in range(1, z_probs.shape[1]): + cartesian_prod = torch.ger(cartesian_prod, z_probs[:, j]) + cartesian_prod = cartesian_prod.view(-1) + + view_ls = [2 for _ in range(len(latent_neg_mask))] + cartesian_prod = cartesian_prod.view(*[view_ls]) + + if self.soft_logic: + + # observed alpha + obs_vals = [e[0] for e in observed_vars] + pred_names = [e[1] for e in observed_vars] + pred_inds = [self.predname2ind[pn] for pn in pred_names] + alpha = self.alpha_table[pred_inds] # alphas in this formula + act_alpha = torch.sigmoid(alpha) + obs_neg_flag = [(1 if observed_vars[i] != observed_neg_mask[i] else 0) + for i in range(len(observed_vars))] + tn_obs_neg_flag = torch.tensor(obs_neg_flag, dtype=torch.float) + + val = torch.abs(1 - torch.tensor(obs_vals, dtype=torch.float) - act_alpha) + obs_score = torch.abs(tn_obs_neg_flag - val) + + # latent alpha + inds = product(*[[0, 1] for _ in range(len(latent_neg_mask))]) + pred_inds = [pred_ind_flat_list[i] for i in latent_var_inds] + alpha = self.alpha_table[pred_inds] # alphas in this formula + act_alpha = torch.sigmoid(alpha) + tn_latent_neg_mask = torch.tensor(latent_neg_mask, dtype=torch.float) + + for ind in inds: + val = torch.abs(1 - torch.tensor(ind, dtype=torch.float) - act_alpha) + val = torch.abs(tn_latent_neg_mask - val) + cartesian_prod[tuple(ind)] *= torch.max(torch.cat([val, obs_score], dim=0)) + + else: + + if sum(observed_neg_mask) == 0: + cartesian_prod[tuple(latent_neg_mask)] = 0.0 + + weight_grad = closed_wolrd_potentials + + return weight_grad + + def gen_index(self, facts, predicates, dataset): + rel2idx = dict() + idx_rel = 0 + for rel in sorted(predicates.keys()): + if rel not in rel2idx: + rel2idx[rel] = idx_rel + idx_rel += 1 + idx2rel = dict(zip(rel2idx.values(), rel2idx.keys())) + + ent2idx = dict() + idx_ent = 0 + for type_name in sorted(dataset.const_sort_dict.keys()): + for const in dataset.const_sort_dict[type_name]: + ent2idx[const] = idx_ent + idx_ent += 1 + idx2ent = dict(zip(ent2idx.values(), ent2idx.keys())) + + node2idx = ent2idx.copy() + idx_node = len(node2idx) + for rel in sorted(facts.keys()): + for fact in sorted(list(facts[rel])): + val, args = fact + if (rel, args) not in node2idx: + node2idx[(rel, args)] = idx_node + idx_node += 1 + idx2node = dict(zip(node2idx.values(), node2idx.keys())) + + return ent2idx, idx2ent, rel2idx, idx2rel, node2idx, idx2node + + def gen_edge_type(self): + edge_type2idx = dict() + num_args_set = set() + for rel in self.PRED_DICT: + num_args = self.PRED_DICT[rel].num_args + num_args_set.add(num_args) + idx = 0 + for num_args in sorted(list(num_args_set)): + for pos_code in product(['0', '1'], repeat=num_args): + if '1' in pos_code: + edge_type2idx[(0, ''.join(pos_code))] = idx + idx += 1 + edge_type2idx[(1, ''.join(pos_code))] = idx + idx += 1 + return edge_type2idx + + def gen_graph(self, facts, predicates, dataset): + """ + generate directed knowledge graph, where each edge is from subject to object + :param facts: + dictionary of facts + :param predicates: + dictionary of predicates + :param dataset: + dataset object + :return: + graph object, entity to index, index to entity, relation to index, index to relation + """ + + # build bipartite graph (constant nodes and hyper predicate nodes) + g = nx.Graph() + ent2idx, idx2ent, rel2idx, idx2rel, node2idx, idx2node = self.gen_index(facts, predicates, dataset) + + edge_type2idx = self.gen_edge_type() + + for node_idx in idx2node: + g.add_node(node_idx) + + for rel in facts.keys(): + for fact in facts[rel]: + val, args = fact + fact_node_idx = node2idx[(rel, args)] + for arg in args: + pos_code = ''.join(['%d' % (arg == v) for v in args]) + g.add_edge(fact_node_idx, node2idx[arg], + edge_type=edge_type2idx[(val, pos_code)]) + return g, edge_type2idx, ent2idx, idx2ent, rel2idx, idx2rel, node2idx, idx2node + + def prepare_node_feature(self, graph, transductive=True): + if transductive: + node_feat = torch.zeros(graph.num_nodes, # for transductive GCN + graph.num_ents + graph.num_rels) + + const_nodes = [] + for i in graph.idx2node: + if isinstance(graph.idx2node[i], str): # const (entity) node + const_nodes.append(i) + node_feat[i][i] = 1 + elif isinstance(graph.idx2node[i], tuple): # fact node + rel, args = graph.idx2node[i] + node_feat[i][graph.num_ents + graph.rel2idx[rel]] = 1 + else: + node_feat = torch.zeros(graph.num_nodes, 1 + graph.num_rels) # for inductive GCN + const_nodes = [] + for i in graph.idx2node: + if isinstance(graph.idx2node[i], str): # const (entity) node + node_feat[i][0] = 1 + const_nodes.append(i) + elif isinstance(graph.idx2node[i], tuple): # fact node + rel, args = graph.idx2node[i] + node_feat[i][1 + graph.rel2idx[rel]] = 1 + + return node_feat, torch.LongTensor(const_nodes) + + +class MLP(nn.Module): + def __init__(self, input_size, num_layers, hidden_size, output_size): + super(MLP, self).__init__() + + self.input_linear = nn.Linear(input_size, hidden_size) + + self.hidden = nn.ModuleList() + for _ in range(num_layers - 1): + self.hidden.append(nn.Linear(hidden_size, hidden_size)) + + self.output_linear = nn.Linear(hidden_size, output_size) + + def forward(self, x): + h = F.relu(self.input_linear(x)) + + for layer in self.hidden: + h = F.relu(layer(h)) + + output = self.output_linear(h) + + return output diff --git a/openhgnn/models/Grail.py b/openhgnn/models/Grail.py new file mode 100644 index 00000000..6f93a9fe --- /dev/null +++ b/openhgnn/models/Grail.py @@ -0,0 +1,291 @@ +import os + +import numpy as np +import torch +import torch.nn as nn +import dgl.function as fn +import torch.nn.functional as F +from . import BaseModel, register_model +import torch.nn.functional as F +from torch.nn import Identity +from dgl import mean_nodes +import abc + +@register_model('Grail') +class Grail(BaseModel): + @classmethod + def build_model_from_args(cls, args, relation2id): + return cls(args,relation2id) + + def __init__(self, args, relation2id): + super(Grail, self).__init__() + self.params = args + self.relation2id = relation2id + self.gnn = RGCN(args) # in_dim, h_dim, h_dim, num_rels, num_bases) + self.rel_emb = nn.Embedding(self.params.num_rels, self.params.rel_emb_dim, sparse=False) + + if self.params.add_ht_emb: + self.fc_layer = nn.Linear(3 * self.params.num_gcn_layers * self.params.emb_dim + self.params.rel_emb_dim, 1) + else: + self.fc_layer = nn.Linear(self.params.num_gcn_layers * self.params.emb_dim + self.params.rel_emb_dim, 1) + + def forward(self,hg): + g, rel_labels = hg + g.ndata['h'] = self.gnn(g) + + g_out = mean_nodes(g, 'repr') + + head_ids = (g.ndata['id'] == 1).nonzero().squeeze(1) + head_embs = g.ndata['repr'][head_ids] + + tail_ids = (g.ndata['id'] == 2).nonzero().squeeze(1) + tail_embs = g.ndata['repr'][tail_ids] + + if self.params.add_ht_emb: + g_rep = torch.cat([g_out.view(-1, self.params.num_gcn_layers * self.params.emb_dim), + head_embs.view(-1, self.params.num_gcn_layers * self.params.emb_dim), + tail_embs.view(-1, self.params.num_gcn_layers * self.params.emb_dim), + self.rel_emb(rel_labels)], dim=1) + else: + g_rep = torch.cat( + [g_out.view(-1, self.params.num_gcn_layers * self.params.emb_dim), self.rel_emb(rel_labels)], dim=1) + + output = self.fc_layer(g_rep) + return output + + + + +class RGCN(nn.Module): + def __init__(self, params): + super(RGCN, self).__init__() + + self.max_label_value = params.max_label_value + self.inp_dim = params.inp_dim + self.emb_dim = params.emb_dim + self.attn_rel_emb_dim = params.attn_rel_emb_dim + self.num_rels = params.num_rels + self.aug_num_rels = params.aug_num_rels + self.num_bases = params.num_bases + self.num_hidden_layers = params.num_gcn_layers + self.dropout = params.dropout + self.edge_dropout = params.edge_dropout + # self.aggregator_type = params.gnn_agg_type + self.has_attn = params.has_attn + + self.device = params.device + + if self.has_attn: + self.attn_rel_emb = nn.Embedding(self.num_rels, self.attn_rel_emb_dim, sparse=False) + else: + self.attn_rel_emb = None + + if params.gnn_agg_type == "sum": + self.aggregator = SumAggregator(self.emb_dim) + elif params.gnn_agg_type == "mlp": + self.aggregator = MLPAggregator(self.emb_dim) + elif params.gnn_agg_type == "gru": + self.aggregator = GRUAggregator(self.emb_dim) + + self.layers = nn.ModuleList() + #input layer + self.layers.append(RGCNBasisLayer(self.inp_dim, + self.emb_dim, + # self.input_basis_weights, + self.aggregator, + self.attn_rel_emb_dim, + self.aug_num_rels, + self.num_bases, + activation=F.relu, + dropout=self.dropout, + edge_dropout=self.edge_dropout, + is_input_layer=True, + has_attn=self.has_attn)) + #hidden layer + for idx in range(self.num_hidden_layers - 1): + self.layers.append(RGCNBasisLayer(self.emb_dim, + self.emb_dim, + # self.basis_weights, + self.aggregator, + self.attn_rel_emb_dim, + self.aug_num_rels, + self.num_bases, + activation=F.relu, + dropout=self.dropout, + edge_dropout=self.edge_dropout, + has_attn=self.has_attn)) + + + def forward(self,g): + for layer in self.layers: + layer(g, self.attn_rel_emb) + return g.ndata.pop('h') + +class RGCNLayer(nn.Module): + def __init__(self, inp_dim, out_dim, aggregator, bias=None, activation=None, dropout=0.0, edge_dropout=0.0, is_input_layer=False): + super(RGCNLayer, self).__init__() + self.bias = bias + self.activation = activation + + if self.bias: + self.bias = nn.Parameter(torch.Tensor(out_dim)) + nn.init.xavier_uniform_(self.bias, + gain=nn.init.calculate_gain('relu')) + + self.aggregator = aggregator + + if dropout: + self.dropout = nn.Dropout(dropout) + else: + self.dropout = None + + if edge_dropout: + self.edge_dropout = nn.Dropout(edge_dropout) + else: + self.edge_dropout = Identity() #Identify需要注意,和原模型有一定出入 + + # define how propagation is done in subclass + def propagate(self, g): + raise NotImplementedError + + def forward(self, g, attn_rel_emb=None): + + self.propagate(g, attn_rel_emb) + + # apply bias and activation + node_repr = g.ndata['h'] + if self.bias: + node_repr = node_repr + self.bias + if self.activation: + node_repr = self.activation(node_repr) + if self.dropout: + node_repr = self.dropout(node_repr) + + g.ndata['h'] = node_repr + + if self.is_input_layer: + g.ndata['repr'] = g.ndata['h'].unsqueeze(1) + else: + g.ndata['repr'] = torch.cat([g.ndata['repr'], g.ndata['h'].unsqueeze(1)], dim=1) + +class RGCNBasisLayer(RGCNLayer): + def __init__(self, inp_dim, out_dim, aggregator, attn_rel_emb_dim, num_rels, num_bases=-1, bias=None, + activation=None, dropout=0.0, edge_dropout=0.0, is_input_layer=False, has_attn=False): + super( + RGCNBasisLayer, + self).__init__( + inp_dim, + out_dim, + aggregator, + bias, + activation, + dropout=dropout, + edge_dropout=edge_dropout, + is_input_layer=is_input_layer) + self.inp_dim = inp_dim + self.out_dim = out_dim + self.attn_rel_emb_dim = attn_rel_emb_dim + self.num_rels = num_rels + self.num_bases = num_bases + self.is_input_layer = is_input_layer + self.has_attn = has_attn + + if self.num_bases <= 0 or self.num_bases > self.num_rels: + self.num_bases = self.num_rels + + # add basis weights + # self.weight = basis_weights + self.weight = nn.Parameter(torch.Tensor(self.num_bases, self.inp_dim, self.out_dim)) + self.w_comp = nn.Parameter(torch.Tensor(self.num_rels, self.num_bases)) + + if self.has_attn: + self.A = nn.Linear(2 * self.inp_dim + 2 * self.attn_rel_emb_dim, inp_dim) + self.B = nn.Linear(inp_dim, 1) + + self.self_loop_weight = nn.Parameter(torch.Tensor(self.inp_dim, self.out_dim)) + + nn.init.xavier_uniform_(self.self_loop_weight, gain=nn.init.calculate_gain('relu')) + nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) + nn.init.xavier_uniform_(self.w_comp, gain=nn.init.calculate_gain('relu')) + + import multiprocessing + + def propagate(self, g, attn_rel_emb=None): + # generate all weights from bases + #torch.cuda.init() + weight = self.weight.view(self.num_bases, + self.inp_dim * self.out_dim) + weight = torch.matmul(self.w_comp, weight).view( + self.num_rels, self.inp_dim, self.out_dim) + g = g.to(weight.device) + g.edata['w'] = self.edge_dropout(torch.ones(g.number_of_edges(), 1).to(weight.device)) + + input_ = 'feat' if self.is_input_layer else 'h' + + def msg_func(edges): + w = weight.index_select(0, edges.data['type']) + msg = edges.data['w'] * torch.bmm(edges.src[input_].unsqueeze(1), w).squeeze(1) + curr_emb = torch.mm(edges.dst[input_], self.self_loop_weight) # (B, F) + + if self.has_attn: + e = torch.cat([edges.src[input_], edges.dst[input_], attn_rel_emb(edges.data['type']), attn_rel_emb(edges.data['label'])], dim=1) + a = torch.sigmoid(self.B(F.relu(self.A(e)))) + else: + a = torch.ones((len(edges), 1)).to(device=w.device) + + return {'curr_emb': curr_emb, 'msg': msg, 'alpha': a} + + g.update_all(msg_func, self.aggregator, None) + + +class Aggregator(nn.Module): + def __init__(self, emb_dim): + super(Aggregator, self).__init__() + + def forward(self, node): + curr_emb = node.mailbox['curr_emb'][:, 0, :] # (B, F) + nei_msg = torch.bmm(node.mailbox['alpha'].transpose(1, 2), node.mailbox['msg']).squeeze(1) # (B, F) + # nei_msg, _ = torch.max(node.mailbox['msg'], 1) # (B, F) + + new_emb = self.update_embedding(curr_emb, nei_msg) + + return {'h': new_emb} + + @abc.abstractmethod + def update_embedding(curr_emb, nei_msg): + raise NotImplementedError + + +class SumAggregator(Aggregator): + def __init__(self, emb_dim): + super(SumAggregator, self).__init__(emb_dim) + + def update_embedding(self, curr_emb, nei_msg): + new_emb = nei_msg + curr_emb + + return new_emb + + +class MLPAggregator(Aggregator): + def __init__(self, emb_dim): + super(MLPAggregator, self).__init__(emb_dim) + self.linear = nn.Linear(2 * emb_dim, emb_dim) + + def update_embedding(self, curr_emb, nei_msg): + inp = torch.cat((nei_msg, curr_emb), 1) + new_emb = F.relu(self.linear(inp)) + + return new_emb + + +class GRUAggregator(Aggregator): + def __init__(self, emb_dim): + super(GRUAggregator, self).__init__(emb_dim) + self.gru = nn.GRUCell(emb_dim, emb_dim) + + def update_embedding(self, curr_emb, nei_msg): + new_emb = self.gru(nei_msg, curr_emb) + + return new_emb + + diff --git a/openhgnn/models/Ingram.py b/openhgnn/models/Ingram.py new file mode 100644 index 00000000..74a5e9cd --- /dev/null +++ b/openhgnn/models/Ingram.py @@ -0,0 +1,253 @@ + +from . import BaseModel, register_model +import torch +import torch.nn as nn + + +@register_model('Ingram') +class Ingram(BaseModel): + @classmethod + def build_model_from_args(cls, config): + return cls(config) + + def __init__(self, config): + super().__init__() + + self.model = Model(config) + + def forward(self, *args): + return self.model(*args) + + def extra_loss(self): + pass + + +class Model(nn.Module): + + def __init__(self, args): + super().__init__() + self.args = args + layers_ent = [] + layers_rel = [] + layer_dim_ent = self.args.hdr_e * self.args.d_e + layer_dim_rel = self.args.hdr_r * self.args.d_r + num_layer_ent = self.args.nle + num_head = self.args.num_head + num_layer_rel = self.args.nlr + num_bin = self.args.num_bin + dim_ent = self.args.d_e + dim_rel = self.args.d_r + bias = True + for _ in range(num_layer_ent): + layers_ent.append(InGramEntityLayer(layer_dim_ent, layer_dim_ent, layer_dim_rel, \ + bias=bias, num_head=num_head)) + for _ in range(num_layer_rel): + layers_rel.append(InGramRelationLayer(layer_dim_rel, layer_dim_rel, num_bin, \ + bias=bias, num_head=num_head)) + res_proj_ent = [] + for _ in range(num_layer_ent): + res_proj_ent.append(nn.Linear(layer_dim_ent, layer_dim_ent, bias=bias)) + + res_proj_rel = [] + for _ in range(num_layer_rel): + res_proj_rel.append(nn.Linear(layer_dim_rel, layer_dim_rel, bias=bias)) + + self.res_proj_ent = nn.ModuleList(res_proj_ent) + self.res_proj_rel = nn.ModuleList(res_proj_rel) + self.bias = bias + self.ent_proj1 = nn.Linear(dim_ent, layer_dim_ent, bias=bias) + self.ent_proj2 = nn.Linear(layer_dim_ent, dim_ent, bias=bias) + self.layers_ent = nn.ModuleList(layers_ent) + self.layers_rel = nn.ModuleList(layers_rel) + + self.rel_proj1 = nn.Linear(dim_rel, layer_dim_rel, bias=bias) + self.rel_proj2 = nn.Linear(layer_dim_rel, dim_rel, bias=bias) + self.rel_proj = nn.Linear(dim_rel, dim_ent, bias=bias) + self.num_layer_ent = num_layer_ent + self.num_layer_rel = num_layer_rel + self.act = nn.ReLU() + + self.param_init() + + def param_init(self): + nn.init.xavier_normal_(self.ent_proj1.weight, gain=nn.init.calculate_gain('relu')) + nn.init.xavier_normal_(self.ent_proj2.weight, gain=nn.init.calculate_gain('relu')) + nn.init.xavier_normal_(self.rel_proj1.weight, gain=nn.init.calculate_gain('relu')) + nn.init.xavier_normal_(self.rel_proj2.weight, gain=nn.init.calculate_gain('relu')) + nn.init.xavier_normal_(self.rel_proj.weight, gain=nn.init.calculate_gain('relu')) + for layer_idx in range(self.num_layer_ent): + nn.init.xavier_normal_(self.res_proj_ent[layer_idx].weight, gain=nn.init.calculate_gain('relu')) + for layer_idx in range(self.num_layer_rel): + nn.init.xavier_normal_(self.res_proj_rel[layer_idx].weight, gain=nn.init.calculate_gain('relu')) + if self.bias: + nn.init.zeros_(self.ent_proj1.bias) + nn.init.zeros_(self.ent_proj2.bias) + nn.init.zeros_(self.rel_proj1.bias) + nn.init.zeros_(self.rel_proj2.bias) + nn.init.zeros_(self.rel_proj.bias) + for layer_idx in range(self.num_layer_ent): + nn.init.zeros_(self.res_proj_ent[layer_idx].bias) + for layer_idx in range(self.num_layer_rel): + nn.init.zeros_(self.res_proj_rel[layer_idx].bias) + + def forward(self, emb_ent, emb_rel, triplets, relation_triplets): + + layer_emb_ent = self.ent_proj1(emb_ent) + layer_emb_rel = self.rel_proj1(emb_rel) + + for layer_idx, layer in enumerate(self.layers_rel): + layer_emb_rel = layer(layer_emb_rel, relation_triplets) + \ + self.res_proj_rel[layer_idx](layer_emb_rel) + layer_emb_rel = self.act(layer_emb_rel) + + for layer_idx, layer in enumerate(self.layers_ent): + layer_emb_ent = layer(layer_emb_ent, layer_emb_rel, triplets) + \ + self.res_proj_ent[layer_idx](layer_emb_ent) + layer_emb_ent = self.act(layer_emb_ent) + + return self.ent_proj2(layer_emb_ent), self.rel_proj2(layer_emb_rel) + + def score(self, emb_ent, emb_rel, triplets): + + head_idxs = triplets[..., 0] + rel_idxs = triplets[..., 1] + tail_idxs = triplets[..., 2] + head_embs = emb_ent[head_idxs] + tail_embs = emb_ent[tail_idxs] + rel_embs = self.rel_proj(emb_rel[rel_idxs]) + output = (head_embs * rel_embs * tail_embs).sum(dim=-1) + return output + + +class InGramEntityLayer(nn.Module): + def __init__(self, dim_in_ent, dim_out_ent, dim_rel, bias=True, num_head=8): + super(InGramEntityLayer, self).__init__() + + self.dim_out_ent = dim_out_ent + self.dim_hid_ent = dim_out_ent // num_head + assert dim_out_ent == self.dim_hid_ent * num_head + self.num_head = num_head + + self.attn_proj = nn.Linear(2 * dim_in_ent + dim_rel, dim_out_ent, bias=bias) + self.attn_vec = nn.Parameter(torch.zeros((1, num_head, self.dim_hid_ent))) + self.aggr_proj = nn.Linear(dim_in_ent + dim_rel, dim_out_ent, bias=bias) + + self.dim_rel = dim_rel + self.act = nn.LeakyReLU(negative_slope=0.2) + self.bias = bias + self.param_init() + + def param_init(self): + nn.init.xavier_normal_(self.attn_proj.weight, gain=nn.init.calculate_gain('relu')) + nn.init.xavier_normal_(self.attn_vec, gain=nn.init.calculate_gain('relu')) + nn.init.xavier_normal_(self.aggr_proj.weight, gain=nn.init.calculate_gain('relu')) + if self.bias: + nn.init.zeros_(self.attn_proj.bias) + nn.init.zeros_(self.aggr_proj.bias) + + def forward(self, emb_ent, emb_rel, triplets): + num_ent = len(emb_ent) + num_rel = len(emb_rel) + head_idxs = triplets[..., 0] + rel_idxs = triplets[..., 1] + tail_idxs = triplets[..., 2] + + ent_freq = torch.zeros((num_ent,)).cuda().index_add(dim=0, index=tail_idxs, \ + source=torch.ones_like(tail_idxs, + dtype=torch.float).cuda()).unsqueeze( + dim=1) + + self_rel = torch.zeros((num_ent, self.dim_rel)).cuda().index_add(dim=0, index=tail_idxs, + source=emb_rel[rel_idxs]) / ent_freq + + # add self-loops + emb_rels = torch.cat([emb_rel[rel_idxs], self_rel], dim=0) + head_idxs = torch.cat([head_idxs, torch.arange(num_ent).cuda()], dim=0) + tail_idxs = torch.cat([tail_idxs, torch.arange(num_ent).cuda()], dim=0) + + concat_mat_att = torch.cat([emb_ent[tail_idxs], emb_ent[head_idxs], \ + emb_rels], dim=-1) + + attn_val_raw = (self.act(self.attn_proj(concat_mat_att).view(-1, self.num_head, self.dim_hid_ent)) * + self.attn_vec).sum(dim=-1, keepdim=True) + + scatter_idx = tail_idxs.unsqueeze(dim=-1).repeat(1, self.num_head).unsqueeze(dim=-1) + + attn_val_max = torch.zeros((num_ent, self.num_head, 1)).cuda().scatter_reduce(dim=0, \ + index=scatter_idx, \ + src=attn_val_raw, reduce='amax', \ + include_self=False) + attn_val = torch.exp(attn_val_raw - attn_val_max[tail_idxs]) + + attn_sums = torch.zeros((num_ent, self.num_head, 1)).cuda().index_add(dim=0, index=tail_idxs, source=attn_val) + + beta = attn_val / (attn_sums[tail_idxs] + 1e-16) + + concat_mat = torch.cat([emb_ent[head_idxs], emb_rels], dim=-1) + + aggr_val = beta * self.aggr_proj(concat_mat).view(-1, self.num_head, self.dim_hid_ent) + + output = torch.zeros((num_ent, self.num_head, self.dim_hid_ent)).cuda().index_add(dim=0, index=tail_idxs, + source=aggr_val) + + return output.flatten(1, -1) + + +class InGramRelationLayer(nn.Module): + def __init__(self, dim_in_rel, dim_out_rel, num_bin, bias=True, num_head=8): + super(InGramRelationLayer, self).__init__() + + self.dim_out_rel = dim_out_rel + self.dim_hid_rel = dim_out_rel // num_head + assert dim_out_rel == self.dim_hid_rel * num_head + + self.attn_proj = nn.Linear(2 * dim_in_rel, dim_out_rel, bias=bias) + self.attn_bin = nn.Parameter(torch.zeros(num_bin, num_head, 1)) + self.attn_vec = nn.Parameter(torch.zeros(1, num_head, self.dim_hid_rel)) + self.aggr_proj = nn.Linear(dim_in_rel, dim_out_rel, bias=bias) + self.num_head = num_head + + self.act = nn.LeakyReLU(negative_slope=0.2) + self.num_bin = num_bin + self.bias = bias + + self.param_init() + + def param_init(self): + nn.init.xavier_normal_(self.attn_proj.weight, gain=nn.init.calculate_gain('relu')) + nn.init.xavier_normal_(self.attn_vec, gain=nn.init.calculate_gain('relu')) + nn.init.xavier_normal_(self.aggr_proj.weight, gain=nn.init.calculate_gain('relu')) + if self.bias: + nn.init.zeros_(self.attn_proj.bias) + nn.init.zeros_(self.aggr_proj.bias) + + def forward(self, emb_rel, relation_triplets): + num_rel = len(emb_rel) + + head_idxs = relation_triplets[..., 0] + tail_idxs = relation_triplets[..., 1] + concat_mat = torch.cat([emb_rel[head_idxs], emb_rel[tail_idxs]], dim=-1) + + attn_val_raw = (self.act(self.attn_proj(concat_mat).view(-1, self.num_head, self.dim_hid_rel)) * \ + self.attn_vec).sum(dim=-1, keepdim=True) + self.attn_bin[relation_triplets[..., 2]] + + scatter_idx = head_idxs.unsqueeze(dim=-1).repeat(1, self.num_head).unsqueeze(dim=-1) + + attn_val_max = torch.zeros((num_rel, self.num_head, 1)).cuda().scatter_reduce(dim=0, \ + index=scatter_idx, \ + src=attn_val_raw, reduce='amax', \ + include_self=False) + attn_val = torch.exp(attn_val_raw - attn_val_max[head_idxs]) + + attn_sums = torch.zeros((num_rel, self.num_head, 1)).cuda().index_add(dim=0, index=head_idxs, source=attn_val) + + beta = attn_val / (attn_sums[head_idxs] + 1e-16) + + output = torch.zeros((num_rel, self.num_head, self.dim_hid_rel)).cuda().index_add(dim=0, \ + index=head_idxs, + source=beta * self.aggr_proj( + emb_rel[tail_idxs]).view( + -1, self.num_head, + self.dim_hid_rel)) + + return output.flatten(1, -1) diff --git a/openhgnn/models/LTE.py b/openhgnn/models/LTE.py new file mode 100644 index 00000000..f6f83459 --- /dev/null +++ b/openhgnn/models/LTE.py @@ -0,0 +1,197 @@ + +from . import BaseModel, register_model +import torch +from torch import nn +from torch.nn import functional as F +@register_model('LTE') +class LTE(BaseModel): + + @classmethod + def build_model_from_args(cls, config): + return cls(config) + + def __init__(self, config): + super().__init__() + + self.model = TransE(config) + + def forward(self, *args): + return self.model(*args) + + def extra_loss(self): + pass + + + + +def get_param(shape): + param = nn.Parameter(torch.Tensor(*shape)) + nn.init.xavier_normal_(param.data) + return param + + +class LTEModel(nn.Module): + def __init__(self, params=None): + super(LTEModel, self).__init__() + self.bceloss = torch.nn.BCELoss() + self.p = params + num_ents = self.p.num_ents + num_rels = self.p.num_rels + self.init_embed = get_param((num_ents, self.p.init_dim)) + self.device = "cuda" + + self.init_rel = get_param((num_rels * 2, self.p.init_dim)) + + self.bias = nn.Parameter(torch.zeros(num_ents)) + + self.h_ops_dict = nn.ModuleDict({ + 'p': nn.Linear(self.p.init_dim, self.p.gcn_dim, bias=False), + 'b': nn.BatchNorm1d(self.p.gcn_dim), + 'd': nn.Dropout(self.p.hid_drop), + 'a': nn.Tanh(), + }) + + self.t_ops_dict = nn.ModuleDict({ + 'p': nn.Linear(self.p.init_dim, self.p.gcn_dim, bias=False), + 'b': nn.BatchNorm1d(self.p.gcn_dim), + 'd': nn.Dropout(self.p.hid_drop), + 'a': nn.Tanh(), + }) + + self.r_ops_dict = nn.ModuleDict({ + 'p': nn.Linear(self.p.init_dim, self.p.gcn_dim, bias=False), + 'b': nn.BatchNorm1d(self.p.gcn_dim), + 'd': nn.Dropout(self.p.hid_drop), + 'a': nn.Tanh(), + }) + + self.x_ops = self.p.x_ops + self.r_ops = self.p.r_ops + self.diff_ht = False + + def calc_loss(self, pred, label): + return self.loss(pred, label) + + def loss(self, pred, true_label): + return self.bceloss(pred, true_label) + + def exop(self, x, r, x_ops=None, r_ops=None, diff_ht=False): + x_head = x_tail = x + if len(x_ops) > 0: + for x_op in x_ops.split("."): + if diff_ht: + x_head = self.h_ops_dict[x_op](x_head) + x_tail = self.t_ops_dict[x_op](x_tail) + else: + x_head = x_tail = self.h_ops_dict[x_op](x_head) + + if len(r_ops) > 0: + for r_op in r_ops.split("."): + r = self.r_ops_dict[r_op](r) + + return x_head, x_tail, r + + +class TransE(LTEModel): + def __init__(self, params=None): + super(self.__class__, self).__init__( params) + num_ents=params.num_ents + num_rels=params.num_rels + self.loop_emb = get_param([1, self.p.init_dim]) + + def forward(self,g, sub, rel): + x = self.init_embed + r = self.init_rel + + x_h, x_t, r = self.exop(x - self.loop_emb, r, self.x_ops, self.r_ops) + + sub_emb = torch.index_select(x_h, 0, sub) + rel_emb = torch.index_select(r, 0, rel) + all_ent = x_t + + obj_emb = sub_emb + rel_emb + x = self.p.gamma - \ + torch.norm(obj_emb.unsqueeze(1) - all_ent, p=1, dim=2) + score = torch.sigmoid(x) + + return score + + +class DistMult(LTEModel): + def __init__(self, num_ents, num_rels, params=None): + super(self.__class__, self).__init__(num_ents, num_rels, params) + + def forward(self, g, sub, rel): + x = self.init_embed + r = self.init_rel + + x_h, x_t, r = self.exop(x, r, self.x_ops, self.r_ops) + + sub_emb = torch.index_select(x_h, 0, sub) + rel_emb = torch.index_select(r, 0, rel) + all_ent = x_t + + obj_emb = sub_emb * rel_emb + x = torch.mm(obj_emb, all_ent.transpose(1, 0)) + x += self.bias.expand_as(x) + score = torch.sigmoid(x) + + return score + + +class ConvE(LTEModel): + def __init__(self, num_ents, num_rels, params=None): + super(self.__class__, self).__init__(num_ents, num_rels, params) + self.bn0 = torch.nn.BatchNorm2d(1) + self.bn1 = torch.nn.BatchNorm2d(self.p.num_filt) + self.bn2 = torch.nn.BatchNorm1d(self.p.embed_dim) + + self.hidden_drop = torch.nn.Dropout(self.p.hid_drop) + self.hidden_drop2 = torch.nn.Dropout(self.p.conve_hid_drop) + self.feature_drop = torch.nn.Dropout(self.p.feat_drop) + self.m_conv1 = torch.nn.Conv2d(1, out_channels=self.p.num_filt, kernel_size=(self.p.ker_sz, self.p.ker_sz), + stride=1, padding=0, bias=self.p.bias) + + flat_sz_h = int(2 * self.p.k_w) - self.p.ker_sz + 1 + flat_sz_w = self.p.k_h - self.p.ker_sz + 1 + self.flat_sz = flat_sz_h * flat_sz_w * self.p.num_filt + self.fc = torch.nn.Linear(self.flat_sz, self.p.embed_dim) + + def concat(self, e1_embed, rel_embed): + e1_embed = e1_embed.view(-1, 1, self.p.embed_dim) + rel_embed = rel_embed.view(-1, 1, self.p.embed_dim) + stack_inp = torch.cat([e1_embed, rel_embed], 1) + stack_inp = torch.transpose(stack_inp, 2, 1).reshape( + (-1, 1, 2 * self.p.k_w, self.p.k_h)) + return stack_inp + + def forward(self, g, sub, rel): + x = self.init_embed + r = self.init_rel + + x_h, x_t, r = self.exop(x, r, self.x_ops, self.r_ops) + + sub_emb = torch.index_select(x_h, 0, sub) + rel_emb = torch.index_select(r, 0, rel) + all_ent = x_t + + stk_inp = self.concat(sub_emb, rel_emb) + x = self.bn0(stk_inp) + x = self.m_conv1(x) + x = self.bn1(x) + x = F.relu(x) + x = self.feature_drop(x) + x = x.view(-1, self.flat_sz) + x = self.fc(x) + x = self.hidden_drop2(x) + x = self.bn2(x) + x = F.relu(x) + + x = torch.mm(x, all_ent.transpose(1, 0)) + x += self.bias.expand_as(x) + + score = torch.sigmoid(x) + return score + + + diff --git a/openhgnn/models/LTE_Transe.py b/openhgnn/models/LTE_Transe.py new file mode 100644 index 00000000..a16cccc9 --- /dev/null +++ b/openhgnn/models/LTE_Transe.py @@ -0,0 +1,180 @@ +from . import BaseModel, register_model + +@register_model('LTE_Transe') +class LTE_Transe(BaseModel): + + @classmethod + def build_model_from_args(cls, config): + return cls(config) + + def __init__(self, config): + super().__init__() + + self.model = GCN_TransE(config) + + def forward(self, *args): + return self.model(*args) + + def extra_loss(self): + pass + +import torch +from torch import nn +import dgl +from ..layers.rgcn_layer import RelGraphConv +from ..layers.compgcn_layer import CompGCNCov +import torch.nn.functional as F + + +class GCNs(nn.Module): + def __init__(self,args, num_ent, num_rel, num_base, init_dim, gcn_dim, embed_dim, n_layer, edge_type, edge_norm, + conv_bias=True, gcn_drop=0., opn='mult', wni=False, wsi=False, encoder='compgcn', use_bn=True, ltr=True): + super(GCNs, self).__init__() + num_ent=args.num_ent + num_rel=args.num_rel + num_base=args.num_base + init_dim=args.init_dim + gcn_dim=args.gcn_dim + embed_dim=args.embed_dim + n_layer=args.n_layer + edge_type=args.edge_type + edge_norm=args.edge_norm + conv_bias = True + if args.conv_bias is not None: + conv_bias=args.conv_bias + gcn_drop = 0. + if args.gcn_drop is not None: + gcn_drop=args.gcn_drop + opn = 'mult' + if args.opn is not None: + opn=args.opn + wni = False + if args.wni is not None: + wni=args.wni + wsi = False + if args.wsi is not None: + wsi=args.wsi + encoder = 'compgcn' + if args.encoder is not None: + encoder=args.encoder + use_bn = True + if args.use_bn is not None: + use_bn=args.use_bn + ltr = True + if args.ltr is not None: + ltr=args.ltr + self.act = torch.tanh + self.loss = nn.BCELoss() + self.num_ent, self.num_rel, self.num_base = num_ent, num_rel, num_base + self.init_dim, self.gcn_dim, self.embed_dim = init_dim, gcn_dim, embed_dim + self.conv_bias = conv_bias + self.gcn_drop = gcn_drop + self.opn = opn + self.edge_type = edge_type # [E] + self.edge_norm = edge_norm # [E] + self.n_layer = n_layer + + self.wni = wni + + self.encoder = encoder + + self.init_embed = self.get_param([self.num_ent, self.init_dim]) + self.init_rel = self.get_param([self.num_rel * 2, self.init_dim]) + + if encoder == 'compgcn': + if n_layer < 3: + self.conv1 = CompGCNCov(self.init_dim, self.gcn_dim, self.act, conv_bias, gcn_drop, opn, num_base=-1, + num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr) + self.conv2 = CompGCNCov(self.gcn_dim, self.embed_dim, self.act, conv_bias, gcn_drop, + opn, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr) if n_layer == 2 else None + else: + self.conv1 = CompGCNCov(self.init_dim, self.gcn_dim, self.act, conv_bias, gcn_drop, opn, num_base=-1, + num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr) + self.conv2 = CompGCNCov(self.gcn_dim, self.gcn_dim, self.act, conv_bias, gcn_drop, opn, num_base=-1, + num_rel=self.num_rel, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr) + self.conv3 = CompGCNCov(self.gcn_dim, self.embed_dim, self.act, conv_bias, gcn_drop, + opn, wni=wni, wsi=wsi, use_bn=use_bn, ltr=ltr) + elif encoder == 'rgcn': + self.conv1 = RelGraphConv(self.init_dim, self.gcn_dim, self.num_rel*2, "bdd", + num_bases=self.num_base, activation=self.act, self_loop=(not wsi), dropout=gcn_drop, wni=wni) + self.conv2 = RelGraphConv(self.gcn_dim, self.embed_dim, self.num_rel*2, "bdd", num_bases=self.num_base, + activation=self.act, self_loop=(not wsi), dropout=gcn_drop, wni=wni) if n_layer == 2 else None + + self.bias = nn.Parameter(torch.zeros(self.num_ent)) + + def get_param(self, shape): + param = nn.Parameter(torch.Tensor(*shape)) + nn.init.xavier_normal_(param, gain=nn.init.calculate_gain('relu')) + return param + + def calc_loss(self, pred, label): + return self.loss(pred, label) + + def forward_base(self, g, subj, rel, drop1, drop2): + """ + :param g: graph + :param sub: subjects in a batch [batch] + :param rel: relations in a batch [batch] + :param drop1: dropout rate in first layer + :param drop2: dropout rate in second layer + :return: sub_emb: [batch, D] + rel_emb: [num_rel*2, D] + x: [num_ent, D] + """ + x, r = self.init_embed, self.init_rel # embedding of relations + + if self.n_layer > 0: + if self.encoder == 'compgcn': + if self.n_layer < 3: + x, r = self.conv1(g, x, r, self.edge_type, self.edge_norm) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv2( + g, x, r, self.edge_type, self.edge_norm) if self.n_layer == 2 else (x, r) + x = drop2(x) if self.n_layer == 2 else x + else: + x, r = self.conv1(g, x, r, self.edge_type, self.edge_norm) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv2(g, x, r, self.edge_type, self.edge_norm) + x = drop1(x) # embeddings of entities [num_ent, dim] + x, r = self.conv3(g, x, r, self.edge_type, self.edge_norm) + x = drop2(x) + elif self.encoder == 'rgcn': + x = self.conv1(g, x, self.edge_type, + self.edge_norm.unsqueeze(-1)) + x = drop1(x) # embeddings of entities [num_ent, dim] + x = self.conv2( + g, x, self.edge_type, self.edge_norm.unsqueeze(-1)) if self.n_layer == 2 else x + x = drop2(x) if self.n_layer == 2 else x + + # filter out embeddings of subjects in this batch + sub_emb = torch.index_select(x, 0, subj) + # filter out embeddings of relations in this batch + rel_emb = torch.index_select(r, 0, rel) + + return sub_emb, rel_emb, x + + +class GCN_TransE(GCNs): + def __init__(self, args): + super(GCN_TransE, self).__init__(args) + + self.drop = nn.Dropout(args.hid_drop) + self.gamma = args.gamma + + def forward(self, g, subj, rel): + """ + :param g: dgl graph + :param sub: subject in batch [batch_size] + :param rel: relation in batch [batch_size] + :return: score: [batch_size, ent_num], the prob in link-prediction + """ + sub_emb, rel_emb, all_ent = self.forward_base( + g, subj, rel, self.drop, self.drop) + obj_emb = sub_emb + rel_emb + + x = self.gamma - torch.norm(obj_emb.unsqueeze(1) - all_ent, p=1, dim=2) + + score = torch.sigmoid(x) + + return score + diff --git a/openhgnn/models/NBF.py b/openhgnn/models/NBF.py new file mode 100644 index 00000000..a57f22d6 --- /dev/null +++ b/openhgnn/models/NBF.py @@ -0,0 +1,862 @@ +import copy +from collections.abc import Sequence +import torch +from torch import nn, autograd +from functools import reduce +from torch.nn import functional as F +from torch import Tensor +from typing import Any, Optional +from . import BaseModel, register_model +from typing import Tuple + + + + +def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): + if dim < 0: + dim = other.dim() + dim + if src.dim() == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.dim(), other.dim()): + src = src.unsqueeze(-1) + src = src.expand(other.size()) + return src + + + +def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + index = broadcast(index, src, dim) + if out is None: + size = list(src.size()) + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + out = torch.zeros(size, dtype=src.dtype, device=src.device) + return out.scatter_add_(dim, index, src) + else: + return out.scatter_add_(dim, index, src) + + +def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + return scatter_sum(src, index, dim, out, dim_size) + + +def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size) + + +def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + out = scatter_sum(src, index, dim, out, dim_size) + dim_size = out.size(dim) + + index_dim = dim + if index_dim < 0: + index_dim = index_dim + src.dim() + if index.dim() <= index_dim: + index_dim = index.dim() - 1 + + ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) + count = scatter_sum(ones, index, index_dim, None, dim_size) + count[count < 1] = 1 + count = broadcast(count, out, dim) + if out.is_floating_point(): + out.true_divide_(count) + else: + out.div_(count, rounding_mode='floor') + return out + + +def scatter_min( + src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size) + + +def scatter_max( + src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size) + + +def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, + reduce: str = "sum") -> torch.Tensor: + + if reduce == 'sum' or reduce == 'add': + return scatter_sum(src, index, dim, out, dim_size) + if reduce == 'mul': + return scatter_mul(src, index, dim, out, dim_size) + elif reduce == 'mean': + return scatter_mean(src, index, dim, out, dim_size) + elif reduce == 'min': + return scatter_min(src, index, dim, out, dim_size)[0] + elif reduce == 'max': + return scatter_max(src, index, dim, out, dim_size)[0] + else: + raise ValueError + + + + + +@register_model('NBF') +class NBFNet(BaseModel): + @classmethod + def build_model_from_args(cls, args, hg): + return cls(input_dim=args.input_dim, + hidden_dims=args.hidden_dims, + num_relation = args.num_relation, + message_func = args.message_func, + aggregate_func = args.aggregate_func, + short_cut = args.short_cut, + layer_norm = args.layer_norm, + dependent = args.dependent, + ) + + + def __init__(self, input_dim, hidden_dims, num_relation, message_func="distmult", aggregate_func="pna", + short_cut=False, layer_norm=False, activation="relu", concat_hidden=False, num_mlp_layer=2, + dependent=True, remove_one_hop=False, num_beam=10, path_topk=10): + super(NBFNet, self).__init__() + + if not isinstance(hidden_dims, Sequence): + hidden_dims = [hidden_dims] + + self.dims = [input_dim] + list(hidden_dims) + self.num_relation = num_relation + self.short_cut = short_cut # whether to use residual connections between GNN layers + self.concat_hidden = concat_hidden # whether to compute final states as a function of all layer outputs or last + self.remove_one_hop = remove_one_hop # whether to dynamically remove one-hop edges from edge_index + self.num_beam = num_beam + self.path_topk = path_topk + + self.layers = nn.ModuleList() + for i in range(len(self.dims) - 1): + self.layers.append(GeneralizedRelationalConv(self.dims[i], self.dims[i + 1], num_relation, + self.dims[0], message_func, aggregate_func, layer_norm, + activation, dependent)) + + feature_dim = (sum(hidden_dims) if concat_hidden else hidden_dims[-1]) + input_dim + + # additional relation embedding which serves as an initial 'query' for the NBFNet forward pass + # each layer has its own learnable relations matrix, so we send the total number of relations, too + self.query = nn.Embedding(num_relation, input_dim) + self.mlp = nn.Sequential() + mlp = [] + for i in range(num_mlp_layer - 1): + mlp.append(nn.Linear(feature_dim, feature_dim)) + mlp.append(nn.ReLU()) + mlp.append(nn.Linear(feature_dim, 1)) + self.mlp = nn.Sequential(*mlp) + + def remove_easy_edges(self, data, h_index, t_index, r_index=None): + # + # we remove training edges (we need to predict them at training time) from the edge index + # think of it as a dynamic edge dropout + h_index_ext = torch.cat([h_index, t_index], dim=-1) + t_index_ext = torch.cat([t_index, h_index], dim=-1) + r_index_ext = torch.cat([r_index, r_index + self.num_relation // 2], dim=-1) + if self.remove_one_hop: + # we remove all existing immediate edges between heads and tails in the batch + edge_index = data.edge_index + easy_edge = torch.stack([h_index_ext, t_index_ext]).flatten(1) + index = edge_match(edge_index, easy_edge)[0] + mask = ~index_to_mask(index, data.num_edges) + else: + # we remove existing immediate edges between heads and tails in the batch with the given relation + edge_index = torch.cat([data.edge_index, data.edge_type.unsqueeze(0)]) + # note that here we add relation types r_index_ext to the matching query + easy_edge = torch.stack([h_index_ext, t_index_ext, r_index_ext]).flatten(1) + index = edge_match(edge_index, easy_edge)[0] + mask = ~index_to_mask(index, data.num_edges) + + data = copy.copy(data) + data.edge_index = data.edge_index[:, mask] + data.edge_type = data.edge_type[mask] + return data + + def negative_sample_to_tail(self, h_index, t_index, r_index): + # convert p(h | t, r) to p(t' | h', r') + # h' = t, r' = r^{-1}, t' = h + is_t_neg = (h_index == h_index[:, [0]]).all(dim=-1, keepdim=True) + new_h_index = torch.where(is_t_neg, h_index, t_index) + new_t_index = torch.where(is_t_neg, t_index, h_index) + new_r_index = torch.where(is_t_neg, r_index, r_index + self.num_relation // 2) + return new_h_index, new_t_index, new_r_index + + def bellmanford(self, data, h_index, r_index, separate_grad=False): + batch_size = len(r_index) + + + query = self.query(r_index) + index = h_index.unsqueeze(-1).expand_as(query) + + + boundary = torch.zeros(batch_size, data.num_nodes, self.dims[0], device=h_index.device) + # by the scatter operation we put query (relation) embeddings as init features of source (index) nodes + boundary.scatter_add_(1, index.unsqueeze(1), query.unsqueeze(1)) + size = (data.num_nodes, data.num_nodes) + edge_weight = torch.ones(data.num_edges, device=h_index.device) + + hiddens = [] + edge_weights = [] + layer_input = boundary + + for layer in self.layers: + if separate_grad: + edge_weight = edge_weight.clone().requires_grad_() + # Bellman-Ford iteration, we send the original boundary condition in addition to the updated node states + + hidden = layer(layer_input, query, boundary, data.edge_index, data.edge_type, size, edge_weight) + if self.short_cut and hidden.shape == layer_input.shape: + # residual connection here + hidden = hidden + layer_input + hiddens.append(hidden) + edge_weights.append(edge_weight) + layer_input = hidden + + # original query (relation type) embeddings + node_query = query.unsqueeze(1).expand(-1, data.num_nodes, -1) # (batch_size, num_nodes, input_dim) + if self.concat_hidden: + output = torch.cat(hiddens + [node_query], dim=-1) + else: + output = torch.cat([hiddens[-1], node_query], dim=-1) + + return { + "node_feature": output, + "edge_weights": edge_weights, + } + + def forward(self, data, batch):# data == train_data + h_index, t_index, r_index = batch.unbind(-1) + + + if self.training: + # Edge dropout in the training mode + # here we want to remove immediate edges (head, relation, tail) from the edge_index and edge_types + # to make NBFNet iteration learn non-trivial paths + data = self.remove_easy_edges(data, h_index, t_index, r_index) + + + data.num_edges = data.edge_index.shape[1] + + shape = h_index.shape + # turn all triples in a batch into a tail prediction mode + h_index, t_index, r_index = self.negative_sample_to_tail(h_index, t_index, r_index) + assert (h_index[:, [0]] == h_index).all() + assert (r_index[:, [0]] == r_index).all() + + # message passing and updated node representations + output = self.bellmanford(data, h_index[:, 0], r_index[:, 0]) # (num_nodes, batch_size, feature_dim) + feature = output["node_feature"] + index = t_index.unsqueeze(-1).expand(-1, -1, feature.shape[-1]) + # extract representations of tail entities from the updated node states + feature = feature.gather(1, index) # (batch_size, num_negative + 1, feature_dim) + + # probability logit for each tail node in the batch + # (batch_size, num_negative + 1, dim) -> (batch_size, num_negative + 1) + score = self.mlp(feature).squeeze(-1) + return score.view(shape) + + def visualize(self, data, batch): + assert batch.shape == (1, 3) + h_index, t_index, r_index = batch.unbind(-1) + + output = self.bellmanford(data, h_index, r_index, separate_grad=True) + feature = output["node_feature"] + edge_weights = output["edge_weights"] + + index = t_index.unsqueeze(0).unsqueeze(-1).expand(-1, -1, feature.shape[-1]) + feature = feature.gather(1, index).squeeze(0) + score = self.mlp(feature).squeeze(-1) + + edge_grads = autograd.grad(score, edge_weights) + distances, back_edges = self.beam_search_distance(data, edge_grads, h_index, t_index, self.num_beam) + paths, weights = self.topk_average_length(distances, back_edges, t_index, self.path_topk) + + return paths, weights + + @torch.no_grad() + def beam_search_distance(self, data, edge_grads, h_index, t_index, num_beam=10): + # beam search the top-k distance from h to t (and to every other node) + num_nodes = data.num_nodes + input = torch.full((num_nodes, num_beam), float("-inf"), device=h_index.device) + input[h_index, 0] = 0 + edge_mask = data.edge_index[0, :] != t_index + + distances = [] + back_edges = [] + for edge_grad in edge_grads: + # we don't allow any path goes out of t once it arrives at t + node_in, node_out = data.edge_index[:, edge_mask] + relation = data.edge_type[edge_mask] + edge_grad = edge_grad[edge_mask] + + message = input[node_in] + edge_grad.unsqueeze(-1) # (num_edges, num_beam) + # (num_edges, num_beam, 3) + msg_source = torch.stack([node_in, node_out, relation], dim=-1).unsqueeze(1).expand(-1, num_beam, -1) + + # (num_edges, num_beam) + is_duplicate = torch.isclose(message.unsqueeze(-1), message.unsqueeze(-2)) & \ + (msg_source.unsqueeze(-2) == msg_source.unsqueeze(-3)).all(dim=-1) + # pick the first occurrence as the ranking in the previous node's beam + # this makes deduplication easier later + # and store it in msg_source + is_duplicate = is_duplicate.float() - \ + torch.arange(num_beam, dtype=torch.float, device=message.device) / (num_beam + 1) + prev_rank = is_duplicate.argmax(dim=-1, keepdim=True) + msg_source = torch.cat([msg_source, prev_rank], dim=-1) # (num_edges, num_beam, 4) + + node_out, order = node_out.sort() + node_out_set = torch.unique(node_out) + # sort messages w.r.t. node_out + message = message[order].flatten() # (num_edges * num_beam) + msg_source = msg_source[order].flatten(0, -2) # (num_edges * num_beam, 4) + size = node_out.bincount(minlength=num_nodes) + msg2out = size_to_index(size[node_out_set] * num_beam) + # deduplicate messages that are from the same source and the same beam + is_duplicate = (msg_source[1:] == msg_source[:-1]).all(dim=-1) + is_duplicate = torch.cat([torch.zeros(1, dtype=torch.bool, device=message.device), is_duplicate]) + message = message[~is_duplicate] + msg_source = msg_source[~is_duplicate] + msg2out = msg2out[~is_duplicate] + size = msg2out.bincount(minlength=len(node_out_set)) + + if not torch.isinf(message).all(): + # take the topk messages from the neighborhood + # distance: (len(node_out_set) * num_beam) + distance, rel_index = scatter_topk(message, size, k=num_beam) + abs_index = rel_index + (size.cumsum(0) - size).unsqueeze(-1) + # store msg_source for backtracking + back_edge = msg_source[abs_index] # (len(node_out_set) * num_beam, 4) + distance = distance.view(len(node_out_set), num_beam) + back_edge = back_edge.view(len(node_out_set), num_beam, 4) + # scatter distance / back_edge back to all nodes + distance = scatter_sum(distance, node_out_set, dim=0, dim_size=num_nodes) # (num_nodes, num_beam) + back_edge = scatter_sum(back_edge, node_out_set, dim=0, dim_size=num_nodes) # (num_nodes, num_beam, 4) + else: + distance = torch.full((num_nodes, num_beam), float("-inf"), device=message.device) + back_edge = torch.zeros(num_nodes, num_beam, 4, dtype=torch.long, device=message.device) + + distances.append(distance) + back_edges.append(back_edge) + input = distance + + return distances, back_edges + + def topk_average_length(self, distances, back_edges, t_index, k=10): + # backtrack distances and back_edges to generate the paths + paths = [] + average_lengths = [] + + for i in range(len(distances)): + distance, order = distances[i][t_index].flatten(0, -1).sort(descending=True) + back_edge = back_edges[i][t_index].flatten(0, -2)[order] + for d, (h, t, r, prev_rank) in zip(distance[:k].tolist(), back_edge[:k].tolist()): + if d == float("-inf"): + break + path = [(h, t, r)] + for j in range(i - 1, -1, -1): + h, t, r, prev_rank = back_edges[j][h, prev_rank].tolist() + path.append((h, t, r)) + paths.append(path[::-1]) + average_lengths.append(d / len(path)) + + if paths: + average_lengths, paths = zip(*sorted(zip(average_lengths, paths), reverse=True)[:k]) + + return paths, average_lengths + + +def index_to_mask(index, size): + index = index.view(-1) + size = int(index.max()) + 1 if size is None else size + mask = index.new_zeros(size, dtype=torch.bool) + mask[index] = True + return mask + + +def size_to_index(size): + range = torch.arange(len(size), device=size.device) + index2sample = range.repeat_interleave(size) + return index2sample + + +def multi_slice_mask(starts, ends, length): + values = torch.cat([torch.ones_like(starts), -torch.ones_like(ends)]) + slices = torch.cat([starts, ends]) + mask = scatter_sum(values, slices, dim=0, dim_size=length + 1)[:-1] + mask = mask.cumsum(0).bool() + return mask + + +def scatter_extend(data, size, input, input_size): + new_size = size + input_size + new_cum_size = new_size.cumsum(0) + new_data = torch.zeros(new_cum_size[-1], *data.shape[1:], dtype=data.dtype, device=data.device) + starts = new_cum_size - new_size + ends = starts + size + index = multi_slice_mask(starts, ends, new_cum_size[-1]) + new_data[index] = data + new_data[~index] = input + return new_data, new_size + + +def scatter_topk(input, size, k, largest=True): + index2graph = size_to_index(size) + index2graph = index2graph.view([-1] + [1] * (input.ndim - 1)) + + mask = ~torch.isinf(input) + max = input[mask].max().item() + min = input[mask].min().item() + safe_input = input.clamp(2 * min - max, 2 * max - min) + offset = (max - min) * 4 + if largest: + offset = -offset + input_ext = safe_input + offset * index2graph + index_ext = input_ext.argsort(dim=0, descending=largest) + num_actual = size.clamp(max=k) + num_padding = k - num_actual + starts = size.cumsum(0) - size + ends = starts + num_actual + mask = multi_slice_mask(starts, ends, len(index_ext)).nonzero().flatten() + + if (num_padding > 0).any(): + # special case: size < k, pad with the last valid index + padding = ends - 1 + padding2graph = size_to_index(num_padding) + mask = scatter_extend(mask, num_actual, padding[padding2graph], num_padding)[0] + + index = index_ext[mask] # (N * k, ...) + value = input.gather(0, index) + if isinstance(k, torch.Tensor) and k.shape == size.shape: + value = value.view(-1, *input.shape[1:]) + index = index.view(-1, *input.shape[1:]) + index = index - (size.cumsum(0) - size).repeat_interleave(k).view([-1] + [1] * (index.ndim - 1)) + else: + value = value.view(-1, k, *input.shape[1:]) + index = index.view(-1, k, *input.shape[1:]) + index = index - (size.cumsum(0) - size).view([-1] + [1] * (index.ndim - 1)) + + return value, index + + + + +def edge_match(edge_index, query_index): + # O((n + q)logn) time + # O(n) memory + # edge_index: big underlying graph + # query_index: edges to match + + # preparing unique hashing of edges, base: (max_node, max_relation) + 1 + base = edge_index.max(dim=1)[0] + 1 + # we will map edges to long ints, so we need to make sure the maximum product is less than MAX_LONG_INT + # idea: max number of edges = num_nodes * num_relations + # e.g. for a graph of 10 nodes / 5 relations, edge IDs 0...9 mean all possible outgoing edge types from node 0 + # given a tuple (h, r), we will search for all other existing edges starting from head h + assert reduce(int.__mul__, base.tolist()) < torch.iinfo(torch.long).max + scale = base.cumprod(0) + scale = scale[-1] // scale + + # hash both the original edge index and the query index to unique integers + edge_hash = (edge_index * scale.unsqueeze(-1)).sum(dim=0) + edge_hash, order = edge_hash.sort() + query_hash = (query_index * scale.unsqueeze(-1)).sum(dim=0) + + # matched ranges: [start[i], end[i]) + start = torch.bucketize(query_hash, edge_hash) + end = torch.bucketize(query_hash, edge_hash, right=True) + # num_match shows how many edges satisfy the (h, r) pattern for each query in the batch + num_match = end - start + + # generate the corresponding ranges + offset = num_match.cumsum(0) - num_match + range = torch.arange(num_match.sum(), device=edge_index.device) + range = range + (start - offset).repeat_interleave(num_match) + + return order[range], num_match + + +def negative_sampling(data, batch, num_negative, strict=True):# data==train_data + batch_size = len(batch) + pos_h_index, pos_t_index, pos_r_index = batch.t() + + # strict negative sampling vs random negative sampling + if strict: + t_mask, h_mask = strict_negative_mask(data, batch) + t_mask = t_mask[:batch_size // 2] + neg_t_candidate = t_mask.nonzero()[:, 1] + num_t_candidate = t_mask.sum(dim=-1) + # draw samples for negative tails + rand = torch.rand(len(t_mask), num_negative, device=batch.device) + index = (rand * num_t_candidate.unsqueeze(-1)).long() + index = index + (num_t_candidate.cumsum(0) - num_t_candidate).unsqueeze(-1) + neg_t_index = neg_t_candidate[index] + + h_mask = h_mask[batch_size // 2:] + neg_h_candidate = h_mask.nonzero()[:, 1] + num_h_candidate = h_mask.sum(dim=-1) + # draw samples for negative heads + rand = torch.rand(len(h_mask), num_negative, device=batch.device) + index = (rand * num_h_candidate.unsqueeze(-1)).long() + index = index + (num_h_candidate.cumsum(0) - num_h_candidate).unsqueeze(-1) + neg_h_index = neg_h_candidate[index] + else: + neg_index = torch.randint(data.num_nodes, (batch_size, num_negative), device=batch.device) + neg_t_index, neg_h_index = neg_index[:batch_size // 2], neg_index[batch_size // 2:] + + h_index = pos_h_index.unsqueeze(-1).repeat(1, num_negative + 1) + t_index = pos_t_index.unsqueeze(-1).repeat(1, num_negative + 1) + r_index = pos_r_index.unsqueeze(-1).repeat(1, num_negative + 1) + t_index[:batch_size // 2, 1:] = neg_t_index + h_index[batch_size // 2:, 1:] = neg_h_index + + return torch.stack([h_index, t_index, r_index], dim=-1) + + +def all_negative(data, batch): + pos_h_index, pos_t_index, pos_r_index = batch.t() + r_index = pos_r_index.unsqueeze(-1).expand(-1, data.num_nodes) + # generate all negative tails for this batch + all_index = torch.arange(data.num_nodes, device=batch.device) + h_index, t_index = torch.meshgrid(pos_h_index, all_index) + t_batch = torch.stack([h_index, t_index, r_index], dim=-1) + # generate all negative heads for this batch + all_index = torch.arange(data.num_nodes, device=batch.device) + t_index, h_index = torch.meshgrid(pos_t_index, all_index) + h_batch = torch.stack([h_index, t_index, r_index], dim=-1) + + return t_batch, h_batch + + +def strict_negative_mask(data, batch): + # this function makes sure that for a given (h, r) batch we will NOT sample true tails as random negatives + # similarly, for a given (t, r) we will NOT sample existing true heads as random negatives + + pos_h_index, pos_t_index, pos_r_index = batch.t() + + # part I: sample hard negative tails + # edge index of all (head, relation) edges from the underlying graph + edge_index = torch.stack([data.edge_index[0], data.edge_type]) + # edge index of current batch (head, relation) for which we will sample negatives + query_index = torch.stack([pos_h_index, pos_r_index]) + # search for all true tails for the given (h, r) batch + edge_id, num_t_truth = edge_match(edge_index, query_index) + # build an index from the found edges + t_truth_index = data.edge_index[1, edge_id] + sample_id = torch.arange(len(num_t_truth), device=batch.device).repeat_interleave(num_t_truth) + t_mask = torch.ones(len(num_t_truth), data.num_nodes, dtype=torch.bool, device=batch.device) + # assign 0s to the mask with the found true tails + t_mask[sample_id, t_truth_index] = 0 + t_mask.scatter_(1, pos_t_index.unsqueeze(-1), 0) + + # part II: sample hard negative heads + # edge_index[1] denotes tails, so the edge index becomes (t, r) + edge_index = torch.stack([data.edge_index[1], data.edge_type]) + # edge index of current batch (tail, relation) for which we will sample heads + query_index = torch.stack([pos_t_index, pos_r_index]) + # search for all true heads for the given (t, r) batch + edge_id, num_h_truth = edge_match(edge_index, query_index) + # build an index from the found edges + h_truth_index = data.edge_index[0, edge_id] + sample_id = torch.arange(len(num_h_truth), device=batch.device).repeat_interleave(num_h_truth) + h_mask = torch.ones(len(num_h_truth), data.num_nodes, dtype=torch.bool, device=batch.device) + # assign 0s to the mask with the found true heads + h_mask[sample_id, h_truth_index] = 0 + h_mask.scatter_(1, pos_h_index.unsqueeze(-1), 0) + + return t_mask, h_mask + + +def compute_ranking(pred, target, mask=None): + pos_pred = pred.gather(-1, target.unsqueeze(-1)) + if mask is not None: + # filtered ranking + ranking = torch.sum((pos_pred <= pred) & mask, dim=-1) + 1 + else: + # unfiltered ranking + ranking = torch.sum(pos_pred <= pred, dim=-1) + 1 + return ranking + + + + + + +def degree(index: Tensor, num_nodes: Optional[int] = None, + dtype: Optional[torch.dtype] = None) -> Tensor: + + N = maybe_num_nodes(index, num_nodes) + out = torch.zeros((N, ), dtype=dtype, device=index.device) + one = torch.ones((index.size(0), ), dtype=out.dtype, device=out.device) + return out.scatter_add_(0, index, one)#Torch + +def maybe_num_nodes(edge_index, num_nodes=None): + if num_nodes is not None: + return num_nodes + elif isinstance(edge_index, Tensor): + if is_torch_sparse_tensor(edge_index): + return max(edge_index.size(0), edge_index.size(1)) + return int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0 + else: + return max(edge_index.size(0), edge_index.size(1)) + + +def is_torch_sparse_tensor(src:Any) -> bool: + if isinstance(src, Tensor): + if src.layout == torch.sparse_coo: + return True + if src.layout == torch.sparse_csr: + return True + if src.layout == torch.sparse_csc: + return True + return False + + +class GeneralizedRelationalConv(torch.nn.Module): + + eps = 1e-6 + + message2mul = { + "transe": "add", + "distmult": "mul", + } + + def __init__(self, input_dim, output_dim, num_relation, query_input_dim, message_func="distmult", + aggregate_func="pna", layer_norm=False, activation="relu", dependent=True): + super(GeneralizedRelationalConv, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.num_relation = num_relation + self.query_input_dim = query_input_dim + self.message_func = message_func + self.aggregate_func = aggregate_func + self.dependent = dependent + self.node_dim = -2 + + + if layer_norm: + self.layer_norm = nn.LayerNorm(output_dim) + else: + self.layer_norm = None + if isinstance(activation, str): + self.activation = getattr(F, activation) + else: + self.activation = activation + + if self.aggregate_func == "pna": + self.linear = nn.Linear(input_dim * 13, output_dim) + else: + self.linear = nn.Linear(input_dim * 2, output_dim) + + if dependent: + # obtain relation embeddings as a projection of the query relation + self.relation_linear = nn.Linear(query_input_dim, num_relation * input_dim) + else: + # relation embeddings as an independent embedding matrix per each layer + self.relation = nn.Embedding(num_relation, input_dim) + + + + + + + def forward(self, input, query, boundary, + edge_index, edge_type, + size, edge_weight=None): + batch_size = len(query) + + + if self.dependent: + relation = self.relation_linear(query).view(batch_size, self.num_relation, self.input_dim) + else:# falese + relation = self.relation.weight.expand(batch_size, -1, -1) + + if edge_weight is None: + edge_weight = torch.ones(len(edge_type), device=input.device) + + # input.shape == 64,2746,32 , input_j .shape == [64,10692,32] + # input_j = torch.gather(input=input, dim=1, index=edge_index[0] ) + input_j = input.index_select(1, edge_index[0]) + + + + message_res = self.message(input_j=input_j,relation=relation,boundary=boundary,edge_type=edge_type) + aggr_res = self.aggregate(input=message_res,edge_weight=edge_weight, index=edge_index[1],dim_size=input.shape[1]) + + return self.update(update=aggr_res,input=input) + + + def message(self, input_j, relation, boundary, edge_type): + + relation_j = relation.index_select(self.node_dim, edge_type)#Torch + #input_j .shape == [64,10692,32] + if self.message_func == "transe": + message = input_j + relation_j + + elif self.message_func == "distmult": + message = input_j * relation_j + + elif self.message_func == "rotate": + x_j_re, x_j_im = input_j.chunk(2, dim=-1)#Torch + r_j_re, r_j_im = relation_j.chunk(2, dim=-1) + message_re = x_j_re * r_j_re - x_j_im * r_j_im + message_im = x_j_re * r_j_im + x_j_im * r_j_re + message = torch.cat([message_re, message_im], dim=-1) + else: + raise ValueError("Unknown message function `%s`" % self.message_func) + + # augment messages with the boundary condition + message = torch.cat([message, boundary], dim=self.node_dim) # (num_edges + num_nodes, batch_size, input_dim) + + return message + + def aggregate(self, input, edge_weight, index, dim_size): + + + + index = torch.cat([index, torch.arange(dim_size, device=input.device)]) # (num_edges + num_nodes == ) + edge_weight = torch.cat([edge_weight, torch.ones(dim_size, device=input.device)]) + + shape = [1] * input.ndim + shape[self.node_dim] = -1 + edge_weight = edge_weight.view(shape) + + + + if self.aggregate_func == "pna": + mean = scatter_mean(input * edge_weight, index, dim=self.node_dim, dim_size=dim_size) + sq_mean = scatter_mean(input ** 2 * edge_weight, index, dim=self.node_dim, dim_size=dim_size) + + max = scatter(input * edge_weight, index, dim=self.node_dim, dim_size=dim_size, reduce="max") + min= scatter(input * edge_weight, index, dim=self.node_dim, dim_size=dim_size, reduce="min") + + std = (sq_mean - mean ** 2).clamp(min=self.eps).sqrt()#Torch + features = torch.cat([mean.unsqueeze(-1), max.unsqueeze(-1), min.unsqueeze(-1), std.unsqueeze(-1)], dim=-1) + features = features.flatten(-2) + degree_out = degree(index, dim_size).unsqueeze(0).unsqueeze(-1) + scale = degree_out.log() + scale = scale / scale.mean() + scales = torch.cat([torch.ones_like(scale), scale, 1 / scale.clamp(min=1e-2)], dim=-1) + output = (features.unsqueeze(-1) * scales.unsqueeze(-2)).flatten(-2) + + + return output + + def update(self, update, input): + + output = self.linear(torch.cat([input, update], dim=-1)) + if self.layer_norm: + output = self.layer_norm(output) + if self.activation: + output = self.activation(output) + return output + + + def message_and_aggregate(self, edge_index, input, relation, boundary, edge_type, edge_weight, index, dim_size): + # fused computation of message and aggregate steps with the custom rspmm cuda kernel + # speed up computation by several times + # reduce memory complexity from O(|E|d) to O(|V|d), so we can apply it to larger graphs + + + + batch_size, num_node = input.shape[:2] + input = input.transpose(0, 1).flatten(1) + relation = relation.transpose(0, 1).flatten(1) + boundary = boundary.transpose(0, 1).flatten(1) + degree_out = degree(index, dim_size).unsqueeze(-1) + 1 + + if self.message_func in self.message2mul:# self.message_func == "distmult" + mul = self.message2mul[self.message_func] # mul == "mul" + else: + raise ValueError("Unknown message function `%s`" % self.message_func) + + + if self.aggregate_func == "sum": + update = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="add", mul=mul) + update = update + boundary + elif self.aggregate_func == "mean": + update = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="add", mul=mul) + update = (update + boundary) / degree_out + elif self.aggregate_func == "max": + update = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="max", mul=mul) + update = torch.max(update, boundary) + ############ msg_func = mul aggr_func = pna, + elif self.aggregate_func == "pna": + # we use PNA with 4 aggregators (mean / max / min / std) + # and 3 scalars (identity / log degree / reciprocal of log degree) + sum = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="add", mul=mul) + sq_sum = generalized_rspmm(edge_index, edge_type, edge_weight, relation ** 2, input ** 2, sum="add", + mul=mul) + max = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="max", mul=mul) + min = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="min", mul=mul) + mean = (sum + boundary) / degree_out + sq_mean = (sq_sum + boundary ** 2) / degree_out + max = torch.max(max, boundary) + min = torch.min(min, boundary) # (node, batch_size * input_dim) + std = (sq_mean - mean ** 2).clamp(min=self.eps).sqrt() + features = torch.cat([mean.unsqueeze(-1), max.unsqueeze(-1), min.unsqueeze(-1), std.unsqueeze(-1)], dim=-1) + features = features.flatten(-2) # (node, batch_size * input_dim * 4) + scale = degree_out.log() + scale = scale / scale.mean() + scales = torch.cat([torch.ones_like(scale), scale, 1 / scale.clamp(min=1e-2)], dim=-1) # (node, 3) + update = (features.unsqueeze(-1) * scales.unsqueeze(-2)).flatten(-2) # (node, batch_size * input_dim * 4 * 3) + else: + raise ValueError("Unknown aggregation function `%s`" % self.aggregate_func) + + update = update.view(num_node, batch_size, -1).transpose(0, 1) + return update + + def propagate(self, edge_index, size=None, **kwargs): + if kwargs["edge_weight"].requires_grad or self.message_func == "rotate": + # the rspmm cuda kernel only works for TransE and DistMult message functions + # otherwise we invoke separate message & aggregate functions + return super(GeneralizedRelationalConv, self).propagate(edge_index, size, **kwargs) + + for hook in self._propagate_forward_pre_hooks.values(): + res = hook(self, (edge_index, size, kwargs)) + if res is not None: + edge_index, size, kwargs = res + + size = self._check_input(edge_index, size) + coll_dict = self._collect(self._fused_user_args, edge_index, + size, kwargs) + + msg_aggr_kwargs = self.inspector.distribute("message_and_aggregate", coll_dict) + for hook in self._message_and_aggregate_forward_pre_hooks.values(): + res = hook(self, (edge_index, msg_aggr_kwargs)) + if res is not None: + edge_index, msg_aggr_kwargs = res + out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs) + for hook in self._message_and_aggregate_forward_hooks.values(): + res = hook(self, (edge_index, msg_aggr_kwargs), out) + if res is not None: + out = res + + update_kwargs = self.inspector.distribute("update", coll_dict) + out = self.update(out, **update_kwargs) + + for hook in self._propagate_forward_hooks.values(): + res = hook(self, (edge_index, size, kwargs), out) + if res is not None: + out = res + + return out + + + diff --git a/openhgnn/models/RedGNN.py b/openhgnn/models/RedGNN.py new file mode 100644 index 00000000..7774f04d --- /dev/null +++ b/openhgnn/models/RedGNN.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn +from ..utils.utils import scatter +from . import BaseModel, register_model +import dgl +# from dgl import +from scipy.sparse import csr_matrix +import numpy as np + + +@register_model('RedGNN') +class RedGNN(BaseModel): + @classmethod + def build_model_from_args(cls, args, loader): + return cls(args, loader) + + def __init__(self, args, loader): + super(RedGNN, self).__init__() + self.device = args.device + self.hidden_dim = args.hidden_dim + self.attn_dim = args.attn_dim + self.n_layer = args.n_layer + self.loader = loader + self.n_rel = self.loader.n_rel + acts = {'relu': nn.ReLU(), 'tanh': torch.tanh, 'idd': lambda x: x} + act = acts[args.act] + self.act = act + self.gnn_layers = [] + for i in range(self.n_layer): + self.gnn_layers.append(RedGNNLayer(self.hidden_dim, self.hidden_dim, self.attn_dim, self.n_rel, act=act)) + self.gnn_layers = nn.ModuleList(self.gnn_layers) + + self.dropout = nn.Dropout(args.dropout) + self.W_final = nn.Linear(self.hidden_dim, 1, bias=False) # get score + self.gate = nn.GRU(self.hidden_dim, self.hidden_dim) + + + def forward(self, subs, rels, mode='transductive'): # source node, rels + n = len(subs) + + n_ent = self.loader.n_ent if mode=='transductive' else self.loader.n_ent_ind + + q_sub = torch.LongTensor(subs).to(self.device) + q_rel = torch.LongTensor(rels).to(self.device) + + h0 = torch.zeros((1, n, self.hidden_dim)).to(self.device) # 1 * n * d + nodes = torch.cat([torch.arange(n).unsqueeze(1).to(self.device), q_sub.unsqueeze(1)], 1) + hidden = torch.zeros(n, self.hidden_dim).to(self.device) + + for i in range(self.n_layer): + nodes, edges, old_nodes_new_idx = self.loader.get_neighbors(nodes.data.cpu().numpy(), mode=mode) + edges = edges.to(self.device) + old_nodes_new_idx = old_nodes_new_idx.to(self.device) + hidden = self.gnn_layers[i](q_sub, q_rel, hidden, edges, nodes.size(0), old_nodes_new_idx) + h0 = torch.zeros(1, nodes.size(0), hidden.size(1)).to(self.device).index_copy_(1, old_nodes_new_idx, h0) + hidden = self.dropout(hidden) + hidden, h0 = self.gate(hidden.unsqueeze(0), h0) + hidden = hidden.squeeze(0) + + scores = self.W_final(hidden).squeeze(-1) + scores_all = torch.zeros((n, n_ent)).to(self.device) + scores_all[[nodes[:, 0], nodes[:,1]]] = scores + return scores_all + + +class RedGNNLayer(torch.nn.Module): + def __init__(self, in_dim, out_dim, attn_dim, n_rel, act=lambda x:x): + super(RedGNNLayer, self).__init__() + self.n_rel = n_rel + self.in_dim = in_dim + self.out_dim = out_dim + self.attn_dim = attn_dim + self.act = act + + self.rela_embed = nn.Embedding(2*n_rel+1, in_dim) + + self.Ws_attn = nn.Linear(in_dim, attn_dim, bias=False) + self.Wr_attn = nn.Linear(in_dim, attn_dim, bias=False) + self.Wqr_attn = nn.Linear(in_dim, attn_dim) + self.w_alpha = nn.Linear(attn_dim, 1) + + self.W_h = nn.Linear(in_dim, out_dim, bias=False) + + def forward(self, q_sub, q_rel, hidden, edges, n_node, old_nodes_new_idx): + # edges: [batch_idx, head, rela, tail, old_idx, new_idx] + sub = edges[:,4] + rel = edges[:,2] + obj = edges[:,5] + + hs = hidden[sub] + hr = self.rela_embed(rel) + + r_idx = edges[:,0] + h_qr = self.rela_embed(q_rel)[r_idx] + + message = hs + hr + alpha = torch.sigmoid(self.w_alpha(nn.ReLU()(self.Ws_attn(hs) + self.Wr_attn(hr) + self.Wqr_attn(h_qr)))) + message = alpha * message + + message_agg = scatter(message, index=obj, dim=0, dim_size=n_node, reduce='sum') + + hidden_new = self.act(self.W_h(message_agg)) + + return hidden_new + + + + + + + diff --git a/openhgnn/models/RedGNNT.py b/openhgnn/models/RedGNNT.py new file mode 100644 index 00000000..2cb9a592 --- /dev/null +++ b/openhgnn/models/RedGNNT.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +from ..utils.utils import scatter +from . import BaseModel, register_model +import dgl +# from dgl import +from scipy.sparse import csr_matrix +import numpy as np + + +@register_model('RedGNNT') +class RedGNNT(BaseModel): + @classmethod + def build_model_from_args(cls, args, loader): + return cls(args, loader) + + def __init__(self, args, loader): + super(RedGNNT, self).__init__() + self.device = args.device + self.hidden_dim = args.hidden_dim + self.attn_dim = args.attn_dim + self.n_layer = args.n_layer + self.loader = loader + self.n_rel = self.loader.n_rel + acts = {'relu': nn.ReLU(), 'tanh': torch.tanh, 'idd': lambda x: x} + act = acts[args.act] + self.act = act + self.gnn_layers = [] + for i in range(self.n_layer): + self.gnn_layers.append(RedGNNLayer(self.hidden_dim, self.hidden_dim, self.attn_dim, self.n_rel, act=act)) + self.gnn_layers = nn.ModuleList(self.gnn_layers) + + self.dropout = nn.Dropout(args.dropout) + self.W_final = nn.Linear(self.hidden_dim, 1, bias=False) # get score + self.gate = nn.GRU(self.hidden_dim, self.hidden_dim) + + + def forward(self, subs, rels, mode='train'): # source node, rels + n = len(subs) + q_sub = torch.LongTensor(subs).to(self.device) + q_rel = torch.LongTensor(rels).to(self.device) + + h0 = torch.zeros((1, n, self.hidden_dim)).to(self.device) # 1 * n * d + nodes = torch.cat([torch.arange(n).unsqueeze(1).to(self.device), q_sub.unsqueeze(1)], 1) + hidden = torch.zeros(n, self.hidden_dim).to(self.device) + + for i in range(self.n_layer): + nodes, edges, old_nodes_new_idx = self.loader.get_neighbors(nodes.data.cpu().numpy(), mode=mode) + edges = edges.to(self.device) + old_nodes_new_idx = old_nodes_new_idx.to(self.device) + hidden = self.gnn_layers[i](q_sub, q_rel, hidden, edges, nodes.size(0), old_nodes_new_idx) + h0 = torch.zeros(1, nodes.size(0), hidden.size(1)).to(self.device).index_copy_(1, old_nodes_new_idx, h0) + hidden = self.dropout(hidden) + hidden, h0 = self.gate(hidden.unsqueeze(0), h0) + hidden = hidden.squeeze(0) + + scores = self.W_final(hidden).squeeze(-1) + scores_all = torch.zeros((n, self.loader.n_ent)).to(self.device) + scores_all[[nodes[:, 0], nodes[:,1]]] = scores + return scores_all + + + +class RedGNNLayer(torch.nn.Module): + def __init__(self, in_dim, out_dim, attn_dim, n_rel, act=lambda x:x): + super(RedGNNLayer, self).__init__() + self.n_rel = n_rel + self.in_dim = in_dim + self.out_dim = out_dim + self.attn_dim = attn_dim + self.act = act + + self.rela_embed = nn.Embedding(2*n_rel+1, in_dim) + + self.Ws_attn = nn.Linear(in_dim, attn_dim, bias=False) + self.Wr_attn = nn.Linear(in_dim, attn_dim, bias=False) + self.Wqr_attn = nn.Linear(in_dim, attn_dim) + self.w_alpha = nn.Linear(attn_dim, 1) + + self.W_h = nn.Linear(in_dim, out_dim, bias=False) + + def forward(self, q_sub, q_rel, hidden, edges, n_node, old_nodes_new_idx): + # edges: [batch_idx, head, rela, tail, old_idx, new_idx] + sub = edges[:,4] + rel = edges[:,2] + obj = edges[:,5] + + hs = hidden[sub] + hr = self.rela_embed(rel) + + r_idx = edges[:,0] + h_qr = self.rela_embed(q_rel)[r_idx] + + message = hs + hr + alpha = torch.sigmoid(self.w_alpha(nn.ReLU()(self.Ws_attn(hs) + self.Wr_attn(hr) + self.Wqr_attn(h_qr)))) + message = alpha * message + + message_agg = scatter(message, index=obj, dim=0, dim_size=n_node, reduce='sum') + + hidden_new = self.act(self.W_h(message_agg)) + + return hidden_new + + + + + + + diff --git a/openhgnn/models/SACN.py b/openhgnn/models/SACN.py new file mode 100644 index 00000000..5dea82ad --- /dev/null +++ b/openhgnn/models/SACN.py @@ -0,0 +1,227 @@ +from . import BaseModel, register_model +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.init import xavier_normal_ +from dgl import function as fn +from numpy.random.mtrand import set_state +import pandas +import torch +import math +@register_model('SACN') +class SACN(BaseModel): + @classmethod + def build_model_from_args(cls, config): + return cls(config) + + def __init__(self, config): + super().__init__() + + self.model = WGCN_Base(config) + + def forward(self, *args): + return self.model(*args) + + def extra_loss(self): + pass +class GraphConvolution(torch.nn.Module): + def __init__(self, in_features, out_features, num_relations, bias=True, wsi=False): + super(GraphConvolution, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter( + torch.FloatTensor(in_features, out_features)) + self.num_relations = num_relations + self.alpha = torch.nn.Embedding(num_relations + 1, 1, padding_idx=0) + if bias: + self.bias = nn.Parameter(torch.FloatTensor(out_features)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + self.wsi = wsi + + def reset_parameters(self): + stdv = 1. / math.sqrt(self.weight.size(1)) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.uniform_(-stdv, stdv) + + def forward(self, g, all_edge_type, input): + with g.local_scope(): + feats = torch.mm(input, self.weight) + g.srcdata['ft'] = feats + if not self.wsi: + train_edge_num = int( + (all_edge_type.shape[0] - input.shape[0]) / 2) + transpose_all_edge_type = torch.cat((all_edge_type[train_edge_num:train_edge_num * 2], + all_edge_type[:train_edge_num], all_edge_type[-input.shape[0]:])) + else: + train_edge_num = int((all_edge_type.shape[0])) + transpose_all_edge_type = torch.cat((all_edge_type[train_edge_num:train_edge_num * 2], + all_edge_type[:train_edge_num])) + alp = self.alpha(all_edge_type) + \ + self.alpha(transpose_all_edge_type) + g.edata['a'] = alp + + g.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft')) + + output = g.dstdata['ft'] + + if self.bias is not None: + return output + self.bias + else: + return output + + def __repr__(self): + return self.__class__.__name__ + ' (' \ + + str(self.in_features) + ' -> ' \ + + str(self.out_features) + ')' +class WGCN_Base(torch.nn.Module): + def __init__(self, args): + super(WGCN_Base, self).__init__() + num_entities=args.num_entities + num_relations=args.num_relations + self.rat = args.rat + self.wni = args.wni + + self.fa = args.final_act + self.fb = args.final_bn + self.fd = args.final_drop + + self.decoder_name = args.decoder + self.num_layers = args.n_layer + self.emb_e = torch.nn.Embedding( + num_entities, args.init_emb_size, padding_idx=0) + self.emb_rel = torch.nn.Embedding( + num_relations, args.embedding_dim, padding_idx=0) + + nn.init.xavier_normal_( + self.emb_e.weight, gain=nn.init.calculate_gain('relu')) + nn.init.xavier_normal_(self.emb_rel.weight, + gain=nn.init.calculate_gain('relu')) + + if self.num_layers == 3: + self.gc1 = GraphConvolution( + args.init_emb_size, args.gc1_emb_size, num_relations, wsi=args.wsi) + self.gc2 = GraphConvolution( + args.gc1_emb_size, args.gc1_emb_size, num_relations, wsi=args.wsi) + self.gc3 = GraphConvolution( + args.gc1_emb_size, args.embedding_dim, num_relations, wsi=args.wsi) + elif self.num_layers == 2: + self.gc2 = GraphConvolution( + args.init_emb_size, args.gc1_emb_size, num_relations, wsi=args.wsi) + self.gc3 = GraphConvolution( + args.gc1_emb_size, args.embedding_dim, num_relations, wsi=args.wsi) + else: + self.gc3 = GraphConvolution( + args.init_emb_size, args.embedding_dim, num_relations, wsi=args.wsi) + + self.inp_drop = torch.nn.Dropout(args.input_dropout) + self.hidden_drop = torch.nn.Dropout(args.dropout_rate) + self.feature_map_drop = torch.nn.Dropout(args.dropout_rate) + self.loss = torch.nn.BCELoss() + self.conv1 = nn.Conv1d(2, args.channels, args.kernel_size, stride=1, + padding=int(math.floor(args.kernel_size / 2))) + self.bn0 = torch.nn.BatchNorm1d(2) + self.bn1 = torch.nn.BatchNorm1d(args.channels) + self.bn2 = torch.nn.BatchNorm1d(args.embedding_dim) + self.register_parameter('b', nn.Parameter(torch.zeros(num_entities))) + self.fc = torch.nn.Linear( + args.embedding_dim * args.channels, args.embedding_dim) + self.bn3 = torch.nn.BatchNorm1d(args.gc1_emb_size) + self.bn4 = torch.nn.BatchNorm1d(args.embedding_dim) + self.bn5 = torch.nn.BatchNorm1d(args.gc1_emb_size) + self.bn_init = torch.nn.BatchNorm1d(args.init_emb_size) + self.args=args + print(num_entities, num_relations) + + if args.decoder == "transe": + self.decoder = self.transe + self.gamma = args.gamma + elif args.decoder == "distmult": + self.decoder = self.distmult + self.bias = nn.Parameter(torch.zeros(num_entities)) + elif args.decoder == "conve": + self.decoder = self.conve + else: + raise NotImplementedError + + def conve(self, e1_embedded, rel_embedded, e1_embedded_all): + stacked_inputs = torch.cat([e1_embedded, rel_embedded], 1) + stacked_inputs = self.bn0(stacked_inputs) + x = self.inp_drop(stacked_inputs) + + x = self.conv1(x) + x = self.bn1(x) + + x = F.relu(x) + x = self.feature_map_drop(x) + x = x.view(e1_embedded.shape[0], -1) + x = self.fc(x) + x = self.hidden_drop(x) + x = self.bn2(x) + x = F.relu(x) + + x = torch.mm(x, e1_embedded_all.transpose(1, 0)) + pred = torch.sigmoid(x) + return pred + + def transe(self, e1_embedded, rel_embedded, e1_embedded_all): + obj_emb = e1_embedded + rel_embedded + + x = self.gamma - \ + torch.norm(obj_emb - e1_embedded_all.unsqueeze(0), p=1, dim=2) + pred = torch.sigmoid(x) + + return pred + + def distmult(self, e1_embedded, rel_embedded, e1_embedded_all): + obj_emb = e1_embedded * rel_embedded + + x = torch.mm(obj_emb.squeeze(1), e1_embedded_all.transpose(1, 0)) + x += self.bias.expand_as(x) + pred = torch.sigmoid(x) + + return pred + + def init(self): + xavier_normal_(self.emb_e.weight.data) + xavier_normal_(self.emb_rel.weight.data) + xavier_normal_(self.gc1.weight.data) + xavier_normal_(self.gc2.weight.data) + xavier_normal_(self.gc3.weight.data) + + def forward(self, g, all_edge, e1, rel, entity_id): + emb_initial = self.emb_e(entity_id) + + if self.num_layers == 3: + x = self.gc1(g, all_edge, emb_initial) + x = self.bn5(x) + x = torch.tanh(x) + x = F.dropout(x, self.args.dropout_rate, training=self.training) + else: + x = emb_initial + + if self.num_layers >= 2: + x = self.gc2(g, all_edge, x) + x = self.bn3(x) + x = torch.tanh(x) + x = F.dropout(x, self.args.dropout_rate, training=self.training) + + if self.num_layers >= 1: + x = self.gc3(g, all_edge, x) + + if self.fb: + x = self.bn4(x) + if self.fa: + x = torch.tanh(x) + if self.fd: + x = F.dropout(x, self.args.dropout_rate, training=self.training) + + e1_embedded_all = x + + e1_embedded = e1_embedded_all[e1] + rel_embedded = self.emb_rel(rel) + + pred = self.decoder(e1_embedded, rel_embedded, e1_embedded_all) + return pred \ No newline at end of file diff --git a/openhgnn/models/__init__.py b/openhgnn/models/__init__.py index 4a38f085..a3c8b11a 100644 --- a/openhgnn/models/__init__.py +++ b/openhgnn/models/__init__.py @@ -111,6 +111,18 @@ def build_model_from_args(args, hg): 'DSSL': 'openhgnn.models.DSSL', 'HGCL': 'openhgnn.models.HGCL', 'lightGCN': 'openhgnn.models.lightGCN', + 'Grail': 'openhgnn.models.Grail', + 'ComPILE': 'openhgnn.models.ComPILE', + 'AdapropT': 'openhgnn.models.AdapropT', + 'AdapropI':'openhgnn.models.AdapropI', + 'LTE': 'openhgnn.models.LTE', + 'LTE_Transe': 'openhgnn.models.LTE_Transe', + 'SACN':'openhgnn.models.SACN', + 'ExpressGNN': 'openhgnn.models.ExpressGNN', + 'NBF': 'openhgnn.models.NBF', + 'Ingram': 'openhgnn.models.Ingram', + 'RedGNN': 'openhgnn.models.RedGNN', + 'RedGNNT': 'openhgnn.models.RedGNNT', } from .HGCL import HGCL @@ -153,6 +165,17 @@ def build_model_from_args(args, hg): from .KGAT import KGAT from .DSSL import DSSL from .lightGCN import lightGCN +from .Grail import Grail +from .ComPILE import ComPILE +from .AdapropT import AdapropT +from .AdapropI import AdapropI +from .LTE import LTE +from .LTE_Transe import LTE_Transe +from .SACN import SACN +from .ExpressGNN import ExpressGNN +from .Ingram import Ingram +from .RedGNN import RedGNN +from .RedGNNT import RedGNNT __all__ = [ 'BaseModel', @@ -192,5 +215,14 @@ def build_model_from_args(args, hg): 'KGAT', 'DSSL', 'lightGCN', + 'Grail', + 'ComPILE', + 'AdapropT', + 'AdapropI', + 'LTE', + 'LTE_Transe', + 'SACN', + 'ExpressGNN', + 'Ingram', ] classes = __all__ diff --git a/openhgnn/output/Adaprop/README.md b/openhgnn/output/Adaprop/README.md new file mode 100644 index 00000000..002d805c --- /dev/null +++ b/openhgnn/output/Adaprop/README.md @@ -0,0 +1,92 @@ +# AdapropT[KDD2023] + +- paper: https://arxiv.org/pdf/2205.15319.pdf +- Code from author: [AdapropT](https://github.com/LARS-research/AdaProp) + +This model is divided into two scenarios: inductive and transactional, so we will integrate them into two models: AdaptropT and AdaptropI +## AdapropT +### How to run + +- Clone the Openhgnn-DGL + + ```bash + python main.py -m AdapropT -d AdapropT -t link_prediction -g 0 + ``` + + If you do not have gpu, set -gpu -1. + + the dataset AdapropT is supported. + +### Performance: Family + +- Device: GPU, **GeForce RTX 2080Ti** +- Dataset: AdapropT + +| transductive | Family +|:------------:| :-----------------------------: +| MRR | paper: 0.988 OpenHGNN: 0.9883 +| H@1 | paper: 98.6 OpenHGNN: 0.9864 +| H@10 | paper: 99.0 OpenHGNN: 0.9907 + + + + + + + +### TrainerFlow: AdapropT_trainer + +#### model + +- ​ AdapropT + - ​ AdapropT is a meta-path based recommendation model, which is based on the idea of meta-path based heterogeneous graph neural network. + + + + +## AdapropI +### How to run + +- Clone the Openhgnn-DGL + + ```bash + python main.py -m AdapropI -d AdapropI -t link_prediction -g 0 + ``` + + If you do not have gpu, set -gpu -1. + + the dataset AdapropI is supported. + + ## Performance: fb15k237_v1 + +- Device: GPU, **GeForce RTX 2080Ti** +- Dataset: AdapropI +- +| inductive | fb15k237_v1 | fb15k237_v2 | fb15k237_v3 | fb15k237_v4 +|:---------:| :-----------------------------: |:--------------------------------:|:--------------------------------:| :-----------------------------: +| MRR | paper: 0.310 OpenHGNN: 0.3121 | paper: 0.471 OpenHGNN: 0.4667 | paper: 0.471 OpenHGNN: 0.3121 | paper: 0.454 OpenHGNN: 0.4468 +| H@1 | paper: 0.191 OpenHGNN: 0.1946 | paper: 0.372 OpenHGNN: 0.3643 | paper: 0.377 OpenHGNN: 0.3121 | paper: 0.353 OpenHGNN: 0.3521 +| H@10 | paper: 0.551 OpenHGNN: 0.5462| paper: 0.659 OpenHGNN: 0.6505 | paper: 0.637 OpenHGNN: 0.3121 | paper: 0.638 OpenHGNN: 0.6381 + + + + + + +### TrainerFlow: AdapropI_trainer + +#### model + +- ​ AdapropI + - ​ AdapropI is a meta-path based recommendation model, which is based on the idea of meta-path based heterogeneous graph neural network. + +#### Contributor + +zikai Zhou[GAMMA LAB] + +#### If you have any questions, + +Submit an issue or email to 460813395@qq.com + + + diff --git a/openhgnn/output/ComPILE/README.md b/openhgnn/output/ComPILE/README.md new file mode 100644 index 00000000..1004c20f --- /dev/null +++ b/openhgnn/output/ComPILE/README.md @@ -0,0 +1,35 @@ +# ComPILE[AAAI 2021] + +- Paper: [Communicative Message Passing for Inductive Relation Reasoning](https://arxiv.org/pdf/2012.08911.pdf) +- Author's code: https://github.com/TmacMai/CoMPILE_Inductive_Knowledge_Graph +- Note: The difference between ComPILE and Grail is mainly on the model structure. + +## How to run + +* Clone the Openhgnn-DGL + +```bash +python main.py -m Grail -d WN18RR_v1 -t link_prediction -g 0 +``` +| inductive | WN18RR_v1 +|:---------:| :-----------------------------: +| H@10 | paper: 83.60 OpenHGNN: 87.23 +We report the result of the best valid epoch. + +If you do not have gpu, set -gpu -1. +#### Contributor + +Shuaikun Liu, Fengqi Liang[GAMMA LAB] + +#### If you have any questions, + +Submit an issue or email to +liushuaikun@bupt.edu.cn, lfq@bupt.edu.cn + + + + + + + + diff --git a/openhgnn/output/DisenKGAT/README.md b/openhgnn/output/DisenKGAT/README.md new file mode 100644 index 00000000..3cdf4772 --- /dev/null +++ b/openhgnn/output/DisenKGAT/README.md @@ -0,0 +1,129 @@ +# DisenKGAT[CIKM 2021] + +Paper: [**DisenKGAT: Knowledge Graph Embedding with Disentangled +Graph Attention Network**](https://dl.acm.org/doi/10.1145/3459637.3482424) + +Code from author: https://github.com/Wjk666/DisenKGAT + +#### How to run + +Clone the Openhgnn-DGL + +```bash +python main.py -m DisenKGAT -t link_prediction -d DisenKGAT_WN18RR -g 0 +``` + +If you do not have gpu, set -gpu -1. + +Candidate dataset: DisenKGAT_WN18RR , DisenKGAT_FB15k-237 + +#### Performance + + +| Metric | DisenKGAT_WN18RR | DisenKGAT_FB15k-237 | +| ------------------- | ------- | ------- | +| MR |1406 | 167 | +| MRR | 0.455 | 0.354 | +| HITS@1 | 0.431 | 0.244 | +| HITS@3 | 0.485 | 0.387 | +| HITS@10 | 0.521 | 0.511 | + + +#### Model + +We implement DisenKGAT with DisenKGAT_TransE,DisenKGAT_DistMult,DisenKGAT_ConvE,DisenKGAT_InteractE + +### Dataset + +Supported dataset: WN18RR , FB15k-237 + +You can download the dataset by + +``` +wget https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/DisenKGAT_WN18RR.zip +wget https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/DisenKGAT_FB15k-237.zip +``` + +### Hyper-parameter specific to the model + +```python +# str +name = Disen_Model +# data = DisenKGAT_WN18RR +# model = DisenKGAT +score_func = interacte +opn = cross +# gpu = 2 +logdir = ./log/ +config = ./config/ +strategy = one_to_n +form = plain +mi_method = club_b +att_mode = dot_weight +score_method = dot_rel +score_order = after +gamma_method = norm + + +# int +k_w = 10 +batch = 2048 +test_batch = 2048 +epoch = 1500 +num_workers = 10 +seed = 41504 +init_dim = 100 +gcn_dim = 200 +embed_dim = 200 +gcn_layer = 1 +k_h = 20 +num_filt = 200 +ker_sz = 7 +num_bases = -1 +neg_num = 1000 +ik_w = 10 +ik_h = 20 +inum_filt = 200 +iker_sz = 9 +iperm = 1 +head_num = 1 +num_factors = 3 +early_stop = 200 +mi_epoch = 1 + +# float +feat_drop = 0.3 +hid_drop2 = 0.3 +hid_drop = 0.3 +gcn_drop = 0.4 +gamma = 9.0 +l2 = 0.0 +lr = 0.001 +lbl_smooth = 0.1 +iinp_drop = 0.3 +ifeat_drop = 0.4 +ihid_drop = 0.3 +alpha = 1e-1 +max_gamma = 5.0 +init_gamma = 9.0 + +# boolean +restore = False +bias = False +no_act = False +mi_train = True +no_enc = False +mi_drop = True +fix_gamma = False + +``` + +All config can be found in [config.ini](../../config.ini) + + + +## More + +#### If you have any questions, + +Submit an issue or email to [zhaozihao@bupt.edu.cn](mailto:zhaozihao@bupt.edu.cn). diff --git a/openhgnn/output/ExpressGNN/README.md b/openhgnn/output/ExpressGNN/README.md new file mode 100644 index 00000000..c26c3bc7 --- /dev/null +++ b/openhgnn/output/ExpressGNN/README.md @@ -0,0 +1,69 @@ +# ExpressGNN[ICLR2020] + +- paper: [Efficient Probabilistic Logic Reasoning With Graph Neural Networks](https://arxiv.org/abs/2001.11850) +- Code from author: [ExpressGNN](https://github.com/expressGNN/ExpressGNN) + + +## How to run + +- Clone the Openhgnn-DGL + + ```bash + python main.py -m ExpressGNN -d EXP_FB15k-237 -t link_prediction -g 0 --use_best_config + ``` + + If you do not have gpu, set -gpu -1. + + +## Performance: link_prediction + +| | fb15k237 +|:---------:| :-----------------------------: +| MRR | paper: 0.49 OpenHGNN: 0.4399 +| H@10 | paper: 0.608 OpenHGNN: 0.5668 + +## Dataset + +Supported dataset: cora, fb15k-237, kinship, uw_cse + +### Cora + +The Cora dataset is a widely used academic citation network, containing scientific papers categorized into various research topics, making it valuable for research in machine learning and graph analysis. + +### fb15k-237 + +The FB15k-237 dataset is a knowledge graph dataset that focuses on entity and relation prediction tasks, derived from Freebase with 237 relations. + +### kinship + +The Kinship dataset is a collection of genealogical information, encompassing family relationships and demographics, often used for studying kinship recognition in computer vision and social sciences. + +### uw_cse + + +The UW CSE (Computer Science and Engineering) dataset is a repository of academic information, including faculty, courses, and research projects, commonly used for educational and research purposes in computer science. + +## TrainerFlow: link_prediction + +#### model + +ExpressGNN is an extension of Graph Neural Networks (GNNs). + +### Graph Neural Networks (GNNs) + +GNNs model node embeddings in a graph-structured data by recursively aggregating neighbor information. + +### ExpressGNN Extension + +ExpressGNN extends GNNs by incorporating temporal information, capturing evolving relationships in dynamic graphs. + + +## More + +#### Contributor + +Ziyao Lin, Fengqi Liang[GAMMA LAB] + +#### If you have any questions, + +Submit an issue or email to ziyao_lin@hust.edu.cn, lfq@bupt.edu.cn \ No newline at end of file diff --git a/openhgnn/output/Grail/README.md b/openhgnn/output/Grail/README.md new file mode 100644 index 00000000..2896e4b8 --- /dev/null +++ b/openhgnn/output/Grail/README.md @@ -0,0 +1,33 @@ +# Grail[ICML 2020] + +- Paper: [Inductive Relation Prediction by Subgraph Reasoning](https://arxiv.org/pdf/1911.06962.pdf) +- Author's code: https://github.com/kkteru/grail + + +## How to run + +* Clone the Openhgnn-DGL + +```bash +python main.py -m Grail -d WN18RR_v1 -t link_prediction -g 0 +``` +| inductive | WN18RR_v1 +|:---------:| :-----------------------------: +| H@10 | paper: 82.45 OpenHGNN: 0.8431 + + +If you do not have gpu, set -gpu -1. +#### Contributor + +Shuaikun Liu, Fengqi Liang[GAMMA LAB] + +#### If you have any questions, + +Submit an issue or email to +liushuaikun@bupt.edu.cn, lfq@bupt.edu.cn + + + + + + diff --git a/openhgnn/output/Ingram/README.md b/openhgnn/output/Ingram/README.md new file mode 100644 index 00000000..9b43d233 --- /dev/null +++ b/openhgnn/output/Ingram/README.md @@ -0,0 +1,45 @@ +# InGram[ICML] + +Paper:InGram: [Inductive Knowledge Graph Embedding via Relation Graphs](https://proceedings.mlr.press/v202/lee23c/lee23c.pdf) + + +#### How to run + +Clone the Openhgnn-DGL + +```bash +python main.py -m Ingram -d NL-100 -t Ingram -g 0 +``` + +If you do not have gpu, set -gpu -1. + +Candidate dataset: NL-100 + +#### Performance + +| InGram[OpenHGNN] | MR | MRR | H@10 | H@1 | +|------------------|---------------------------|-----------------------------|-----------------------------|------------------------------| +| NL-100 | Paper:92.6 openhgnn:97.0 | Paper:0.309 openhgnn:0.295 | Paper:0.506 openhgnn:0.494 | Paper: 0.212 openhgnn: 0.193 | + + + + +### TrainerFlow: Ingram + + +#### model + +Ingram + +### Dataset + +NL-100 + + +#### Contirbutor + +Zikai Zhou[GAMMA LAB] + +#### If you have any questions, + +Submit an issue or email to 460813395@qq.com. \ No newline at end of file diff --git a/openhgnn/output/LTE/README.md b/openhgnn/output/LTE/README.md new file mode 100644 index 00000000..0ea1b9b1 --- /dev/null +++ b/openhgnn/output/LTE/README.md @@ -0,0 +1,38 @@ +# LTE + +-paper: [Rethinking Graph Convolutional Networks in Knowledge +Graph Completion +](https://arxiv.org/pdf/2202.05679.pdf) + + +## How to run +- Clone the Openhgnn-DGL + ```bash + python main.py -m LTE -d LTE -t link_prediction -g 0 + ``` + +for high efficiency, only gpu + +## Performance: Recommendation + +- Device: GPU, **GeForce GTX 3090** +- Dataset:LTE_dataset + + +| LTE-TransE | FB237 | +|:--------------:|:---------------------------------:| +| MRR | paper: 0.334 OpenHGNN: 0.3272 | +| H@1 | paper: 0.241 OpenHGNN: 0.23527 | +| H@3 | paper: 0.370 OpenHGNN: 0.3616 | +| H@10 | paper: 0.519 OpenHGNN: 0.5111 | + + +#### Contributor + +Zikai Zhou[GAMMA LAB] + +#### If you have any questions, + +Submit an issue or email to 460813395@qq.com + + diff --git a/openhgnn/output/NBF/README.md b/openhgnn/output/NBF/README.md new file mode 100644 index 00000000..0822a8e4 --- /dev/null +++ b/openhgnn/output/NBF/README.md @@ -0,0 +1,74 @@ +# NBF_Net[NIPS 2021] + +Paper: [**Neural Bellman-Ford Networks: A General Graph Neural Network Framework for Link Prediction**](https://proceedings.neurips.cc/paper_files/paper/2021/file/f6a673f09493afcd8b129a0bcf1cd5bc-Paper.pdf) + +Code from author: https://github.com/DeepGraphLearning/NBFNet + +#### How to run + +Clone the Openhgnn-DGL + +```bash +python main.py -m NBF -t link_prediction -d NBF_WN18RR -g 0 +``` + +If you do not have gpu, set -gpu -1. + +Candidate dataset: NBF_WN18RR , NBF_FB15k-237 + +#### Performance + + +| Metric | WN18RR | +| ------------------- | ------- | +| MRR | 69.72 | +| HITS@1 | 60.37 | +| HITS@3 | 77.66 | +| HITS@10 | 82.71 | +| HITS@10_50 | 95.40 | + +#### Model + +We implement NBF_Net with GeneralizedRelationalConv + +### Dataset + +Supported dataset: WN18RR , FB15k-237 + +You can download the dataset by + +``` +wget https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/WN18RR.zip +wget https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/FB15k-237.zip +``` + +### Hyper-parameter specific to the model + +```python +input_dim = 32 +hidden_dims = [32, 32, 32, 32, 32, 32] +message_func = distmult +aggregate_func = pna +short_cut = True +layer_norm = True +dependent = False +num_negative = 32 +strict_negative = True +adversarial_temperature = 1 +lr = 0.005 +gpus = [0] +batch_size = 64 +num_epoch = 20 +log_interval = 100 + +``` + +All config can be found in [config.ini](../../config.ini) + + + +## More + +#### If you have any questions, + +Submit an issue or email to [zhaozihao@bupt.edu.cn](mailto:zhaozihao@bupt.edu.cn). diff --git a/openhgnn/output/RedGNN/README.md b/openhgnn/output/RedGNN/README.md new file mode 100644 index 00000000..a04c4cc0 --- /dev/null +++ b/openhgnn/output/RedGNN/README.md @@ -0,0 +1,89 @@ +# Adaprop[WWW2022] + +- paper: https://arxiv.org/pdf/2108.06040.pdf +- Code from author: [RedGNN](https://github.com/LARS-research/RED-GNN) + +This model is divided into two scenarios: transactional and inductive , so we will integrate them into two models: REDGNNT and REDGNN +## AdapropT +### How to run + +- Clone the Openhgnn-DGL + + ```bash + python main.py -m RedGNN -d family -t link_prediction -g 0 + ``` + + If you do not have gpu, set -gpu -1. + + the dataset family is supported. + +### Performance: Family + +- Device: GPU, **GeForce RTX 3090** +- Dataset: AdapropT + +| transductive | Family +|:------------:| :-----------------------------: +| MRR | paper: 0.992 OpenHGNN: 0.9836 +| H@1 | paper: 98.8 OpenHGNN: 0.9771 +| H@10 | paper: 99.7 OpenHGNN: 0.9907 + + + + + +### TrainerFlow: RedGNNT_trainer + +#### model + +- ​ RedGNNT + - ​ RedGNN is a GNN based knowledge graph reasoning model. + + + + +## RedGNN +### How to run + +- Clone the Openhgnn-DGL + + ```bash + python main.py -m RedGNN -d fb237_v1 -t link_prediction -g 0 + ``` + + If you do not have gpu, set -gpu -1. + + the dataset AdapropI is supported. + + ## Performance: fb15k237_v1 + +- Device: GPU, **GeForce RTX 2080Ti** +- Dataset: AdapropI +- +| inductive | fb15k237_v1 +|:---------:| :-----------------------------: +| MRR | paper: 0.369 OpenHGNN: 0.3543 +| H@1 | paper: 0.302 OpenHGNN: 0.2968 +| H@10 | paper: 0.483 OpenHGNN: 0.4672 + + + + + +### TrainerFlow: RedGNN + +#### model + +- ​ RedGNN + - ​ RedGNN is a GNN based knowledge graph reasoning model. + +#### Contributor + +fengqi liang[GAMMA LAB] + +#### If you have any questions, + +Submit an issue or email to lfq@bupt.edu.cn + + + diff --git a/openhgnn/output/SACN/README.md b/openhgnn/output/SACN/README.md new file mode 100644 index 00000000..60db4a54 --- /dev/null +++ b/openhgnn/output/SACN/README.md @@ -0,0 +1,43 @@ +# SACN + +Paper:[End-to-end Structure-Aware Convolutional Networks for Knowledge Base Completion](https://arxiv.org/pdf/1811.04441.pdf) + +## How to run + +* Clone the Openhgnn-DGL + +```bash +python main.py -m SACN -d SACN -t link_prediction -g 0 +``` + +If you do not have gpu, set -gpu -1. + +## Performance + +| Dataset | FB15k-237 | +|---------|----------------------------| +| MRR | Paper:0.35 OpenHGNN:0.3528 | +| H@1 | Paper:0.26 OpenHGNN:0.2575 | +| H@3 | Paper:0.39 OpenHGNN:0.3938 | +| H@10 | Paper:0.54 OpenHGNN:0.5397 | + +### TrainerFlow + +```SACN_trainer``` + +### model + +```SACN``` + +### Dataset + +SACN_dataset + + +#### Contributor + +Zikai Zhou[GAMMA LAB] + +#### If you have any questions, + +Submit an issue or email to 460813395@qq.com \ No newline at end of file diff --git a/openhgnn/tasks/Ingram_task.py b/openhgnn/tasks/Ingram_task.py new file mode 100644 index 00000000..f5483144 --- /dev/null +++ b/openhgnn/tasks/Ingram_task.py @@ -0,0 +1,17 @@ +import torch.nn.functional as F +from . import BaseTask, register_task +from ..dataset import build_dataset +from ..utils import Evaluator + + +@register_task("Ingram") +class Ingram(BaseTask): + """Recommendation tasks.""" + def __init__(self, args): + super().__init__() + self.logger = args.logger + self.name_dataset = args.dataset + self.train_dataloader, self.valid_dataloader, self.test_dataloader = build_dataset(args.dataset, 'Ingram') + def evaluate(self, y_true, y_score): + pass + diff --git a/openhgnn/tasks/__init__.py b/openhgnn/tasks/__init__.py index 61498723..fa5fe5a8 100644 --- a/openhgnn/tasks/__init__.py +++ b/openhgnn/tasks/__init__.py @@ -59,6 +59,9 @@ def try_import_task(task): 'pretrain': 'openhgnn.tasks.pretrain', 'abnorm_event_detection': 'openhgnn.tasks.AbnormEventDetection', 'DSSL_trainer': 'openhgnn.tasks.DSSL_task', + 'NBF_link_prediction':'openhgnn.tasks.link_prediction', + 'Ingram': 'openhgnn.tasks.Ingram', + 'DisenKGAT_link_prediction':'openhgnn.tasks.link_prediction', } from .node_classification import NodeClassification @@ -67,6 +70,8 @@ def try_import_task(task): from .edge_classification import EdgeClassification from .hypergraph import hypergraph from .node_classification import DSSL_task +from .Ingram_task import Ingram + __all__ = [ 'BaseTask', diff --git a/openhgnn/tasks/link_prediction.py b/openhgnn/tasks/link_prediction.py index 8470285b..ba1a13e2 100644 --- a/openhgnn/tasks/link_prediction.py +++ b/openhgnn/tasks/link_prediction.py @@ -29,12 +29,30 @@ class LinkPrediction(BaseTask): """ def __init__(self, args): - super(LinkPrediction, self).__init__() + super(LinkPrediction, self).__init__( ) self.name_dataset = args.dataset self.logger = args.logger - self.dataset = build_dataset(args.dataset, 'link_prediction', logger=self.logger) + + if args.model=="Grail" or args.model =="ComPILE": + self.dataset = build_dataset(args.dataset, 'link_prediction', logger=self.logger,args = args) + return + if(args.dataset=='AdapropT'): + self.dataloader = build_dataset(args.dataset, 'Adaprop', logger=self.logger,args = args) + return + if(args.dataset=='AdapropI'): + self.dataloader = build_dataset(args.dataset, 'AdapropI', logger=self.logger, args = args) + return + if (args.dataset_name == 'LTE'): + build_dataset(args, 'LTE') + return + if (args.dataset_name == 'SACN'): + build_dataset(args, 'SACN') + return + self.dataset = build_dataset(args.dataset, 'link_prediction', logger=self.logger, args=args) # self.evaluator = Evaluator() - self.train_hg, self.val_hg, self.test_hg, self.neg_val_graph, self.neg_test_graph = self.dataset.get_split() + if args.model == 'ExpressGNN' or args.model == 'RedGNN' or args.model == 'RedGNNT' or args.model == 'DisenKGAT': + return + self.train_hg, self.val_hg, self.test_hg, self.neg_val_graph, self.neg_test_graph = self.dataset.get_split( ) self.pred_hg = getattr(self.dataset, 'pred_graph', None) if self.val_hg is None and self.test_hg is None: pass @@ -43,14 +61,14 @@ def __init__(self, args): self.test_hg = self.test_hg.to(args.device) self.evaluator = Evaluator(args.seed) if not hasattr(args, 'score_fn'): - self.ScorePredictor = HeteroDistMultPredictor() + self.ScorePredictor = HeteroDistMultPredictor( ) args.score_fn = 'distmult' elif args.score_fn == 'dot-product': - self.ScorePredictor = HeteroDotProductPredictor() + self.ScorePredictor = HeteroDotProductPredictor( ) elif args.score_fn == 'distmult': - self.ScorePredictor = HeteroDistMultPredictor() + self.ScorePredictor = HeteroDistMultPredictor( ) # deprecated, new score predictor of these score_fn are in their model - #elif args.score_fn in ['transe', 'transh', 'transr', 'transd', 'gie'] : + # elif args.score_fn in ['transe', 'transh', 'transr', 'transd', 'gie'] : # self.ScorePredictor = HeteroTransXPredictor(args.dis_norm) self.negative_sampler = Uniform(1) @@ -58,7 +76,7 @@ def __init__(self, args): self.evaluation_metric = getattr(args, 'evaluation_metric', 'roc_auc') # default evaluation_metric is roc_auc if args.dataset in ['wn18', 'FB15k', 'FB15k-237']: self.evaluation_metric = 'mrr' - self.filtered = args.filtered + # self.filtered = args.filtered if hasattr(args, "valid_percent"): self.dataset.modify_size(args.valid_percent, 'valid') if hasattr(args, "test_percent"): @@ -130,7 +148,7 @@ def evaluate(self, n_embedding, r_embedding=None, mode='test'): n_score = th.sigmoid(self.ScorePredictor(neg_hg, n_embedding, r_embedding)) p_label = th.ones(len(p_score), device=p_score.device) n_label = th.zeros(len(n_score), device=p_score.device) - roc_auc = self.evaluator.cal_roc_auc(th.cat((p_label, n_label)).cpu(), th.cat((p_score, n_score)).cpu()) + roc_auc = self.evaluator.cal_roc_auc(th.cat((p_label, n_label)).cpu( ), th.cat((p_score, n_score)).cpu( )) loss = F.binary_cross_entropy_with_logits(th.cat((p_score, n_score)), th.cat((p_label, n_label))) return dict(roc_auc=roc_auc, loss=loss) else: @@ -138,13 +156,13 @@ def evaluate(self, n_embedding, r_embedding=None, mode='test'): def predict(self, n_embedding, r_embedding, **kwargs): score = th.sigmoid(self.ScorePredictor(self.pred_hg, n_embedding, r_embedding)) - indices = self.pred_hg.edges() + indices = self.pred_hg.edges( ) return indices, score def tranX_predict(self): pred_triples_T = self.dataset.pred_triples.T score = th.sigmoid(self.ScorePredictor(pred_triples_T[0], pred_triples_T[1], pred_triples_T[2])) - indices = self.pred_hg.edges() + indices = self.pred_hg.edges( ) return indices, score def downstream_evaluate(self, logits, evaluation_metric): @@ -160,13 +178,13 @@ def get_train(self): return self.train_hg def get_labels(self): - return self.dataset.get_labels() + return self.dataset.get_labels( ) def dict2emd(self, r_embedding): r_emd = [] for i in range(self.dataset.num_rels): r_emd.append(r_embedding[str(i)]) - return th.stack(r_emd).squeeze() + return th.stack(r_emd).squeeze( ) def construct_negative_graph(self, hg): e_dict = { @@ -181,7 +199,7 @@ def construct_negative_graph(self, hg): class HeteroDotProductPredictor(th.nn.Module): """ References: `documentation of dgl _` - + """ def forward(self, edge_subgraph, x, *args, **kwargs): @@ -192,14 +210,14 @@ def forward(self, edge_subgraph, x, *args, **kwargs): the prediction graph only contains the edges of the target link x: dict[str: th.Tensor] the embedding dict. The key only contains the nodes involving with the target link. - + Returns ------- score: th.Tensor the prediction of the edges in edge_subgraph """ - with edge_subgraph.local_scope(): + with edge_subgraph.local_scope( ): for ntype in edge_subgraph.ntypes: edge_subgraph.nodes[ntype].data['x'] = x[ntype] for etype in edge_subgraph.canonical_etypes: @@ -208,10 +226,10 @@ def forward(self, edge_subgraph, x, *args, **kwargs): score = edge_subgraph.edata['score'] if isinstance(score, dict): result = [] - for _, value in score.items(): + for _, value in score.items( ): result.append(value) score = th.cat(result) - return score.squeeze() + return score.squeeze( ) class HeteroDistMultPredictor(th.nn.Module): @@ -220,13 +238,13 @@ def forward(self, edge_subgraph, x, r_embedding, *args, **kwargs): """ DistMult factorization (Yang et al. 2014) as the scoring function, which is known to perform well on standard link prediction benchmarks when used on its own. - + In DistMult, every relation r is associated with a diagonal matrix :math:`R_{r} \in \mathbb{R}^{d \times d}` and a triple (s, r, o) is scored as - + .. math:: f(s, r, o)=e_{s}^{T} R_{r} e_{o} - + Parameters ---------- edge_subgraph: dgl.Heterograph @@ -235,13 +253,13 @@ def forward(self, edge_subgraph, x, r_embedding, *args, **kwargs): the node embedding dict. The key only contains the nodes involving with the target link. r_embedding: th.Tensor the all relation types embedding - + Returns ------- score: th.Tensor the prediction of the edges in edge_subgraph """ - with edge_subgraph.local_scope(): + with edge_subgraph.local_scope( ): for ntype in edge_subgraph.ntypes: edge_subgraph.nodes[ntype].data['x'] = x[ntype] for etype in edge_subgraph.canonical_etypes: @@ -259,14 +277,14 @@ def forward(self, edge_subgraph, x, r_embedding, *args, **kwargs): score = edge_subgraph.edata['score'] if isinstance(score, dict): result = [] - for _, value in score.items(): + for _, value in score.items( ): result.append(th.sum(value, dim=1)) score = th.cat(result) else: score = th.sum(score, dim=1) return score -#class HeteroTransXPredictor(th.nn.Module): +# class HeteroTransXPredictor(th.nn.Module): # def __init__(self, dis_norm): # super(HeteroTransXPredictor, self).__init__() # self.dis_norm = dis_norm @@ -277,3 +295,36 @@ def forward(self, edge_subgraph, x, r_embedding, *args, **kwargs): # t = F.normalize(t, 2, -1) # dist = th.norm(h+r-t, self.dis_norm, dim=-1) # return dist + +@register_task("NBF_link_prediction") +class NBF_LinkPrediction(BaseTask): + r""" + Link prediction tasks for NBF + + """ + + def __init__(self, args): + super(NBF_LinkPrediction, self).__init__() + self.logger = None + # dataset = 'NBF_WN18RR' or 'NBF_FB15k-237' + self.dataset = build_dataset(args.dataset, 'link_prediction',logger=self.logger,args=args) + + def evaluate(self): + return None + + + +@register_task("DisenKGAT_link_prediction") +class DisenKGAT_LinkPrediction(BaseTask): + + def __init__(self, args ) : + super(DisenKGAT_LinkPrediction, self).__init__() + self.logger = None + + self.dataset = build_dataset(dataset = args.dataset, task='link_prediction', + logger=self.logger, args=args) + + + def evaluate(self): + return None + diff --git a/openhgnn/trainerflow/AdapropI_trainer.py b/openhgnn/trainerflow/AdapropI_trainer.py new file mode 100644 index 00000000..f0368783 --- /dev/null +++ b/openhgnn/trainerflow/AdapropI_trainer.py @@ -0,0 +1,124 @@ +import random +import os +import argparse +import torch +import numpy as np +from ..models import build_model +from . import BaseFlow, register_flow +from ..tasks import build_task +from ..utils.AdapropI_utils import * +@register_flow("AdapropI_trainer") +class AdapropITrainer(BaseFlow): + def __init__(self, args): + class Options(object): + pass + + dataset = args.data_path + dataset = dataset.split('/') + if len(dataset[-1]) > 0: + dataset = dataset[-1] + else: + dataset = dataset[-2] + # args.dataset_name=dataset + args.dataset_name=dataset + self.model_name='AdapropI' + self.args = Options + self.args=args + self.args.hidden_dim = 64 + self.args.init_dim = 10 + self.args.attn_dim = 5 + self.args.n_layer = 3 + self.args.n_batch = 50 + self.args.lr = 0.001 + self.args.decay_rate = 0.999 + self.args.perf_file = './results.txt' + self.args.task_dir=args.data_path + gpu = args.device + torch.cuda.set_device(gpu) + print('==> selected GPU id: ', gpu) + args.n_batch=self.args.n_batch + self.task = build_task(self.args) + self.loader=self.task.dataloader + loader=self.loader + # loader = DataLoader(args.data_path, n_batch=self.args.n_batch) + self.args.n_ent = loader.n_ent + self.args.n_rel = loader.n_rel + + + + params = {} + if 'fb237_v1' in args.data_path: + params['lr'], params['decay_rate'], params["lamb"], params['hidden_dim'], params['init_dim'], params[ + 'attn_dim'], params['n_layer'], params['n_batch'], params['dropout'], params['act'], params['topk'], \ + params['increase'] = 0.0005, 0.9968, 0.000081, 32, 32, 5, 3, 100, 0.4137, 'relu', 100, True + if 'fb237_v2' in args.data_path: + params['lr'], params['decay_rate'], params["lamb"], params['hidden_dim'], params['init_dim'], params[ + 'attn_dim'], params['n_layer'], params['n_batch'], params['dropout'], params['act'], params['topk'], \ + params['increase'] = 0.0087, 0.9937, 0.000025, 16, 16, 5, 5, 20, 0.3265, 'relu', 200, True + + if 'fb237_v3' in args.data_path: + params['lr'], params['decay_rate'], params["lamb"], params['hidden_dim'], params['init_dim'], params[ + 'attn_dim'], params['n_layer'], params['n_batch'], params['dropout'], params['act'], params['topk'], \ + params['increase'] = 0.0079, 0.9934, 0.000187, 48, 48, 5, 7, 20, 0.4632, 'relu', 200, True + + if 'fb237_v4' in args.data_path: + params['lr'], params['decay_rate'], params["lamb"], params['hidden_dim'], params['init_dim'], params[ + 'attn_dim'], params['n_layer'], params['n_batch'], params['dropout'], params['act'], params['topk'], \ + params['increase'] = 0.0010, 0.9997, 0.000186, 16, 16, 5, 7, 50, 0.4793, 'relu', 500, True + + print(params) + + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + self.args.lr = params['lr'] + self.args.lamb = params["lamb"] + self.args.decay_rate = params['decay_rate'] + self.args.hidden_dim = params['hidden_dim'] + self.args.init_dim = params['hidden_dim'] + self.args.attn_dim = params['attn_dim'] + self.args.dropout = params['dropout'] + self.args.act = params['act'] + self.args.n_layer = params['n_layer'] + self.args.n_batch = params['n_batch'] + self.args.topk = params['topk'] + self.args.increase = params['increase'] + + config_str = '%.4f, %.4f, %.6f, %d, %d, %d, %d, %d, %.4f, %s %d, %s\n' % ( + self.args.lr, self.args.decay_rate, self.args.lamb, self.args.hidden_dim, self.args.init_dim, self.args.attn_dim, self.args.n_layer, + self.args.n_batch, self.args.dropout, self.args.act, self.args.topk, str(self.args.increase)) + print(args.data_path) + print(config_str) + try: + self.model = build_model(self.model_name).build_model_from_args( + self.args, self.loader).model + model = self.model + best_mrr = 0 + best_tmrr = 0 + early_stop = 0 + for epoch in range(30): + print("epoch:"+str(epoch)) + mrr, t_mrr, out_str = model.train_batch() + print(mrr, t_mrr, out_str) + if mrr > best_mrr: + best_mrr = mrr + best_tmrr = t_mrr + best_str = out_str + early_stop = 0 + else: + early_stop += 1 + + with open(self.args.perf_file, 'a') as f: + f.write(args.data_path + '\n') + f.write(config_str) + f.write(best_str + '\n') + print('\n\n') + + except RuntimeError: + best_tmrr = 0 + + print( + 'self.time_1, self.time_2, time_3, v_mrr, v_mr, v_h1, v_h3, v_h10, v_h1050, t_mrr, t_mr, t_h1, t_h3, t_h10, t_h1050') + print(best_str) + def train(self): + print(2) \ No newline at end of file diff --git a/openhgnn/trainerflow/AdapropT_trainer.py b/openhgnn/trainerflow/AdapropT_trainer.py new file mode 100644 index 00000000..fbeb321a --- /dev/null +++ b/openhgnn/trainerflow/AdapropT_trainer.py @@ -0,0 +1,88 @@ +import random +import os +import argparse +import torch +import numpy as np +from ..models import build_model +from . import BaseFlow, register_flow +from ..tasks import build_task +from ..utils.Adaprop_utils import * + + +@register_flow("AdapropT_trainer") +class AdapropTTrainer(BaseFlow): + def __init__(self, args): + opts = args + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.set_num_threads(8) + args.dataset="AdapropT" + dataset = opts.data_path + dataset = dataset.split('/') + if len(dataset[-1]) > 0: + self.dataset_name = dataset[-1] + else: + self.dataset_name = dataset[-2] + + torch.cuda.set_device(args.gpu) + print('==> gpu:', args.gpu) + self.task=build_task(args) + self.loader=self.task.dataloader + self.model_name='AdapropT' + args.n_ent = self.loader.n_ent + args.n_rel = self.loader.n_rel + self.args=args + self.model=build_model(self.model_name).build_model_from_args( + self.args,self.loader).model + # loader = DataLoader(opts) + + def train(self): + print("111") + opts=self.args + # check all output paths + checkPath('./results/') + checkPath(f'./results/{self.dataset_name}/') + checkPath(f'{self.loader.task_dir}/saveModel/') + model = self.model + opts.perf_file = f'results/{self.dataset_name}/{model.modelName}_perf.txt' + print(f'==> perf_file: {self.args.perf_file}') + + config_str = '%.4f, %.4f, %.6f, %d, %d, %d, %d, %.4f,%s\n' % ( + opts.lr, opts.decay_rate, opts.lamb, opts.hidden_dim, opts.attn_dim, opts.n_layer, opts.n_batch, opts.dropout, + opts.act) + print(config_str) + # with open(opts.perf_file, 'a+') as f: + # f.write(config_str) + + # if self.args.weight != None: + # model.loadModel(self.args.weight) + # model._update() + # model.model.updateTopkNums(opts.n_node_topk) + + if opts.train: + # training mode + best_v_mrr = 0 + for epoch in range(opts.epoch): + model.train_batch() + # eval on val/test set + if (epoch + 1) % self.args.eval_interval == 0: + result_dict, out_str = model.evaluate(eval_val=True, eval_test=True) + v_mrr, t_mrr = result_dict['v_mrr'], result_dict['t_mrr'] + print(out_str) + with open(opts.perf_file, 'a+') as f: + f.write(out_str) + if v_mrr > best_v_mrr: + best_v_mrr = v_mrr + best_str = out_str + print(str(epoch) + '\t' + best_str) + BestMetricStr = f'ValMRR_{str(v_mrr)[:5]}_TestMRR_{str(t_mrr)[:5]}' + model.saveModelToFiles(BestMetricStr, deleteLastFile=False) + + # show the final result + print(best_str) + + if opts.eval: + # evaluate on test set with loaded weight to save time + result_dict, out_str = model.evaluate(eval_val=False, eval_test=True, verbose=True) + print(result_dict, '\n', out_str) + diff --git a/openhgnn/trainerflow/ComPILE_trainer.py b/openhgnn/trainerflow/ComPILE_trainer.py new file mode 100644 index 00000000..96e9044a --- /dev/null +++ b/openhgnn/trainerflow/ComPILE_trainer.py @@ -0,0 +1,707 @@ +import statistics +import timeit +import os +import logging +import pdb +import numpy as np +import time + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.utils.data import DataLoader +from openhgnn.models import build_model +from . import BaseFlow, register_flow +from ..tasks import build_task +import scipy.sparse as ssp +from sklearn import metrics +from ..utils.Grail_utils import collate_dgl2,move_batch_to_device_dgl,ssp_multigraph_to_dgl +import dgl +import tqdm +import random +params = None + +@register_flow('ComPILE_trainer') +class ComPILETrainer(BaseFlow): + def __init__(self,args): + super(ComPILETrainer, self).__init__(args) + #self.train_hg = self.task.get_train() + self.trainset = self.task.dataset.train + self.valid = self.task.dataset.valid + self.args.num_rels = self.trainset.num_rels + self.args.aug_num_rels = self.trainset.aug_num_rels + self.args.inp_dim = self.trainset.n_feat_dim + + self.args.collate_fn = collate_dgl2 + self.args.move_batch_to_device = move_batch_to_device_dgl + self.args.max_label_value = self.trainset.max_n_label + self.params = self.args + self.params.adj_list = [] + self.params.dgl_adj_list = [] + self.params.triplets = [] + self.params.entity2id = [] + self.params.relation2id = [] + self.params.id2entity = [] + self.params.id2relation = [] + + # Log the max label value to save it in the model. This will be used to cap the labels generated on test set. + + self.model = build_model(self.model).build_model_from_args(self.args, self.task.dataset.relation2id).to( + self.device) + self.updates_counter = 0 + model_params = list(self.model.parameters()) + #logging.info('Total number of parameters: %d' % sum(map(lambda x: x.numel(), model_params))) + + if self.args.optimizer == "SGD": + self.optimizer = optim.SGD(model_params, lr=self.args.lr, momentum=self.args.momentum, + weight_decay=self.args.l2) + if self.args.optimizer == "Adam": + self.optimizer = optim.Adam(model_params, lr=self.args.lr, weight_decay=self.args.l2) + + self.criterion = nn.MarginRankingLoss(self.args.margin, reduction='sum') + + self.reset_training_state() + #graph_classifier = initialize_model(params, dgl_model, params.load_model) + + self.logger.info(f"Device: {args.device}") + self.logger.info( + f"Input dim : {args.inp_dim}, # Relations : {args.num_rels}, # Augmented relations : {args.aug_num_rels}") + + self.args.save_path = os.path.dirname(os.path.abspath('__file__')) + '/openhgnn/output/' + self.model_name + self.valid_evaluator = Evaluator(self.args, self.model, self.valid) + #self.save_path = os.path.dirname(os.path.abspath('__file__')) + '/openhgnn/output/' + self.model_name + + self.logger.info('Starting training with full batch...') + + + + + def reset_training_state(self): + self.best_metric = 0 + self.last_metric = 0 + self.not_improved_count = 0 + + def train_epoch(self): + total_loss = 0 + all_preds = [] + all_labels = [] + all_scores = [] + + dataloader = DataLoader(self.trainset, batch_size=self.args.batch_size, shuffle=True, + num_workers=self.args.num_workers, collate_fn=self.args.collate_fn) + # dataloader = DataLoader(self.train_data, batch_size=self.params.batch_size, shuffle=True, num_workers=self.params.num_workers) + self.model.train() + model_params = list(self.model.parameters()) + torch.multiprocessing.set_sharing_strategy('file_system') + for b_idx, batch in enumerate(dataloader): + (graphs_pos, r_labels_pos), g_labels_pos, (graph_neg, r_labels_neg), g_labels_neg = batch + #(graphs_pos, r_labels_pos), g_labels_pos, (graph_neg, r_labels_neg), g_labels_neg = self.args.move_batch_to_device(batch, self.args.device) + + g_labels_pos = torch.LongTensor(g_labels_pos).to(device=self.args.device) + r_labels_pos = torch.LongTensor(r_labels_pos).to(device=self.args.device) + + g_labels_neg = torch.LongTensor(g_labels_neg).to(device=self.args.device) + r_labels_neg = torch.LongTensor(r_labels_neg).to(device=self.args.device) + + self.model.train() + # data_pos, targets_pos, data_neg, targets_neg = self.params.move_batch_to_device(batch, self.params.device) + self.optimizer.zero_grad() + # print('batch size ', len(targets_pos), ' ', len(targets_neg)) + # print('r label pos ', len(data_pos[1]), ' r label neg ', len(data_neg[1])) + score_pos = self.model(graphs_pos) + score_neg = self.model(graph_neg) + loss = self.criterion(score_pos.squeeze(), score_neg.view(len(score_pos), -1).mean(dim=1), + torch.Tensor([1]).to(device=self.args.device)) + # print(score_pos, score_neg, loss) + loss.backward() + self.optimizer.step() + self.updates_counter += 1 + + with torch.no_grad(): + # print(score_pos.shape, score_neg.shape) + # print(score_pos) + all_scores += score_pos.squeeze(1).detach().cpu().tolist() + score_neg.squeeze( + 1).detach().cpu().tolist() + all_labels += g_labels_pos.tolist() + g_labels_neg.tolist() + total_loss += loss + + if self.valid_evaluator and self.args.eval_every_iter and self.updates_counter % self.args.eval_every_iter == 0: + tic = time.time() + result = self.valid_evaluator.eval() + logging.info('\nPerformance:' + str(result) + 'in ' + str(time.time() - tic)) + + if result['auc'] >= self.best_metric: + self.save_classifier() + self.best_metric = result['auc'] + self.not_improved_count = 0 + + else: + self.not_improved_count += 1 + if self.not_improved_count > self.args.early_stop: + logging.info( + f"Validation performance didn\'t improve for {self.args.early_stop} epochs. Training stops.") + break + self.last_metric = result['auc'] + + auc = metrics.roc_auc_score(all_labels, all_scores) + auc_pr = metrics.average_precision_score(all_labels, all_scores) + + weight_norm = sum(map(lambda x: torch.norm(x), model_params)) + + return total_loss, auc, auc_pr, weight_norm + + def train(self): + self.reset_training_state() + for epoch in range(1, self.args.num_epochs + 1): + time_start = time.time() + loss, auc, auc_pr, weight_norm = self.train_epoch() + time_elapsed = time.time() - time_start + self.logger.info( + f'Epoch {epoch} with loss: {loss}, training auc: {auc}, training auc_pr: {auc_pr}, best validation AUC: {self.best_metric}, weight_norm: {weight_norm} in {time_elapsed}') + + if epoch % self.args.save_every == 0: + #save_path = os.path.dirname(os.path.abspath('__file__')) + '/openhgnn/output/' + self.model_name + torch.save(self.model, self.args.save_path + '/ComPILE_chk.pth') + self.params.model_path = self.args.save_path + '/ComPILE_chk.pth' + self.params.file_paths = { + 'graph': os.path.join(f'./openhgnn/dataset/data/{self.args.dataset}_ind/train.txt'), + 'links': os.path.join(f'./openhgnn/dataset/data/{self.args.dataset}_ind/test.txt') + } + global params + params = self.params + eval_rank(self.logger) + return + + def save_classifier(self): + #save_path = os.path.dirname(os.path.abspath('__file__')) + '/openhgnn/output/' + self.model_name + torch.save(self.model, self.args.save_path + '/best.pth') + self.logger.info('Better models found w.r.t accuracy. Saved it!') + + +class Evaluator(): + def __init__(self, params, graph_classifier, data): + self.params = params + self.graph_classifier = graph_classifier + self.data = data + + def eval(self, save=False): + pos_scores = [] + pos_labels = [] + neg_scores = [] + neg_labels = [] + dataloader = DataLoader(self.data, batch_size=self.params.batch_size, shuffle=False, + num_workers=self.params.num_workers, collate_fn=self.params.collate_fn) + + self.graph_classifier.eval() + with torch.no_grad(): + for b_idx, batch in enumerate(dataloader): + (graphs_pos, r_labels_pos), g_labels_pos, (graph_neg, r_labels_neg), g_labels_neg = batch + # data_pos, targets_pos, data_neg, targets_neg = self.params.move_batch_to_device(batch, self.params.device) + # print([self.data.id2relation[r.item()] for r in data_pos[1]]) + # pdb.set_trace() + + g_labels_pos = torch.LongTensor(g_labels_pos).to(device=self.params.device) + r_labels_pos = torch.LongTensor(r_labels_pos).to(device=self.params.device) + + g_labels_neg = torch.LongTensor(g_labels_neg).to(device=self.params.device) + r_labels_neg = torch.LongTensor(r_labels_neg).to(device=self.params.device) + + score_pos = self.graph_classifier(graphs_pos) + score_neg = self.graph_classifier(graph_neg) + + # preds += torch.argmax(logits.detach().cpu(), dim=1).tolist() + pos_scores += score_pos.squeeze(1).detach().cpu().tolist() + neg_scores += score_neg.squeeze(1).detach().cpu().tolist() + pos_labels += g_labels_pos.tolist() + neg_labels += g_labels_neg.tolist() + + # acc = metrics.accuracy_score(labels, preds) + auc = metrics.roc_auc_score(pos_labels + neg_labels, pos_scores + neg_scores) + auc_pr = metrics.average_precision_score(pos_labels + neg_labels, pos_scores + neg_scores) + + if save: + pos_test_triplets_path = os.path.join(self.params.save_path, + 'data/{}/{}.txt'.format(self.params.dataset, self.data.file_name)) + with open(pos_test_triplets_path) as f: + pos_triplets = [line.split() for line in f.read().split('\n')[:-1]] + pos_file_path = os.path.join(self.params.save_path, + 'data/{}/grail_{}_predictions.txt'.format(self.params.dataset, + self.data.file_name)) + with open(pos_file_path, "w") as f: + for ([s, r, o], score) in zip(pos_triplets, pos_scores): + f.write('\t'.join([s, r, o, str(score)]) + '\n') + + neg_test_triplets_path = os.path.join(self.params.save_path, + 'data/{}/neg_{}_0.txt'.format(self.params.dataset, + self.data.file_name)) + with open(neg_test_triplets_path) as f: + neg_triplets = [line.split() for line in f.read().split('\n')[:-1]] + neg_file_path = os.path.join(self.params.save_path, + 'data/{}/grail_neg_{}_{}_predictions.txt'.format(self.params.dataset, + self.data.file_name, + self.params.constrained_neg_prob)) + with open(neg_file_path, "w") as f: + for ([s, r, o], score) in zip(neg_triplets, neg_scores): + f.write('\t'.join([s, r, o, str(score)]) + '\n') + + return {'auc': auc, 'auc_pr': auc_pr} + +def process_files(files, saved_relation2id, add_traspose_rels): + ''' + files: Dictionary map of file paths to read the triplets from. + saved_relation2id: Saved relation2id (mostly passed from a trained model) which can be used to map relations to pre-defined indices and filter out the unknown ones. + ''' + entity2id = {} + relation2id = saved_relation2id + + triplets = {} + + ent = 0 + rel = 0 + + for file_type, file_path in files.items(): + + data = [] + with open(file_path) as f: + file_data = [line.split() for line in f.read().split('\n')[:-1]] + + for triplet in file_data: + if triplet[0] not in entity2id: + entity2id[triplet[0]] = ent + ent += 1 + if triplet[2] not in entity2id: + entity2id[triplet[2]] = ent + ent += 1 + + # Save the triplets corresponding to only the known relations + if triplet[1] in saved_relation2id: + data.append([entity2id[triplet[0]], entity2id[triplet[2]], saved_relation2id[triplet[1]]]) + + triplets[file_type] = np.array(data) + + id2entity = {v: k for k, v in entity2id.items()} + id2relation = {v: k for k, v in relation2id.items()} + + # Construct the list of adjacency matrix each corresponding to eeach relation. Note that this is constructed only from the train data. + adj_list = [] + for i in range(len(saved_relation2id)): + idx = np.argwhere(triplets['graph'][:, 2] == i) + adj_list.append(ssp.csc_matrix((np.ones(len(idx), dtype=np.uint8), (triplets['graph'][:, 0][idx].squeeze(1), triplets['graph'][:, 1][idx].squeeze(1))), shape=(len(entity2id), len(entity2id)))) + + # Add transpose matrices to handle both directions of relations. + adj_list_aug = adj_list + if add_traspose_rels: + adj_list_t = [adj.T for adj in adj_list] + adj_list_aug = adj_list + adj_list_t + + dgl_adj_list = ssp_multigraph_to_dgl(adj_list_aug) + + return adj_list, dgl_adj_list, triplets, entity2id, relation2id, id2entity, id2relation + + +def intialize_worker(model, adj_list, dgl_adj_list, id2entity, params, node_features, kge_entity2id): + global model_, adj_list_, dgl_adj_list_, id2entity_, params_, node_features_, kge_entity2id_ + model_, adj_list_, dgl_adj_list_, id2entity_, params_, node_features_, kge_entity2id_ = model, adj_list, dgl_adj_list, id2entity, params, node_features, kge_entity2id + + +def get_neg_samples_replacing_head_tail(test_links, adj_list, num_samples=50): + + n, r = adj_list[0].shape[0], len(adj_list) + heads, tails, rels = test_links[:, 0], test_links[:, 1], test_links[:, 2] + + neg_triplets = [] + for i, (head, tail, rel) in enumerate(zip(heads, tails, rels)): + neg_triplet = {'head': [[], 0], 'tail': [[], 0]} + neg_triplet['head'][0].append([head, tail, rel]) + while len(neg_triplet['head'][0]) < num_samples: + neg_head = head + neg_tail = np.random.choice(n) + + if neg_head != neg_tail and adj_list[rel][neg_head, neg_tail] == 0: + neg_triplet['head'][0].append([neg_head, neg_tail, rel]) + + neg_triplet['tail'][0].append([head, tail, rel]) + while len(neg_triplet['tail'][0]) < num_samples: + neg_head = np.random.choice(n) + neg_tail = tail + # neg_head, neg_tail, rel = np.random.choice(n), np.random.choice(n), np.random.choice(r) + + if neg_head != neg_tail and adj_list[rel][neg_head, neg_tail] == 0: + neg_triplet['tail'][0].append([neg_head, neg_tail, rel]) + + neg_triplet['head'][0] = np.array(neg_triplet['head'][0]) + neg_triplet['tail'][0] = np.array(neg_triplet['tail'][0]) + + neg_triplets.append(neg_triplet) + + return neg_triplets + + +def get_neg_samples_replacing_head_tail_all(test_links, adj_list): + + n, r = adj_list[0].shape[0], len(adj_list) + heads, tails, rels = test_links[:, 0], test_links[:, 1], test_links[:, 2] + + neg_triplets = [] + print('sampling negative triplets...') + for i, (head, tail, rel) in tqdm(enumerate(zip(heads, tails, rels)), total=len(heads)): + neg_triplet = {'head': [[], 0], 'tail': [[], 0]} + neg_triplet['head'][0].append([head, tail, rel]) + for neg_tail in range(n): + neg_head = head + + if neg_head != neg_tail and adj_list[rel][neg_head, neg_tail] == 0: + neg_triplet['head'][0].append([neg_head, neg_tail, rel]) + + neg_triplet['tail'][0].append([head, tail, rel]) + for neg_head in range(n): + neg_tail = tail + + if neg_head != neg_tail and adj_list[rel][neg_head, neg_tail] == 0: + neg_triplet['tail'][0].append([neg_head, neg_tail, rel]) + + neg_triplet['head'][0] = np.array(neg_triplet['head'][0]) + neg_triplet['tail'][0] = np.array(neg_triplet['tail'][0]) + + neg_triplets.append(neg_triplet) + + return neg_triplets + + +def get_neg_samples_replacing_head_tail_from_ruleN(ruleN_pred_path, entity2id, saved_relation2id): + with open(ruleN_pred_path) as f: + pred_data = [line.split() for line in f.read().split('\n')[:-1]] + + neg_triplets = [] + for i in range(len(pred_data) // 3): + neg_triplet = {'head': [[], 10000], 'tail': [[], 10000]} + if pred_data[3 * i][1] in saved_relation2id: + head, rel, tail = entity2id[pred_data[3 * i][0]], saved_relation2id[pred_data[3 * i][1]], entity2id[pred_data[3 * i][2]] + for j, new_head in enumerate(pred_data[3 * i + 1][1::2]): + neg_triplet['head'][0].append([entity2id[new_head], tail, rel]) + if entity2id[new_head] == head: + neg_triplet['head'][1] = j + for j, new_tail in enumerate(pred_data[3 * i + 2][1::2]): + neg_triplet['tail'][0].append([head, entity2id[new_tail], rel]) + if entity2id[new_tail] == tail: + neg_triplet['tail'][1] = j + + neg_triplet['head'][0] = np.array(neg_triplet['head'][0]) + neg_triplet['tail'][0] = np.array(neg_triplet['tail'][0]) + + neg_triplets.append(neg_triplet) + + return neg_triplets + + +def incidence_matrix(adj_list): + ''' + adj_list: List of sparse adjacency matrices + ''' + + rows, cols, dats = [], [], [] + dim = adj_list[0].shape + for adj in adj_list: + adjcoo = adj.tocoo() + rows += adjcoo.row.tolist() + cols += adjcoo.col.tolist() + dats += adjcoo.data.tolist() + row = np.array(rows) + col = np.array(cols) + data = np.array(dats) + return ssp.csc_matrix((data, (row, col)), shape=dim) + + +def _bfs_relational(adj, roots, max_nodes_per_hop=None): + """ + BFS for graphs with multiple edge types. Returns list of level sets. + Each entry in list corresponds to relation specified by adj_list. + Modified from dgl.contrib.data.knowledge_graph to node accomodate sampling + """ + visited = set() + current_lvl = set(roots) + + next_lvl = set() + + while current_lvl: + + for v in current_lvl: + visited.add(v) + + next_lvl = _get_neighbors(adj, current_lvl) + next_lvl -= visited # set difference + + if max_nodes_per_hop and max_nodes_per_hop < len(next_lvl): + next_lvl = set(random.sample(next_lvl, max_nodes_per_hop)) + + yield next_lvl + + current_lvl = set.union(next_lvl) + + +def _get_neighbors(adj, nodes): + """Takes a set of nodes and a graph adjacency matrix and returns a set of neighbors. + Directly copied from dgl.contrib.data.knowledge_graph""" + sp_nodes = _sp_row_vec_from_idx_list(list(nodes), adj.shape[1]) + sp_neighbors = sp_nodes.dot(adj) + neighbors = set(ssp.find(sp_neighbors)[1]) # convert to set of indices + return neighbors + + +def _sp_row_vec_from_idx_list(idx_list, dim): + """Create sparse vector of dimensionality dim from a list of indices.""" + shape = (1, dim) + data = np.ones(len(idx_list)) + row_ind = np.zeros(len(idx_list)) + col_ind = list(idx_list) + return ssp.csr_matrix((data, (row_ind, col_ind)), shape=shape) + + +def get_neighbor_nodes(roots, adj, h=1, max_nodes_per_hop=None): + bfs_generator = _bfs_relational(adj, roots, max_nodes_per_hop) + lvls = list() + for _ in range(h): + try: + lvls.append(next(bfs_generator)) + except StopIteration: + pass + return set().union(*lvls) + + +def subgraph_extraction_labeling(ind, rel, A_list, h=1, enclosing_sub_graph=False, max_nodes_per_hop=None, node_information=None, max_node_label_value=None): + # extract the h-hop enclosing subgraphs around link 'ind' + A_incidence = incidence_matrix(A_list) + A_incidence += A_incidence.T + + # could pack these two into a function + root1_nei = get_neighbor_nodes(set([ind[0]]), A_incidence, h, max_nodes_per_hop) + root2_nei = get_neighbor_nodes(set([ind[1]]), A_incidence, h, max_nodes_per_hop) + + subgraph_nei_nodes_int = root1_nei.intersection(root2_nei) + subgraph_nei_nodes_un = root1_nei.union(root2_nei) + + # Extract subgraph | Roots being in the front is essential for labelling and the model to work properly. + if enclosing_sub_graph: + subgraph_nodes = list(ind) + list(subgraph_nei_nodes_int) + else: + subgraph_nodes = list(ind) + list(subgraph_nei_nodes_un) + + subgraph = [adj[subgraph_nodes, :][:, subgraph_nodes] for adj in A_list] + + labels, enclosing_subgraph_nodes = node_label_new(incidence_matrix(subgraph), max_distance=h) + + pruned_subgraph_nodes = np.array(subgraph_nodes)[enclosing_subgraph_nodes].tolist() + pruned_labels = labels[enclosing_subgraph_nodes] + + if max_node_label_value is not None: + pruned_labels = np.array([np.minimum(label, max_node_label_value).tolist() for label in pruned_labels]) + + return pruned_subgraph_nodes, pruned_labels + + +def remove_nodes(A_incidence, nodes): + idxs_wo_nodes = list(set(range(A_incidence.shape[1])) - set(nodes)) + return A_incidence[idxs_wo_nodes, :][:, idxs_wo_nodes] + + +def node_label_new(subgraph, max_distance=1): + # an implementation of the proposed double-radius node labeling (DRNd L) + roots = [0, 1] + sgs_single_root = [remove_nodes(subgraph, [root]) for root in roots] + dist_to_roots = [np.clip(ssp.csgraph.dijkstra(sg, indices=[0], directed=False, unweighted=True, limit=1e6)[:, 1:], 0, 1e7) for r, sg in enumerate(sgs_single_root)] + dist_to_roots = np.array(list(zip(dist_to_roots[0][0], dist_to_roots[1][0])), dtype=int) + + # dist_to_roots[np.abs(dist_to_roots) > 1e6] = 0 + # dist_to_roots = dist_to_roots + 1 + target_node_labels = np.array([[0, 1], [1, 0]]) + labels = np.concatenate((target_node_labels, dist_to_roots)) if dist_to_roots.size else target_node_labels + + enclosing_subgraph_nodes = np.where(np.max(labels, axis=1) <= max_distance)[0] + # print(len(enclosing_subgraph_nodes)) + return labels, enclosing_subgraph_nodes + + + + + +def prepare_features(subgraph, n_labels, max_n_label, n_feats=None): + # One hot encode the node label feature and concat to n_featsure + n_nodes = subgraph.number_of_nodes() + label_feats = np.zeros((n_nodes, max_n_label[0] + 1 + max_n_label[1] + 1)) + label_feats[np.arange(n_nodes), n_labels[:, 0]] = 1 + label_feats[np.arange(n_nodes), max_n_label[0] + 1 + n_labels[:, 1]] = 1 + n_feats = np.concatenate((label_feats, n_feats), axis=1) if n_feats is not None else label_feats + subgraph.ndata['feat'] = torch.FloatTensor(n_feats) + + head_id = np.argwhere([label[0] == 0 and label[1] == 1 for label in n_labels]) + tail_id = np.argwhere([label[0] == 1 and label[1] == 0 for label in n_labels]) + n_ids = np.zeros(n_nodes) + n_ids[head_id] = 1 # head + n_ids[tail_id] = 2 # tail + subgraph.ndata['id'] = torch.FloatTensor(n_ids) + + return subgraph + + +def get_subgraphs(all_links, adj_list, dgl_adj_list, max_node_label_value, id2entity, node_features=None, kge_entity2id=None): + # dgl_adj_list = ssp_multigraph_to_dgl(adj_list) + + subgraphs = [] + r_labels = [] + + for link in all_links: + head, tail, rel = link[0], link[1], link[2] + nodes, node_labels = subgraph_extraction_labeling((head, tail), rel, adj_list, h=params_.hop, enclosing_sub_graph=params.enclosing_sub_graph, max_node_label_value=max_node_label_value) + + subgraph = dgl_adj_list.subgraph(nodes) + subgraph.edata['type'] = dgl_adj_list.edata['type'][dgl_adj_list.subgraph(nodes).edata[dgl.EID]] + subgraph.edata['label'] = torch.tensor(rel * np.ones(subgraph.edata['type'].shape), dtype=torch.long) + + # edges_btw_roots = subgraph.edge_id(0, 1) + try: + edges_btw_roots = subgraph.edge_ids(0, 1) + edges_btw_roots = torch.tensor([edges_btw_roots]) + except: + edges_btw_roots = torch.tensor([]) + edges_btw_roots = edges_btw_roots.numpy() + rel_link = np.nonzero(subgraph.edata['type'][edges_btw_roots] == rel) + + if rel_link.squeeze().nelement() == 0: + # subgraph.add_edge(0, 1, {'type': torch.tensor([rel]), 'label': torch.tensor([rel])}) + subgraph = dgl.add_edges(subgraph, 0, 1) + subgraph.edata['type'][-1] = torch.tensor(rel).type(torch.LongTensor) + subgraph.edata['label'][-1] = torch.tensor(rel).type(torch.LongTensor) + + kge_nodes = [kge_entity2id[id2entity[n]] for n in nodes] if kge_entity2id else None + n_feats = node_features[kge_nodes] if node_features is not None else None + subgraph = prepare_features(subgraph, node_labels, max_node_label_value, n_feats) + + subgraphs.append(subgraph) + r_labels.append(rel) + + # batched_graph = dgl.batch(subgraphs) + r_labels = torch.LongTensor(r_labels) + + return (subgraphs, r_labels) + + +def get_rank(neg_links): + head_neg_links = neg_links['head'][0] + head_target_id = neg_links['head'][1] + + if head_target_id != 10000: + data = get_subgraphs(head_neg_links, adj_list_, dgl_adj_list_, model_.max_label_value, id2entity_, node_features_, kge_entity2id_) + head_scores = model_(data[0]).squeeze(1).detach().numpy() + head_rank = np.argwhere(np.argsort(head_scores)[::-1] == head_target_id) + 1 + else: + head_scores = np.array([]) + head_rank = 10000 + + tail_neg_links = neg_links['tail'][0] + tail_target_id = neg_links['tail'][1] + + if tail_target_id != 10000: + data = get_subgraphs(tail_neg_links, adj_list_, dgl_adj_list_, params.max_label_value, id2entity_, node_features_, kge_entity2id_) + tail_scores = model_(data[0]).squeeze(1).detach().numpy() + tail_rank = np.argwhere(np.argsort(tail_scores)[::-1] == tail_target_id) + 1 + else: + tail_scores = np.array([]) + tail_rank = 10000 + + return head_scores, head_rank, tail_scores, tail_rank + + +def save_to_file(neg_triplets, id2entity, id2relation): + + with open(os.path.join('./data', params.dataset, 'ranking_head.txt'), "w") as f: + for neg_triplet in neg_triplets: + for s, o, r in neg_triplet['head'][0]: + f.write('\t'.join([id2entity[s], id2relation[r], id2entity[o]]) + '\n') + + with open(os.path.join('./data', params.dataset, 'ranking_tail.txt'), "w") as f: + for neg_triplet in neg_triplets: + for s, o, r in neg_triplet['tail'][0]: + f.write('\t'.join([id2entity[s], id2relation[r], id2entity[o]]) + '\n') + + +def save_score_to_file(neg_triplets, all_head_scores, all_tail_scores, id2entity, id2relation): + + with open(os.path.join('./data', params.dataset, 'grail_ranking_head_predictions.txt'), "w") as f: + for i, neg_triplet in enumerate(neg_triplets): + for [s, o, r], head_score in zip(neg_triplet['head'][0], all_head_scores[50 * i:50 * (i + 1)]): + f.write('\t'.join([id2entity[s], id2relation[r], id2entity[o], str(head_score)]) + '\n') + + with open(os.path.join('./data', params.dataset, 'grail_ranking_tail_predictions.txt'), "w") as f: + for i, neg_triplet in enumerate(neg_triplets): + for [s, o, r], tail_score in zip(neg_triplet['tail'][0], all_tail_scores[50 * i:50 * (i + 1)]): + f.write('\t'.join([id2entity[s], id2relation[r], id2entity[o], str(tail_score)]) + '\n') + + +def save_score_to_file_from_ruleN(neg_triplets, all_head_scores, all_tail_scores, id2entity, id2relation): + + with open(os.path.join('./data', params.dataset, 'grail_ruleN_ranking_head_predictions.txt'), "w") as f: + for i, neg_triplet in enumerate(neg_triplets): + for [s, o, r], head_score in zip(neg_triplet['head'][0], all_head_scores[50 * i:50 * (i + 1)]): + f.write('\t'.join([id2entity[s], id2relation[r], id2entity[o], str(head_score)]) + '\n') + + with open(os.path.join('./data', params.dataset, 'grail_ruleN_ranking_tail_predictions.txt'), "w") as f: + for i, neg_triplet in enumerate(neg_triplets): + for [s, o, r], tail_score in zip(neg_triplet['tail'][0], all_tail_scores[50 * i:50 * (i + 1)]): + f.write('\t'.join([id2entity[s], id2relation[r], id2entity[o], str(tail_score)]) + '\n') + +import json +def get_kge_embeddings(dataset, kge_model): + + path = './experiments/kge_baselines/{}_{}'.format(kge_model, dataset) + node_features = np.load(os.path.join(path, 'entity_embedding.npy')) + with open(os.path.join(path, 'id2entity.json')) as json_file: + kge_id2entity = json.load(json_file) + kge_entity2id = {v: int(k) for k, v in kge_id2entity.items()} + + return node_features, kge_entity2id + + + +def eval_rank(logger): + # print(params.file_paths) + model = torch.load(params.model_path, map_location='cpu') + + adj_list, dgl_adj_list, triplets, entity2id, relation2id, id2entity, id2relation = process_files(params.file_paths, model.relation2id, params.add_traspose_rels) + + node_features, kge_entity2id = None, None + + if params.mode == 'sample': + neg_triplets = get_neg_samples_replacing_head_tail(triplets['links'], adj_list) + elif params.mode == 'all': + neg_triplets = get_neg_samples_replacing_head_tail_all(triplets['links'], adj_list) + elif params.mode == 'ruleN': + neg_triplets = get_neg_samples_replacing_head_tail_from_ruleN(params.ruleN_pred_path, entity2id, relation2id) + print(len(neg_triplets)) + ranks = [] + all_head_scores = [] + all_tail_scores = [] + intialize_worker(model, adj_list, dgl_adj_list, id2entity, params, node_features, kge_entity2id) + # with mp.Pool(processes=None, initializer=intialize_worker, initargs=(model, adj_list, dgl_adj_list, id2entity, params, node_features, kge_entity2id)) as p: + # intialize_worker(model, adj_list, dgl_adj_list, id2entity, params, node_features, kge_entity2id) + for link in neg_triplets: + head_scores, head_rank, tail_scores, tail_rank = get_rank(link) + ranks.append(head_rank) + ranks.append(tail_rank) + + all_head_scores += head_scores.tolist() + all_tail_scores += tail_scores.tolist() + + + + + isHit1List = [x for x in ranks if x <= 1] + isHit5List = [x for x in ranks if x <= 5] + isHit10List = [x for x in ranks if x <= 10] + hits_1 = len(isHit1List) / len(ranks) + hits_5 = len(isHit5List) / len(ranks) + hits_10 = len(isHit10List) / len(ranks) + + mrr = np.mean(1 / np.array(ranks)) + + logger.info(f'MRR | Hits@1 | Hits@5 | Hits@10 : {mrr} | {hits_1} | {hits_5} | {hits_10}') \ No newline at end of file diff --git a/openhgnn/trainerflow/DisenKGAT_trainer.py b/openhgnn/trainerflow/DisenKGAT_trainer.py new file mode 100644 index 00000000..df3301d3 --- /dev/null +++ b/openhgnn/trainerflow/DisenKGAT_trainer.py @@ -0,0 +1,743 @@ +from ..models import build_model +from . import BaseFlow, register_flow +from ..tasks import build_task +import os +import sys +import torch +from .demo import * +from ..models import DisenKGAT +sys.path.append(os.path.dirname(os.path.dirname(__file__))) +import numpy as np, sys, os, json, time +from pprint import pprint +import logging, logging.config +from collections import defaultdict as ddict +from ordered_set import OrderedSet +import traceback +from torch.utils.data import DataLoader +np.set_printoptions(precision=4) +from torch.utils.data import Dataset + + + + + + +def get_combined_results(left_results, right_results): + results = {} + count = float(left_results['count']) + + results['left_mr'] = round(left_results['mr'] / count, 5) + results['left_mrr'] = round(left_results['mrr'] / count, 5) + results['right_mr'] = round(right_results['mr'] / count, 5) + results['right_mrr'] = round(right_results['mrr'] / count, 5) + results['mr'] = round((left_results['mr'] + right_results['mr']) / (2 * count), 5) + results['mrr'] = round((left_results['mrr'] + right_results['mrr']) / (2 * count), 5) + + for k in range(10): + results['left_hits@{}'.format(k + 1)] = round(left_results['hits@{}'.format(k + 1)] / count, 5) + results['right_hits@{}'.format(k + 1)] = round(right_results['hits@{}'.format(k + 1)] / count, 5) + results['hits@{}'.format(k + 1)] = round( + (left_results['hits@{}'.format(k + 1)] + right_results['hits@{}'.format(k + 1)]) / (2 * count), 5) + return results + + + + +@register_flow("DisenKGAT_trainer") +class Runner(BaseFlow): + + def load_data(self): + ent_set, rel_set = OrderedSet(), OrderedSet() + + + for split in ['train', 'test', 'valid']: + path = os.path.join(self.raw_dir, split + ".txt") + for line in open(path): + sub, rel, obj = map(str.lower, line.strip().split('\t')) + ent_set.add(sub) + rel_set.add(rel) + ent_set.add(obj) + + self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)} + self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)} + self.rel2id.update({rel + '_reverse': idx + len(self.rel2id) for idx, rel in enumerate(rel_set)}) + + self.id2ent = {idx: ent for ent, idx in self.ent2id.items()} + self.id2rel = {idx: rel for rel, idx in self.rel2id.items()} + + self.p.num_ent = len(self.ent2id) + self.p.num_rel = len(self.rel2id) // 2 + self.p.embed_dim = self.p.k_w * self.p.k_h if self.p.embed_dim is None else self.p.embed_dim + # self.logger.info('num_ent {} num_rel {}'.format(self.p.num_ent, self.p.num_rel)) + self.data = ddict(list) + sr2o = ddict(set) + + for split in ['train', 'test', 'valid']: + + path = os.path.join(self.raw_dir, split + ".txt") + for line in open(path): + sub, rel, obj = map(str.lower, line.strip().split('\t')) + sub, rel, obj = self.ent2id[sub], self.rel2id[rel], self.ent2id[obj] + self.data[split].append((sub, rel, obj)) + + if split == 'train': + sr2o[(sub, rel)].add(obj) + sr2o[(obj, rel + self.p.num_rel)].add(sub) + # self.data: all origin train + valid + test triplets + self.data = dict(self.data) + # self.sr2o: train origin edges and reverse edges + self.sr2o = {k: list(v) for k, v in sr2o.items()} + for split in ['test', 'valid']: + for sub, rel, obj in self.data[split]: + sr2o[(sub, rel)].add(obj) + sr2o[(obj, rel + self.p.num_rel)].add(sub) + + self.sr2o_all = {k: list(v) for k, v in sr2o.items()} + self.triples = ddict(list) + + # for (sub, rel), obj in self.sr2o.items(): + # self.triples['train'].append({'triple': (sub, rel, -1), 'label': self.sr2o[(sub, rel)], 'sub_samp': 1}) + if self.p.strategy == 'one_to_n': + for (sub, rel), obj in self.sr2o.items(): + self.triples['train'].append({'triple': (sub, rel, -1), 'label': self.sr2o[(sub, rel)], 'sub_samp': 1}) + else: + for sub, rel, obj in self.data['train']: + rel_inv = rel + self.p.num_rel + sub_samp = len(self.sr2o[(sub, rel)]) + len(self.sr2o[(obj, rel_inv)]) + sub_samp = np.sqrt(1 / sub_samp) + + self.triples['train'].append( + {'triple': (sub, rel, obj), 'label': self.sr2o[(sub, rel)], 'sub_samp': sub_samp}) + self.triples['train'].append( + {'triple': (obj, rel_inv, sub), 'label': self.sr2o[(obj, rel_inv)], 'sub_samp': sub_samp}) + + for split in ['test', 'valid']: + for sub, rel, obj in self.data[split]: + rel_inv = rel + self.p.num_rel + self.triples['{}_{}'.format(split, 'tail')].append( + {'triple': (sub, rel, obj), 'label': self.sr2o_all[(sub, rel)]}) + self.triples['{}_{}'.format(split, 'head')].append( + {'triple': (obj, rel_inv, sub), 'label': self.sr2o_all[(obj, rel_inv)]}) + + # self.logger.info('{}_{} num is {}'.format(split, 'tail', len(self.triples['{}_{}'.format(split, 'tail')]))) + # self.logger.info('{}_{} num is {}'.format(split, 'head', len(self.triples['{}_{}'.format(split, 'head')]))) + + self.triples = dict(self.triples) + + def get_data_loader(dataset_class, split, batch_size, shuffle=True): + return DataLoader( + dataset_class(self.triples[split], self.p), + batch_size=batch_size, + shuffle=shuffle, + num_workers=max(0, self.p.num_workers), + collate_fn=dataset_class.collate_fn + ) + + self.data_iter = { + 'train': get_data_loader(TrainDataset, 'train', self.p.batch), + 'valid_head': get_data_loader(TestDataset, 'valid_head', self.p.test_batch), + 'valid_tail': get_data_loader(TestDataset, 'valid_tail', self.p.test_batch), + 'test_head': get_data_loader(TestDataset, 'test_head', self.p.test_batch), + 'test_tail': get_data_loader(TestDataset, 'test_tail', self.p.test_batch), + } + # self.logger.info('num_ent {} num_rel {}\n'.format(self.p.num_ent, self.p.num_rel)) + # self.logger.info('train set num is {}\n'.format(len(self.triples['train']))) + # self.logger.info('{}_{} num is {}\n'.format('test', 'tail', len(self.triples['{}_{}'.format('test', 'tail')]))) + # self.logger.info('{}_{} num is {}\n'.format('valid', 'tail', len(self.triples['{}_{}'.format('valid', 'tail')]))) + self.edge_index, self.edge_type = self.construct_adj() + + def construct_adj(self): + """ + Constructor of the runner class + + Parameters + ---------- + + Returns + ------- + Constructs the adjacency matrix for GCN + + """ + edge_index, edge_type = [], [] + + for sub, rel, obj in self.data['train']: + edge_index.append((sub, obj)) + edge_type.append(rel) + + # Adding inverse edges + for sub, rel, obj in self.data['train']: + edge_index.append((obj, sub)) + edge_type.append(rel + self.p.num_rel) + # edge_index: 2 * 2E, edge_type: 2E * 1 + edge_index = torch.LongTensor(edge_index).to(self.device).t() + edge_type = torch.LongTensor(edge_type).to(self.device) + + return edge_index, edge_type + + def __init__(self, args): # args == self.config + + args.task = args.model +"_" +args.task + self.args = args + self.model_name = args.model + self.device = args.device + self.hg = None + self.task = build_task(self.args) + self.raw_dir = self.task.dataset.raw_dir + self.process_dir = self.task.dataset.raw_dir + self.p = args + self.logger = logging.getLogger(__file__) + pprint(vars(self.p)) + + if self.p.gpu != '-1' and torch.cuda.is_available(): + self.device = torch.device('cuda') + torch.cuda.set_rng_state(torch.cuda.get_rng_state()) + torch.backends.cudnn.deterministic = True + else: + self.device = torch.device('cpu') + + self.load_data() + self.model = self.add_model(self.p.model, self.p.score_func) # disenkgat , interacte + self.optimizer, self.optimizer_mi = self.add_optimizer(self.model) + if not args.restore: + args.name = args.name + '_' + time.strftime('%d_%m_%Y') + '_' + time.strftime('%H:%M:%S') + set_gpu(args.gpu) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + + def add_model(self, model, score_func): + """ + """ + + model_name = '{}_{}'.format(model, score_func) + + if model_name.lower() == 'disenkgat_transe': + model = DisenKGAT.DisenKGAT_TransE(edge_index=self.edge_index, + edge_type=self.edge_type, + params=self.p) + + elif model_name.lower() == 'disenkgat_distmult': + model = DisenKGAT.DisenKGAT_DistMult(self.edge_index, self.edge_type, params=self.p) + elif model_name.lower() == 'disenkgat_conve': + model = DisenKGAT.DisenKGAT_ConvE(self.edge_index, self.edge_type, params=self.p) + elif model_name.lower() == 'disenkgat_interacte': + model = DisenKGAT.DisenKGAT_InteractE(self.edge_index, self.edge_type, params=self.p) + else: + raise NotImplementedError + + model.to(self.device) + return model + + def add_optimizer(self, model): + """ + Creates an optimizer for training the parameters + + Parameters + ---------- + parameters: The parameters of the model + + Returns + ------- + Returns an optimizer for learning the parameters of the model + + """ + if self.p.mi_train and self.p.mi_method.startswith('club'): + mi_disc_params = list(map(id,model.mi_Discs.parameters())) + rest_params = filter(lambda x:id(x) not in mi_disc_params, model.parameters()) + for m in model.mi_Discs.modules(): + self.logger.info(m) + for name, parameters in model.named_parameters(): + print(name,':',parameters.size()) + return torch.optim.Adam(rest_params, lr=self.p.lr, weight_decay=self.p.l2), torch.optim.Adam(model.mi_Discs.parameters(), lr=self.p.lr, weight_decay=self.p.l2) + else: + return torch.optim.Adam(model.parameters(), lr=self.p.lr, weight_decay=self.p.l2), None + + def read_batch(self, batch, split): + """ + Function to read a batch of data and move the tensors in batch to CPU/GPU + + Parameters + ---------- + batch: the batch to process + split: (string) If split == 'train', 'valid' or 'test' split + + + Returns + ------- + Head, Relation, Tails, labels + """ + # if split == 'train': + # triple, label = [_.to(self.device) for _ in batch] + # return triple[:, 0], triple[:, 1], triple[:, 2], label + # else: + # triple, label = [_.to(self.device) for _ in batch] + # return triple[:, 0], triple[:, 1], triple[:, 2], label + if split == 'train': + if self.p.strategy == 'one_to_x': + triple, label, neg_ent, sub_samp = [_.to(self.device) for _ in batch] + return triple[:, 0], triple[:, 1], triple[:, 2], label, neg_ent, sub_samp + else: + triple, label = [_.to(self.device) for _ in batch] + return triple[:, 0], triple[:, 1], triple[:, 2], label, None, None + else: + triple, label = [_.to(self.device) for _ in batch] + return triple[:, 0], triple[:, 1], triple[:, 2], label + + def save_model(self, save_path): + """ + Function to save a model. It saves the model parameters, best validation scores, + best epoch corresponding to best validation, state of the optimizer and all arguments for the run. + + Parameters + ---------- + save_path: path where the model is saved + + Returns + ------- + """ + state = { + 'state_dict': self.model.state_dict(), + 'best_val': self.best_val, + 'best_epoch': self.best_epoch, + 'optimizer': self.optimizer.state_dict(), + 'args': vars(self.p) + } + torch.save(state, save_path) + + def load_model(self, load_path): + """ + Function to load a saved model + + Parameters + ---------- + load_path: path to the saved model + + Returns + ------- + """ + state = torch.load(load_path) + state_dict = state['state_dict'] + self.best_val = state['best_val'] + self.best_val_mrr = self.best_val['mrr'] + + self.model.load_state_dict(state_dict) + self.optimizer.load_state_dict(state['optimizer']) + + def evaluate(self, split, epoch): + """ + Function to evaluate the model on validation or test set + + Parameters + ---------- + split: (string) If split == 'valid' then evaluate on the validation set, else the test set + epoch: (int) Current epoch count + + Returns + ------- + resutls: The evaluation results containing the following: + results['mr']: Average of ranks_left and ranks_right + results['mrr']: Mean Reciprocal Rank + results['hits@k']: Probability of getting the correct preodiction in top-k ranks based on predicted score + + """ + left_results = self.predict(split=split, mode='tail_batch') + right_results = self.predict(split=split, mode='head_batch') + results = get_combined_results(left_results, right_results) + res_mrr = '\n\tMRR: Tail : {:.5}, Head : {:.5}, Avg : {:.5}\n'.format(results['left_mrr'], + results['right_mrr'], + results['mrr']) + res_mr = '\tMR: Tail : {:.5}, Head : {:.5}, Avg : {:.5}\n'.format(results['left_mr'], + results['right_mr'], + results['mr']) + res_hit1 = '\tHit-1: Tail : {:.5}, Head : {:.5}, Avg : {:.5}\n'.format(results['left_hits@1'], + results['right_hits@1'], + results['hits@1']) + res_hit3 = '\tHit-3: Tail : {:.5}, Head : {:.5}, Avg : {:.5}\n'.format(results['left_hits@3'], + results['right_hits@3'], + results['hits@3']) + res_hit10 = '\tHit-10: Tail : {:.5}, Head : {:.5}, Avg : {:.5}'.format(results['left_hits@10'], + results['right_hits@10'], + results['hits@10']) + log_res = res_mrr + res_mr + res_hit1 + res_hit3 + res_hit10 + if (epoch + 1) % 10 == 0 or split == 'test': + self.logger.info( + '[Evaluating Epoch {} {}]: {}'.format(epoch, split, log_res)) + else: + self.logger.info( + '[Evaluating Epoch {} {}]: {}'.format(epoch, split, res_mrr)) + + return results + + def predict(self, split='valid', mode='tail_batch'): + """ + Function to run model evaluation for a given mode + + Parameters + ---------- + split: (string) If split == 'valid' then evaluate on the validation set, else the test set + mode: (string): Can be 'head_batch' or 'tail_batch' + + Returns + ------- + resutls: The evaluation results containing the following: + results['mr']: Average of ranks_left and ranks_right + results['mrr']: Mean Reciprocal Rank + results['hits@k']: Probability of getting the correct preodiction in top-k ranks based on predicted score + + """ + self.model.eval() + + with torch.no_grad(): + results = {} + train_iter = iter(self.data_iter['{}_{}'.format(split, mode.split('_')[0])]) + + for step, batch in enumerate(train_iter): + sub, rel, obj, label = self.read_batch(batch, split) + pred, _ = self.model.forward(sub, rel, None, split) + b_range = torch.arange(pred.size()[0], device=self.device) + target_pred = pred[b_range, obj] + # filter setting + pred = torch.where(label.byte(), -torch.ones_like(pred) * 10000000, pred) + pred[b_range, obj] = target_pred + ranks = 1 + torch.argsort(torch.argsort(pred, dim=1, descending=True), dim=1, descending=False)[ + b_range, obj] + + ranks = ranks.float() + results['count'] = torch.numel(ranks) + results.get('count', 0.0) + results['mr'] = torch.sum(ranks).item() + results.get('mr', 0.0) + results['mrr'] = torch.sum(1.0 / ranks).item() + results.get('mrr', 0.0) + for k in range(10): + results['hits@{}'.format(k + 1)] = torch.numel(ranks[ranks <= (k + 1)]) + results.get( + 'hits@{}'.format(k + 1), 0.0) + + # if step % 100 == 0: + # self.logger.info('[{}, {} Step {}]\t{}'.format(split.title(), mode.title(), step, self.p.name)) + + return results + + def run_epoch(self, epoch, val_mrr=0): + """ + Function to run one epoch of training + + Parameters + ---------- + epoch: current epoch count + + Returns + ------- + loss: The loss value after the completion of one epoch + """ + self.model.train() + losses = [] + losses_train = [] + corr_losses = [] + lld_losses = [] + train_iter = iter(self.data_iter['train']) + + for step, batch in enumerate(train_iter): + self.optimizer.zero_grad() + if self.p.mi_train and self.p.mi_method.startswith('club'): + self.model.mi_Discs.eval() + sub, rel, obj, label, neg_ent, sub_samp = self.read_batch(batch, 'train') + + pred, corr = self.model.forward(sub, rel, neg_ent, 'train') + + loss = self.model.loss(pred, label) + if self.p.mi_train: + losses_train.append(loss.item()) + loss = loss + self.p.alpha * corr + corr_losses.append(corr.item()) + + loss.backward() + self.optimizer.step() + losses.append(loss.item()) + # start to compute mi_loss + if self.p.mi_train and self.p.mi_method.startswith('club'): + for i in range(self.p.mi_epoch): + self.model.mi_Discs.train() + lld_loss = self.model.lld_best(sub, rel) + self.optimizer_mi.zero_grad() + lld_loss.backward() + self.optimizer_mi.step() + lld_losses.append(lld_loss.item()) + + if step % 100 == 0: + if self.p.mi_train: + self.logger.info('[E:{}| {}]: total Loss:{:.5}, Train Loss:{:.5}, Corr Loss:{:.5}, Val MRR:{:.5}\t{}'.format(epoch, step, np.mean(losses), + np.mean(losses_train), np.mean(corr_losses), self.best_val_mrr, + self.p.name)) + else: + self.logger.info('[E:{}| {}]: Train Loss:{:.5}, Val MRR:{:.5}\t{}'.format(epoch, step, np.mean(losses), + self.best_val_mrr, + self.p.name)) + + loss = np.mean(losses_train) if self.p.mi_train else np.mean(losses) + if self.p.mi_train: + loss_corr = np.mean(corr_losses) + if self.p.mi_method.startswith('club') and self.p.mi_epoch == 1: + loss_lld = np.mean(lld_losses) + return loss, loss_corr, loss_lld + return loss, loss_corr, 0. + # self.logger.info('[Epoch:{}]: Training Loss:{:.4}\n'.format(epoch, loss)) + return loss, 0., 0. + + def train(self): + """ + Function to run training and evaluation of model + + """ + try: + self.best_val_mrr, self.best_val, self.best_epoch, val_mrr = 0., {}, 0, 0. + save_path = os.path.join('./save_models', self.p.name) + + if not os.path.exists('./save_models'): # 如果 节点目录不存在,则创建目录 + os.makedirs('./save_models') + + # if self.p.restore: + # self.load_model(save_path) + # self.logger.info('Successfully Loaded previous model') + + val_results = {} + val_results['mrr'] = 0 + kill_cnt = 0 + for epoch in range(self.p.epoch): + train_loss, corr_loss, lld_loss = self.run_epoch(epoch, val_mrr) + val_results = self.evaluate('valid', epoch) + + if val_results['mrr'] > self.best_val_mrr: + self.best_val = val_results + self.best_val_mrr = val_results['mrr'] + self.best_epoch = epoch + self.save_model(save_path) + kill_cnt = 0 + else: + kill_cnt += 1 + if kill_cnt % 10 == 0 and self.p.gamma > self.p.max_gamma: + self.p.gamma -= 5 + # self.logger.info('Gamma decay on saturation, updated value of gamma: {}'.format(self.p.gamma)) + if kill_cnt > self.p.early_stop: + self.logger.info("Early Stopping!!") + break + if self.p.mi_train: + if self.p.mi_method == 'club_s' or self.p.mi_method == 'club_b': + self.logger.info( + '[Epoch {}]: Training Loss: {:.5}, corr Loss: {:.5}, lld loss :{:.5}, Best valid MRR: {:.5}\n\n'.format(epoch, train_loss, corr_loss, + lld_loss, self.best_val_mrr)) + else: + self.logger.info( + '[Epoch {}]: Training Loss: {:.5}, corr Loss: {:.5}, Best valid MRR: {:.5}\n\n'.format(epoch, train_loss, corr_loss, + self.best_val_mrr)) + else: + self.logger.info( + '[Epoch {}]: Training Loss: {:.5}, Best valid MRR: {:.5}\n\n'.format(epoch, train_loss, + self.best_val_mrr)) + + + self.logger.info('Loading best model, Evaluating on Test data') + self.load_model(save_path) + test_results = self.evaluate('test', self.best_epoch) + except Exception as e: + self.logger.debug("%s____%s\n" + "traceback.format_exc():____%s" % (Exception, e, traceback.format_exc())) + + + + + + +def get_logger(name, log_dir, config_dir): + """ + Creates a logger object + + Parameters + ---------- + name: Name of the logger file + log_dir: Directory where logger file needs to be stored + config_dir: Directory from where log_config.json needs to be read + + Returns + ------- + A logger object which writes to both file and stdout + + """ + config_dict = json.load(open(config_dir + 'log_config.json')) + config_dict['handlers']['file_handler']['filename'] = log_dir + name.replace('/', '-') + logging.config.dictConfig(config_dict) + logger = logging.getLogger(name) + + std_out_format = '%(asctime)s - [%(levelname)s] - %(message)s' + + consoleHandler = logging.StreamHandler(sys.stdout) + consoleHandler.setFormatter(logging.Formatter(std_out_format)) + logger.addHandler(consoleHandler) + + return logger + + + +def set_gpu(gpus): + """ + Sets the GPU to be used for the run + + Parameters + ---------- + gpus: List of GPUs to be used for the run + + Returns + ------- + + """ + + gpus = str(gpus) + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = gpus + + + +class TrainDataset(Dataset): + """ + Training Dataset class. + + Parameters + ---------- + triples: The triples used for training the model + params: Parameters for the experiments + + Returns + ------- + A training Dataset class instance used by DataLoader + """ + + def __init__(self, triples, params): + self.triples = triples + self.p = params + self.entities = np.arange(self.p.num_ent, dtype=np.int32) + + def __len__(self): + return len(self.triples) + + def __getitem__(self, idx): + ele = self.triples[idx] + triple, label, sub_samp = torch.LongTensor(ele['triple']), np.int32(ele['label']), np.float32(ele['sub_samp']) + trp_label = self.get_label(label) + + if self.p.lbl_smooth != 0.0: + trp_label = (1.0 - self.p.lbl_smooth) * trp_label + (1.0 / self.p.num_ent) + + if self.p.strategy == 'one_to_n': + return triple, trp_label, None, None + + elif self.p.strategy == 'one_to_x': + sub_samp = torch.FloatTensor([sub_samp]) + neg_ent = torch.LongTensor(self.get_neg_ent(triple, label)) + return triple, trp_label, neg_ent, sub_samp + else: + raise NotImplementedError + + # return triple, trp_label, None, None + + @staticmethod + def collate_fn(data): + triple = torch.stack([_[0] for _ in data], dim=0) + trp_label = torch.stack([_[1] for _ in data], dim=0) + # triple: (batch-size) * 3(sub, rel, -1) trp_label (batch-size) * num entity + # return triple, trp_label + if not data[0][2] is None: # one_to_x + neg_ent = torch.stack([_[2] for _ in data], dim=0) + sub_samp = torch.cat([_[3] for _ in data], dim=0) + return triple, trp_label, neg_ent, sub_samp + else: + return triple, trp_label + + # def get_neg_ent(self, triple, label): + # def get(triple, label): + # pos_obj = label + # mask = np.ones([self.p.num_ent], dtype=np.bool) + # mask[label] = 0 + # neg_ent = np.int32( + # np.random.choice(self.entities[mask], self.p.neg_num - len(label), replace=False)).reshape([-1]) + # neg_ent = np.concatenate((pos_obj.reshape([-1]), neg_ent)) + # + # return neg_ent + # + # neg_ent = get(triple, label) + # return neg_ent + def get_neg_ent(self, triple, label): + def get(triple, label): + if self.p.strategy == 'one_to_x': + pos_obj = triple[2] + mask = np.ones([self.p.num_ent], dtype=np.bool) + mask[label] = 0 + neg_ent = np.int32(np.random.choice(self.entities[mask], self.p.neg_num, replace=False)).reshape([-1]) + neg_ent = np.concatenate((pos_obj.reshape([-1]), neg_ent)) + else: + pos_obj = label + mask = np.ones([self.p.num_ent], dtype=np.bool) + mask[label] = 0 + neg_ent = np.int32( + np.random.choice(self.entities[mask], self.p.neg_num - len(label), replace=False)).reshape([-1]) + neg_ent = np.concatenate((pos_obj.reshape([-1]), neg_ent)) + + if len(neg_ent) > self.p.neg_num: + import pdb; + pdb.set_trace() + + return neg_ent + + neg_ent = get(triple, label) + return neg_ent + + def get_label(self, label): + # y = np.zeros([self.p.num_ent], dtype=np.float32) + # for e2 in label: y[e2] = 1.0 + # return torch.FloatTensor(y) + if self.p.strategy == 'one_to_n': + y = np.zeros([self.p.num_ent], dtype=np.float32) + for e2 in label: y[e2] = 1.0 + elif self.p.strategy == 'one_to_x': + y = [1] + [0] * self.p.neg_num + else: + raise NotImplementedError + return torch.FloatTensor(y) + +class TestDataset(Dataset): + """ + Evaluation Dataset class. + + Parameters + ---------- + triples: The triples used for evaluating the model + params: Parameters for the experiments + + Returns + ------- + An evaluation Dataset class instance used by DataLoader for model evaluation + """ + + def __init__(self, triples, params): + self.triples = triples + self.p = params + + def __len__(self): + return len(self.triples) + + def __getitem__(self, idx): + ele = self.triples[idx] + triple, label = torch.LongTensor(ele['triple']), np.int32(ele['label']) + label = self.get_label(label) + + return triple, label + + @staticmethod + def collate_fn(data): + triple = torch.stack([_[0] for _ in data], dim=0) + label = torch.stack([_[1] for _ in data], dim=0) + return triple, label + + def get_label(self, label): + y = np.zeros([self.p.num_ent], dtype=np.float32) + for e2 in label: y[e2] = 1.0 + return torch.FloatTensor(y) + diff --git a/openhgnn/trainerflow/ExpressGNN_trainer.py b/openhgnn/trainerflow/ExpressGNN_trainer.py new file mode 100644 index 00000000..af94f659 --- /dev/null +++ b/openhgnn/trainerflow/ExpressGNN_trainer.py @@ -0,0 +1,599 @@ +import torch +import torch.optim as optim +import networkx as nx +from itertools import product +from tqdm import tqdm +from itertools import chain +import numpy as np +from sklearn.metrics import roc_auc_score, average_precision_score +import math +from collections import Counter +from . import BaseFlow, register_flow +from ..models import build_model +from ..utils import EarlyStopping + + +class KnowledgeGraph(object): + def __init__(self, facts, predicates, dataset): + self.dataset = dataset + self.PRED_DICT = predicates + self.graph, self.edge_type2idx, \ + self.ent2idx, self.idx2ent, self.rel2idx, self.idx2rel, \ + self.node2idx, self.idx2node = self.gen_graph(facts, dataset) + + self.num_ents = len(self.ent2idx) + self.num_rels = len(self.rel2idx) + + self.num_nodes = len(self.graph.nodes()) + self.num_edges = len(self.graph.edges()) + + x, y, v = zip(*sorted(self.graph.edges(data=True), key=lambda t: t[:2])) + self.edge_types = [d['edge_type'] for d in v] + self.edge_pairs = np.ndarray(shape=(self.num_edges, 2), dtype=np.int64) + self.edge_pairs[:, 0] = x + self.edge_pairs[:, 1] = y + + self.idx2edge = dict() + idx = 0 + for x, y in self.edge_pairs: + self.idx2edge[idx] = (self.idx2node[x], self.idx2node[y]) + idx += 1 + self.idx2edge[idx] = (self.idx2node[y], self.idx2node[x]) + idx += 1 + + def gen_graph(self, facts, dataset): + """ + generate directed knowledge graph, where each edge is from subject to object + :param facts: + dictionary of facts + :param dataset: + dataset object + :return: + graph object, entity to index, index to entity, relation to index, index to relation + """ + + # build bipartite graph (constant nodes and hyper predicate nodes) + g = nx.Graph() + ent2idx, idx2ent, rel2idx, idx2rel, node2idx, idx2node = self.gen_index(facts, dataset) + + edge_type2idx = self.gen_edge_type() + + for node_idx in idx2node: + g.add_node(node_idx) + + for rel in facts.keys(): + for fact in facts[rel]: + val, args = fact + fact_node_idx = node2idx[(rel, args)] + for arg in args: + pos_code = ''.join(['%d' % (arg == v) for v in args]) + g.add_edge(fact_node_idx, node2idx[arg], + edge_type=edge_type2idx[(val, pos_code)]) + return g, edge_type2idx, ent2idx, idx2ent, rel2idx, idx2rel, node2idx, idx2node + + def gen_index(self, facts, dataset): + rel2idx = dict() + idx_rel = 0 + for rel in sorted(self.PRED_DICT.keys()): + if rel not in rel2idx: + rel2idx[rel] = idx_rel + idx_rel += 1 + idx2rel = dict(zip(rel2idx.values(), rel2idx.keys())) + + ent2idx = dict() + idx_ent = 0 + for type_name in sorted(dataset.const_sort_dict.keys()): + for const in dataset.const_sort_dict[type_name]: + ent2idx[const] = idx_ent + idx_ent += 1 + idx2ent = dict(zip(ent2idx.values(), ent2idx.keys())) + + node2idx = ent2idx.copy() + idx_node = len(node2idx) + for rel in sorted(facts.keys()): + for fact in sorted(list(facts[rel])): + val, args = fact + if (rel, args) not in node2idx: + node2idx[(rel, args)] = idx_node + idx_node += 1 + idx2node = dict(zip(node2idx.values(), node2idx.keys())) + + return ent2idx, idx2ent, rel2idx, idx2rel, node2idx, idx2node + + def gen_edge_type(self): + edge_type2idx = dict() + num_args_set = set() + for rel in self.PRED_DICT: + num_args = self.PRED_DICT[rel].num_args + num_args_set.add(num_args) + idx = 0 + for num_args in sorted(list(num_args_set)): + for pos_code in product(['0', '1'], repeat=num_args): + if '1' in pos_code: + edge_type2idx[(0, ''.join(pos_code))] = idx + idx += 1 + edge_type2idx[(1, ''.join(pos_code))] = idx + idx += 1 + return edge_type2idx + + +@register_flow("ExpressGNN_trainer") +class ExpressGNNTrainer(BaseFlow): + + def __init__(self, args): + super(ExpressGNNTrainer, self).__init__(args) + + + self.model_name = args.model + self.device = args.device + self.dataset = self.task.dataset + self.args.rule_list = self.dataset.rule_ls + self.kg = KnowledgeGraph(self.dataset.fact_dict, self.dataset.PRED_DICT, self.dataset) + self.args.PRED_DICT = self.dataset.PRED_DICT + self.model = build_model(self.model).build_model_from_args(self.args, self.kg).to(self.device) + + self.stopper = EarlyStopping(self.args.patience, self._checkpoint) + self.scheduler = None + self.pred_aggregated_hid_args = dict() + self.preprocess() + + def preprocess(self): + all_params = chain.from_iterable([self.model.parameters()]) + self.optimizer = optim.Adam(all_params, lr=self.args.learning_rate, weight_decay=self.args.l2_coef) + self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'max', factor=self.args.lr_decay_factor, + patience=self.args.lr_decay_patience, + min_lr=self.args.lr_decay_min) + if self.args.no_train == 1: + self.args.num_epochs = 0 + + # for Freebase data + if self.args.load_method == 1: + # prepare data for M-step + tqdm.write('preparing data for M-step...') + pred_arg1_set_arg2 = dict() + pred_arg2_set_arg1 = dict() + pred_fact_set = dict() + for pred in self.dataset.fact_dict_2: + pred_arg1_set_arg2[pred] = dict() + pred_arg2_set_arg1[pred] = dict() + pred_fact_set[pred] = set() + for _, args in self.dataset.fact_dict_2[pred]: + if args[0] not in pred_arg1_set_arg2[pred]: + pred_arg1_set_arg2[pred][args[0]] = set() + if args[1] not in pred_arg2_set_arg1[pred]: + pred_arg2_set_arg1[pred][args[1]] = set() + pred_arg1_set_arg2[pred][args[0]].add(args[1]) + pred_arg2_set_arg1[pred][args[1]].add(args[0]) + pred_fact_set[pred].add(args) + grounded_rules = [] + for rule_idx, rule in enumerate(self.dataset.rule_ls): + grounded_rules.append(set()) + body_atoms = [] + head_atom = None + for atom in rule.atom_ls: + if atom.neg: + body_atoms.append(atom) + elif head_atom is None: + head_atom = atom + # atom in body must be observed + assert len(body_atoms) <= 2 + if len(body_atoms) > 0: + body1 = body_atoms[0] + for _, body1_args in self.dataset.fact_dict_2[body1.pred_name]: + var2arg = dict() + var2arg[body1.var_name_ls[0]] = body1_args[0] + var2arg[body1.var_name_ls[1]] = body1_args[1] + for body2 in body_atoms[1:]: + if body2.var_name_ls[0] in var2arg: + if var2arg[body2.var_name_ls[0]] in pred_arg1_set_arg2[body2.pred_name]: + for body2_arg2 in pred_arg1_set_arg2[body2.pred_name][ + var2arg[body2.var_name_ls[0]]]: + var2arg[body2.var_name_ls[1]] = body2_arg2 + grounded_rules[rule_idx].add(tuple(sorted(var2arg.items()))) + elif body2.var_name_ls[1] in var2arg: + if var2arg[body2.var_name_ls[1]] in pred_arg2_set_arg1[body2.pred_name]: + for body2_arg1 in pred_arg2_set_arg1[body2.pred_name][ + var2arg[body2.var_name_ls[1]]]: + var2arg[body2.var_name_ls[0]] = body2_arg1 + grounded_rules[rule_idx].add(tuple(sorted(var2arg.items()))) + # Collect head atoms derived by grounded formulas + self.grounded_obs = dict() + self.grounded_hid = dict() + self.grounded_hid_score = dict() + for rule_idx in range(len(self.dataset.rule_ls)): + rule = self.dataset.rule_ls[rule_idx] + for var2arg in grounded_rules[rule_idx]: + var2arg = dict(var2arg) + head_atom = rule.atom_ls[-1] + assert not head_atom.neg # head atom + pred = head_atom.pred_name + args = (var2arg[head_atom.var_name_ls[0]], var2arg[head_atom.var_name_ls[1]]) + if args in pred_fact_set[pred]: + if (pred, args) not in self.grounded_obs: + self.grounded_obs[(pred, args)] = [] + self.grounded_obs[(pred, args)].append(rule_idx) + else: + if (pred, args) not in self.grounded_hid: + self.grounded_hid[(pred, args)] = [] + self.grounded_hid[(pred, args)].append(rule_idx) + + tqdm.write('observed: %d, hidden: %d' % (len(self.grounded_obs), len(self.grounded_hid))) + + # Aggregate atoms by predicates for fast inference + pred_aggregated_hid = dict() + for (pred, args) in self.grounded_hid: + if pred not in pred_aggregated_hid: + pred_aggregated_hid[pred] = [] + if pred not in self.pred_aggregated_hid_args: + self.pred_aggregated_hid_args[pred] = [] + pred_aggregated_hid[pred].append((self.dataset.const2ind[args[0]], self.dataset.const2ind[args[1]])) + self.pred_aggregated_hid_args[pred].append(args) + self.pred_aggregated_hid_list = [[pred, pred_aggregated_hid[pred]] for pred in + sorted(pred_aggregated_hid.keys())] + + def train(self): + if self.args.load_method == 1: + for current_epoch in range(self.args.num_epochs): + num_batches = int(math.ceil(len(self.dataset.test_fact_ls) / self.args.batchsize)) + pbar = tqdm(total=num_batches) + acc_loss = 0.0 + cur_batch = 0 + + for samples_by_r, latent_mask_by_r, neg_mask_by_r, obs_var_by_r, neg_var_by_r in \ + self.dataset.get_batch_by_q(self.args.batchsize): + + node_embeds = self.model.gcn_forward(self.dataset) + + loss = 0.0 + r_cnt = 0 + for ind, samples in enumerate(samples_by_r): + neg_mask = neg_mask_by_r[ind] + latent_mask = latent_mask_by_r[ind] + obs_var = obs_var_by_r[ind] + neg_var = neg_var_by_r[ind] + + if sum([len(e[1]) for e in neg_mask]) == 0: + continue + + potential, posterior_prob, obs_xent = self.model.posterior_forward( + [samples, neg_mask, latent_mask, + obs_var, neg_var], + node_embeds, fast_mode=True) + if self.args.no_entropy == 1: + entropy = 0 + else: + entropy = compute_entropy(posterior_prob) / self.args.entropy_temp + + loss += - (potential.sum() * self.dataset.rule_ls[ind].weight + entropy) / ( + potential.size(0) + 1e-6) + obs_xent + + r_cnt += 1 + + if r_cnt > 0: + loss /= r_cnt + acc_loss += loss.item() + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + pbar.update() + cur_batch += 1 + pbar.set_description( + 'Epoch %d, train loss: %.4f, lr: %.4g' % ( + current_epoch, acc_loss / cur_batch, get_lr(self.optimizer))) + # M-step: optimize the weights of logic rules + with torch.no_grad(): + posterior_prob = self.model.posterior_forward(self.pred_aggregated_hid_list, node_embeds, + fast_inference_mode=True) + for pred_i, (pred, var_ls) in enumerate(self.pred_aggregated_hid_list): + for var_i, var in enumerate(var_ls): + args = self.pred_aggregated_hid_args[pred][var_i] + self.grounded_hid_score[(pred, args)] = posterior_prob[pred_i][var_i] + + rule_weight_gradient = torch.zeros(len(self.dataset.rule_ls), device=self.args.device) + for (pred, args) in self.grounded_obs: + for rule_idx in set(self.grounded_obs[(pred, args)]): + rule_weight_gradient[rule_idx] += 1.0 - compute_MB_proba(self.dataset.rule_ls, + self.grounded_obs[(pred, args)]) + for (pred, args) in self.grounded_hid: + for rule_idx in set(self.grounded_hid[(pred, args)]): + target = self.grounded_hid_score[(pred, args)] + rule_weight_gradient[rule_idx] += target - compute_MB_proba(self.dataset.rule_ls, + self.grounded_hid[(pred, args)]) + + for rule_idx, rule in enumerate(self.dataset.rule_ls): + rule.weight += self.args.learning_rate_rule_weights * rule_weight_gradient[rule_idx] + # print(self.dataset.rule_ls[rule_idx].weight, end=' ') + pbar.close() + # validation + with torch.no_grad(): + node_embeds = self.model.gcn_forward(self.dataset) + + valid_loss = 0.0 + cnt_batch = 0 + for samples_by_r, latent_mask_by_r, neg_mask_by_r, obs_var_by_r, neg_var_by_r in \ + self.dataset.get_batch_by_q(self.args.batchsize, validation=True): + loss = 0.0 + r_cnt = 0 + for ind, samples in enumerate(samples_by_r): + neg_mask = neg_mask_by_r[ind] + latent_mask = latent_mask_by_r[ind] + obs_var = obs_var_by_r[ind] + neg_var = neg_var_by_r[ind] + + if sum([len(e[1]) for e in neg_mask]) == 0: + continue + + valid_potential, valid_prob, valid_obs_xent = self.model.posterior_forward( + [samples, neg_mask, latent_mask, + obs_var, neg_var], + node_embeds, fast_mode=True) + + if self.args.no_entropy == 1: + valid_entropy = 0 + else: + valid_entropy = compute_entropy(valid_prob) / self.args.entropy_temp + + loss += - (valid_potential.sum() + valid_entropy) / ( + valid_potential.size(0) + 1e-6) + valid_obs_xent + + r_cnt += 1 + + if r_cnt > 0: + loss /= r_cnt + valid_loss += loss.item() + + cnt_batch += 1 + + tqdm.write('Epoch %d, valid loss: %.4f' % (current_epoch, valid_loss / cnt_batch)) + + should_stop = self.stopper.loss_step(valid_loss, self.model) + self.scheduler.step(valid_loss) + + is_current_best = self.stopper.counter == 0 + if is_current_best: + self.stopper.save_model(self.model) + + should_stop = should_stop or (current_epoch + 1 == self.args.num_epochs) + + if should_stop: + tqdm.write('Early stopping') + break + + # ======================= generate rank list ======================= + print("rank_list", current_epoch) + node_embeds = self.model.gcn_forward(self.dataset) + + pbar = tqdm(total=len(self.dataset.test_fact_ls)) + pbar.write('\n' + '*' * 10 + ' Evaluation ' + '*' * 10) + rrank = 0.0 + hits = 0.0 + cnt = 0 + + rrank_pred = dict([(pred_name, 0.0) for pred_name in self.kg.PRED_DICT]) + hits_pred = dict([(pred_name, 0.0) for pred_name in self.kg.PRED_DICT]) + cnt_pred = dict([(pred_name, 0.0) for pred_name in self.kg.PRED_DICT]) + for pred_name, X, invX, sample in gen_eval_query(self.dataset, const2ind=self.kg.ent2idx): + x_mat = np.array(X) + invx_mat = np.array(invX) + sample_mat = np.array(sample) + + tail_score, head_score, true_score = self.model.posterior_forward( + [pred_name, x_mat, invx_mat, sample_mat], + node_embeds, + batch_mode=True) + + rank = torch.sum(tail_score >= true_score).item() + 1 + rrank += 1.0 / rank + hits += 1 if rank <= 10 else 0 + + rrank_pred[pred_name] += 1.0 / rank + hits_pred[pred_name] += 1 if rank <= 10 else 0 + + rank = torch.sum(head_score >= true_score).item() + 1 + rrank += 1.0 / rank + hits += 1 if rank <= 10 else 0 + + rrank_pred[pred_name] += 1.0 / rank + hits_pred[pred_name] += 1 if rank <= 10 else 0 + + cnt_pred[pred_name] += 2 + cnt += 2 + + pbar.update() + pbar.close() + self.logger.info('\ncomplete:\n mmr %.4f\n' % (rrank / cnt) + 'hits %.4f\n' % (hits / cnt)) + for pred_name in self.kg.PRED_DICT: + if cnt_pred[pred_name] == 0: + continue + self.logger.info('mmr %s %.4f\n' % (pred_name, rrank_pred[pred_name] / cnt_pred[pred_name])) + self.logger.info('hits %s %.4f\n' % (pred_name, hits_pred[pred_name] / cnt_pred[pred_name])) + + + # for Kinship / UW-CSE / Cora data + elif self.args.load_method == 0: + for current_epoch in range(self.args.num_epochs): + pbar = tqdm(range(self.args.num_batches)) + acc_loss = 0.0 + + for k in pbar: + node_embeds = self.model.gcn_forward(self.dataset) + + batch_neg_mask, flat_list, batch_latent_var_inds, observed_rule_cnts, batch_observed_vars = self.dataset.get_batch_rnd( + observed_prob=self.args.observed_prob, + filter_latent=self.args.filter_latent == 1, + closed_world=self.args.closed_world == 1, + filter_observed=1) + + posterior_prob = self.model.posterior_forward(flat_list, node_embeds) + + if self.args.no_entropy == 1: + entropy = 0 + else: + entropy = compute_entropy(posterior_prob) / self.args.entropy_temp + + entropy = entropy.to('cpu') + posterior_prob = posterior_prob.to('cpu') + + potential = self.model.mln_forward(batch_neg_mask, batch_latent_var_inds, observed_rule_cnts, + posterior_prob, + flat_list, batch_observed_vars) + + self.optimizer.zero_grad() + + loss = - (potential + entropy) / self.args.batchsize + acc_loss += loss.item() + + loss.backward() + + self.optimizer.step() + + pbar.set_description('train loss: %.4f, lr: %.4g' % (acc_loss / (k + 1), get_lr(self.optimizer))) + + # test + node_embeds = self.model.gcn_forward(self.dataset) + with torch.no_grad(): + + posterior_prob = self.model.posterior_forward([(e[1], e[2]) for e in self.dataset.test_fact_ls], + node_embeds) + posterior_prob = posterior_prob.to('cpu') + + label = np.array([e[0] for e in self.dataset.test_fact_ls]) + test_log_prob = float( + np.sum(np.log(np.clip(np.abs((1 - label) - posterior_prob.numpy()), 1e-6, 1 - 1e-6)))) + auc_roc = roc_auc_score(label, posterior_prob.numpy()) + auc_pr = average_precision_score(label, posterior_prob.numpy()) + + self.logger.info( + 'Epoch: %d, train loss: %.4f, test auc-roc: %.4f, test auc-pr: %.4f, test log prob: %.4f' % ( + current_epoch, acc_loss / self.args.num_batches, auc_roc, auc_pr, test_log_prob)) + # tqdm.write(str(posterior_prob[:10])) + + # validation for early stop + valid_sample = [] + valid_label = [] + for pred_name in self.dataset.valid_dict_2: + for val, consts in self.dataset.valid_dict_2[pred_name]: + valid_sample.append((pred_name, consts)) + valid_label.append(val) + valid_label = np.array(valid_label) + + valid_prob = self.model.posterior_forward(valid_sample, node_embeds) + valid_prob = valid_prob.to('cpu') + + valid_log_prob = float( + np.sum(np.log(np.clip(np.abs((1 - valid_label) - valid_prob.numpy()), 1e-6, 1 - 1e-6)))) + + # tqdm.write('epoch: %d, valid log prob: %.4f' % (current_epoch, valid_log_prob)) + # + # should_stop = monitor.update(-valid_log_prob) + # scheduler.step(valid_log_prob) + # + # is_current_best = monitor.cnt == 0 + # if is_current_best: + # savepath = joinpath(self.args.exp_path, 'saved_model') + # os.makedirs(savepath, exist_ok=True) + # torch.save(gcn.state_dict(), joinpath(savepath, 'gcn.model')) + # torch.save(posterior_model.state_dict(), joinpath(savepath, 'posterior.model')) + # + # should_stop = should_stop or (current_epoch + 1 == self.args.num_epochs) + # + # if should_stop: + # tqdm.write('Early stopping') + # break + self.evaluate() + + def evaluate(self): + # evaluation after training + node_embeds = self.model.gcn_forward(self.dataset) + with torch.no_grad(): + posterior_prob = self.model.posterior_forward([(e[1], e[2]) for e in self.dataset.test_fact_ls], + node_embeds) + posterior_prob = posterior_prob.to('cpu') + + label = np.array([e[0] for e in self.dataset.test_fact_ls]) + test_log_prob = float( + np.sum(np.log(np.clip(np.abs((1 - label) - posterior_prob.numpy()), 1e-6, 1 - 1e-6)))) + auc_roc = roc_auc_score(label, posterior_prob.numpy()) + auc_pr = average_precision_score(label, posterior_prob.numpy()) + + self.logger.info( + 'test auc-roc: %.4f, test auc-pr: %.4f, test log prob: %.4f' % (auc_roc, auc_pr, test_log_prob)) + pass + + +def compute_entropy(posterior_prob): + eps = 1e-6 + posterior_prob.clamp_(eps, 1 - eps) + compl_prob = 1 - posterior_prob + entropy = -(posterior_prob * torch.log(posterior_prob) + compl_prob * torch.log(compl_prob)).sum() + return entropy + + +def compute_MB_proba(rule_ls, ls_rule_idx): + rule_idx_cnt = Counter(ls_rule_idx) + numerator = 0 + for rule_idx in rule_idx_cnt: + weight = rule_ls[rule_idx].weight + cnt = rule_idx_cnt[rule_idx] + numerator += math.exp(weight * cnt) + return numerator / (numerator + 1.0) + + +def get_lr(optimizer): + return optimizer.state_dict()['param_groups'][0]['lr'] + + +def gen_eval_query(dataset, const2ind=None, pickone=None): + const_ls = dataset.const_sort_dict['type'] + + toindex = lambda x: x + if const2ind is not None: + toindex = lambda x: const2ind[x] + + for val, pred_name, consts in dataset.test_fact_ls: + c1, c2 = toindex(consts[0]), toindex(consts[1]) + + if pickone is not None: + if pred_name != pickone: + continue + + X, invX = [], [] + for const in const_ls: + + if const not in dataset.ht_dict[pred_name][0][consts[0]]: + X.append([c1, toindex(const)]) + if const not in dataset.ht_dict[pred_name][1][consts[1]]: + invX.append([toindex(const), c2]) + + yield pred_name, X, invX, [[c1, c2]] + + +class EarlyStopMonitor: + + def __init__(self, patience): + self.patience = patience + self.cnt = 0 + self.cur_best = float('inf') + + def update(self, loss): + """ + + :param loss: + :return: + return True if patience exceeded + """ + if loss < self.cur_best: + self.cnt = 0 + self.cur_best = loss + else: + self.cnt += 1 + + if self.cnt >= self.patience: + return True + else: + return False + + def reset(self): + self.cnt = 0 + self.cur_best = float('inf') diff --git a/openhgnn/trainerflow/Grail_trainer.py b/openhgnn/trainerflow/Grail_trainer.py new file mode 100644 index 00000000..8f3c63be --- /dev/null +++ b/openhgnn/trainerflow/Grail_trainer.py @@ -0,0 +1,1065 @@ +import argparse +import copy +import dgl +import numpy as np +import torch +from tqdm import tqdm +import torch.nn.functional as F +from openhgnn.models import build_model +from . import BaseFlow, register_flow +from ..tasks import build_task +import random +import torch.optim as optim +import torch.nn as nn +import time + +import multiprocessing as mp +import scipy.sparse as ssp +import json +import networkx as nx +from ..utils.Grail_utils import collate_dgl,move_batch_to_device_dgl, ssp_multigraph_to_dgl +params = None + +@register_flow('Grail_trainer') +class GrailTrainer(BaseFlow): + def __init__(self,args): + super(GrailTrainer, self).__init__(args) + #self.train_hg = self.task.get_train() + self.trainset = self.task.dataset.train + self.valid = self.task.dataset.valid + self.args.num_rels = self.trainset.num_rels + self.args.aug_num_rels = self.trainset.aug_num_rels + self.args.inp_dim = self.trainset.n_feat_dim + self.args.collate_fn = collate_dgl + self.args.move_batch_to_device = move_batch_to_device_dgl + self.params = self.args + self.params.adj_list=[] + self.params.dgl_adj_list=[] + self.params.triplets=[] + self.params.entity2id=[] + self.params.relation2id=[] + self.params.id2entity=[] + self.params.id2relation = [] + # Log the max label value to save it in the model. This will be used to cap the labels generated on test set. + self.args.max_label_value = self.trainset.max_n_label + + self.model = build_model(self.model).build_model_from_args(self.args, self.task.dataset.relation2id).to( + self.device) + self.updates_counter = 0 + model_params = list(self.model.parameters()) + #logging.info('Total number of parameters: %d' % sum(map(lambda x: x.numel(), model_params))) + + if self.args.optimizer == "SGD": + self.optimizer = optim.SGD(model_params, lr=self.args.lr, momentum=self.args.momentum, + weight_decay=self.args.l2) + if self.args.optimizer == "Adam": + self.optimizer = optim.Adam(model_params, lr=self.args.lr, weight_decay=self.args.l2) + + self.criterion = nn.MarginRankingLoss(self.args.margin, reduction='sum') + + self.reset_training_state() + #graph_classifier = initialize_model(params, dgl_model, params.load_model) + + self.logger.info(f"Device: {args.device}") + self.logger.info( + f"Input dim : {args.inp_dim}, # Relations : {args.num_rels}, # Augmented relations : {args.aug_num_rels}") + + self.args.save_path = os.path.dirname(os.path.abspath('__file__')) + '/openhgnn/output/' + self.model_name + self.valid_evaluator = Evaluator(self.args, self.model, self.valid) + + #self.save_path = os.path.dirname(os.path.abspath('__file__')) + '/openhgnn/output/' + self.model_name + + #trainer = Trainer(params, graph_classifier, train, valid_evaluator) + + self.logger.info('Starting training with full batch...') + + + def reset_training_state(self): + self.best_metric = 0 + self.last_metric = 0 + self.not_improved_count = 0 + + def train_epoch(self): + total_loss = 0 + all_preds = [] + all_labels = [] + all_scores = [] + + dataloader = DataLoader(self.trainset, batch_size=self.args.batch_size, shuffle=True, + num_workers=self.args.num_workers, collate_fn=self.args.collate_fn) + self.model.train() # 将模型设置为训练模式 + model_params = list(self.model.parameters()) + torch.multiprocessing.set_sharing_strategy('file_system') + for b_idx, batch in enumerate(dataloader): + data_pos, targets_pos, data_neg, targets_neg = self.args.move_batch_to_device(batch, self.args.device) + # print(batch) + self.optimizer.zero_grad() + score_pos = self.model(data_pos) + score_neg = self.model(data_neg) + loss = self.criterion(score_pos.squeeze(), score_neg.view(len(score_pos), -1).mean(dim=1), + torch.Tensor([1]).to(device=self.args.device)) + + # print(score_pos, score_neg, loss) + loss.backward() + self.optimizer.step() + self.updates_counter += 1 + #print(self.updates_counter) + + with torch.no_grad(): + all_scores += score_pos.squeeze().detach().cpu().tolist() + score_neg.squeeze().detach().cpu().tolist() + all_labels += targets_pos.tolist() + targets_neg.tolist() + total_loss += loss + + if self.valid_evaluator and self.args.eval_every_iter and self.updates_counter % self.args.eval_every_iter == 0: + tic = time.time() + result = self.valid_evaluator.eval() + self.logger.info('\nPerformance:' + str(result) + 'in ' + str(time.time() - tic)) + + if result['auc'] >= self.best_metric: + self.save_classifier() + self.best_metric = result['auc'] + self.not_improved_count = 0 + + else: + self.not_improved_count += 1 + if self.not_improved_count > self.args.early_stop: + self.logger.info( + f"Validation performance didn\'t improve for {self.args.early_stop} epochs. Training stops.") + break + self.last_metric = result['auc'] + + + auc = metrics.roc_auc_score(all_labels, all_scores) + auc_pr = metrics.average_precision_score(all_labels, all_scores) + + weight_norm = sum(map(lambda x: torch.norm(x), model_params)) + + return total_loss, auc, auc_pr, weight_norm + + + def train(self): + self.reset_training_state() + + for epoch in range(1, 1): + time_start = time.time() + loss, auc, auc_pr, weight_norm = self.train_epoch() + time_elapsed = time.time() - time_start + self.logger.info( + f'Epoch {epoch} with loss: {loss}, training auc: {auc}, training auc_pr: {auc_pr}, best validation AUC: {self.best_metric}, weight_norm: {weight_norm} in {time_elapsed}') + if epoch % self.args.save_every == 0: + torch.save(self.model, self.args.save_path + '/Grail_chk.pth') + self.params.model_path = self.args.save_path + '/Grail_chk.pth' + self.params.file_paths = { + 'graph': os.path.join(f'./openhgnn/dataset/data/{self.args.dataset}_ind/train.txt'), + 'links': os.path.join(f'./openhgnn/dataset/data/{self.args.dataset}_ind/test.txt') + } + global params + params = self.params + eval_rank(self.logger) + return + + def save_classifier(self): + #save_path = os.path.dirname(os.path.abspath('__file__')) + '/openhgnn/output/' + self.model_name + torch.save(self.model, self.args.save_path + '/best.pth') + self.logger.info('Better models found w.r.t accuracy. Saved it!') + + def process_files(self,files, saved_relation2id, add_traspose_rels): + ''' + files: Dictionary map of file paths to read the triplets from. + saved_relation2id: Saved relation2id (mostly passed from a trained model) which can be used to map relations to pre-defined indices and filter out the unknown ones. + ''' + entity2id = {} + relation2id = saved_relation2id + + triplets = {} + + ent = 0 + rel = 0 + + for file_type, file_path in files.items(): + + data = [] + with open(file_path) as f: + file_data = [line.split() for line in f.read().split('\n')[:-1]] + + for triplet in file_data: + if triplet[0] not in entity2id: + entity2id[triplet[0]] = ent + ent += 1 + if triplet[2] not in entity2id: + entity2id[triplet[2]] = ent + ent += 1 + + # Save the triplets corresponding to only the known relations + if triplet[1] in saved_relation2id: + data.append([entity2id[triplet[0]], entity2id[triplet[2]], saved_relation2id[triplet[1]]]) + + triplets[file_type] = np.array(data) + + id2entity = {v: k for k, v in entity2id.items()} + id2relation = {v: k for k, v in relation2id.items()} + + # Construct the list of adjacency matrix each corresponding to eeach relation. Note that this is constructed only from the train data. + adj_list = [] + for i in range(len(saved_relation2id)): + idx = np.argwhere(triplets['graph'][:, 2] == i) + adj_list.append(ssp.csc_matrix((np.ones(len(idx), dtype=np.uint8), ( + triplets['graph'][:, 0][idx].squeeze(1), triplets['graph'][:, 1][idx].squeeze(1))), + shape=(len(entity2id), len(entity2id)))) + + # Add transpose matrices to handle both directions of relations. + adj_list_aug = adj_list + if add_traspose_rels: + adj_list_t = [adj.T for adj in adj_list] + adj_list_aug = adj_list + adj_list_t + + dgl_adj_list = self.ssp_multigraph_to_dgl(adj_list_aug) + + return adj_list, dgl_adj_list, triplets, entity2id, relation2id, id2entity, id2relation + ''' + def intialize_worker(model, adj_list, dgl_adj_list, id2entity, params, node_features, kge_entity2id): + global model_, adj_list_, dgl_adj_list_, id2entity_, params_, node_features_, kge_entity2id_ + model_, adj_list_, dgl_adj_list_, id2entity_, params_, node_features_, kge_entity2id_ = model, adj_list, dgl_adj_list, id2entity, params, node_features, kge_entity2id + ''' + def get_neg_samples_replacing_head_tail(self,test_links, adj_list, num_samples=50): + + n, r = adj_list[0].shape[0], len(adj_list) + heads, tails, rels = test_links[:, 0], test_links[:, 1], test_links[:, 2] + + neg_triplets = [] + for i, (head, tail, rel) in enumerate(zip(heads, tails, rels)): + neg_triplet = {'head': [[], 0], 'tail': [[], 0]} + neg_triplet['head'][0].append([head, tail, rel]) + while len(neg_triplet['head'][0]) < num_samples: + neg_head = head + neg_tail = np.random.choice(n) + + if neg_head != neg_tail and adj_list[rel][neg_head, neg_tail] == 0: + neg_triplet['head'][0].append([neg_head, neg_tail, rel]) + + neg_triplet['tail'][0].append([head, tail, rel]) + while len(neg_triplet['tail'][0]) < num_samples: + neg_head = np.random.choice(n) + neg_tail = tail + # neg_head, neg_tail, rel = np.random.choice(n), np.random.choice(n), np.random.choice(r) + + if neg_head != neg_tail and adj_list[rel][neg_head, neg_tail] == 0: + neg_triplet['tail'][0].append([neg_head, neg_tail, rel]) + + neg_triplet['head'][0] = np.array(neg_triplet['head'][0]) + neg_triplet['tail'][0] = np.array(neg_triplet['tail'][0]) + + neg_triplets.append(neg_triplet) + + return neg_triplets + + def get_neg_samples_replacing_head_tail_all(self,test_links, adj_list): + + n, r = adj_list[0].shape[0], len(adj_list) + heads, tails, rels = test_links[:, 0], test_links[:, 1], test_links[:, 2] + + neg_triplets = [] + print('sampling negative triplets...') + for i, (head, tail, rel) in tqdm(enumerate(zip(heads, tails, rels)), total=len(heads)): + neg_triplet = {'head': [[], 0], 'tail': [[], 0]} + neg_triplet['head'][0].append([head, tail, rel]) + for neg_tail in range(n): + neg_head = head + + if neg_head != neg_tail and adj_list[rel][neg_head, neg_tail] == 0: + neg_triplet['head'][0].append([neg_head, neg_tail, rel]) + + neg_triplet['tail'][0].append([head, tail, rel]) + for neg_head in range(n): + neg_tail = tail + + if neg_head != neg_tail and adj_list[rel][neg_head, neg_tail] == 0: + neg_triplet['tail'][0].append([neg_head, neg_tail, rel]) + + neg_triplet['head'][0] = np.array(neg_triplet['head'][0]) + neg_triplet['tail'][0] = np.array(neg_triplet['tail'][0]) + + neg_triplets.append(neg_triplet) + + return neg_triplets + + def get_neg_samples_replacing_head_tail_from_ruleN(self,ruleN_pred_path, entity2id, saved_relation2id): + with open(ruleN_pred_path) as f: + pred_data = [line.split() for line in f.read().split('\n')[:-1]] + + neg_triplets = [] + for i in range(len(pred_data) // 3): + neg_triplet = {'head': [[], 10000], 'tail': [[], 10000]} + if pred_data[3 * i][1] in saved_relation2id: + head, rel, tail = entity2id[pred_data[3 * i][0]], saved_relation2id[pred_data[3 * i][1]], entity2id[ + pred_data[3 * i][2]] + for j, new_head in enumerate(pred_data[3 * i + 1][1::2]): + neg_triplet['head'][0].append([entity2id[new_head], tail, rel]) + if entity2id[new_head] == head: + neg_triplet['head'][1] = j + for j, new_tail in enumerate(pred_data[3 * i + 2][1::2]): + neg_triplet['tail'][0].append([head, entity2id[new_tail], rel]) + if entity2id[new_tail] == tail: + neg_triplet['tail'][1] = j + + neg_triplet['head'][0] = np.array(neg_triplet['head'][0]) + neg_triplet['tail'][0] = np.array(neg_triplet['tail'][0]) + + neg_triplets.append(neg_triplet) + + return neg_triplets + + def incidence_matrix(self,adj_list): + ''' + adj_list: List of sparse adjacency matrices + ''' + + rows, cols, dats = [], [], [] + dim = adj_list[0].shape + for adj in adj_list: + adjcoo = adj.tocoo() + rows += adjcoo.row.tolist() + cols += adjcoo.col.tolist() + dats += adjcoo.data.tolist() + row = np.array(rows) + col = np.array(cols) + data = np.array(dats) + return ssp.csc_matrix((data, (row, col)), shape=dim) + + def _bfs_relational(self,adj, roots, max_nodes_per_hop=None): + """ + BFS for graphs with multiple edge types. Returns list of level sets. + Each entry in list corresponds to relation specified by adj_list. + Modified from dgl.contrib.data.knowledge_graph to node accomodate sampling + """ + visited = set() + current_lvl = set(roots) + + next_lvl = set() + + while current_lvl: + + for v in current_lvl: + visited.add(v) + + next_lvl = self._get_neighbors(adj, current_lvl) + next_lvl -= visited # set difference + + if max_nodes_per_hop and max_nodes_per_hop < len(next_lvl): + next_lvl = set(random.sample(next_lvl, max_nodes_per_hop)) + + yield next_lvl + + current_lvl = set.union(next_lvl) + + def _get_neighbors(self,adj, nodes): + """Takes a set of nodes and a graph adjacency matrix and returns a set of neighbors. + Directly copied from dgl.contrib.data.knowledge_graph""" + sp_nodes = self._sp_row_vec_from_idx_list(list(nodes), adj.shape[1]) + sp_neighbors = sp_nodes.dot(adj) + neighbors = set(ssp.find(sp_neighbors)[1]) # convert to set of indices + return neighbors + + def _sp_row_vec_from_idx_list(self,idx_list, dim): + """Create sparse vector of dimensionality dim from a list of indices.""" + shape = (1, dim) + data = np.ones(len(idx_list)) + row_ind = np.zeros(len(idx_list)) + col_ind = list(idx_list) + return ssp.csr_matrix((data, (row_ind, col_ind)), shape=shape) + + def get_neighbor_nodes(self,roots, adj, h=1, max_nodes_per_hop=None): + bfs_generator = self._bfs_relational(adj, roots, max_nodes_per_hop) + lvls = list() + for _ in range(h): + try: + lvls.append(next(bfs_generator)) + except StopIteration: + pass + return set().union(*lvls) + + def subgraph_extraction_labeling(self,ind, rel, A_list, h=1, enclosing_sub_graph=False, max_nodes_per_hop=None, + node_information=None, max_node_label_value=None): + # extract the h-hop enclosing subgraphs around link 'ind' + A_incidence = self.incidence_matrix(A_list) + A_incidence += A_incidence.T + + # could pack these two into a function + root1_nei = self.get_neighbor_nodes(set([ind[0]]), A_incidence, h, max_nodes_per_hop) + root2_nei = self.get_neighbor_nodes(set([ind[1]]), A_incidence, h, max_nodes_per_hop) + + subgraph_nei_nodes_int = root1_nei.intersection(root2_nei) + subgraph_nei_nodes_un = root1_nei.union(root2_nei) + + # Extract subgraph | Roots being in the front is essential for labelling and the model to work properly. + if enclosing_sub_graph: + subgraph_nodes = list(ind) + list(subgraph_nei_nodes_int) + else: + subgraph_nodes = list(ind) + list(subgraph_nei_nodes_un) + + subgraph = [adj[subgraph_nodes, :][:, subgraph_nodes] for adj in A_list] + + labels, enclosing_subgraph_nodes = self.node_label_new(self.incidence_matrix(subgraph), max_distance=h) + + pruned_subgraph_nodes = np.array(subgraph_nodes)[enclosing_subgraph_nodes].tolist() + pruned_labels = labels[enclosing_subgraph_nodes] + + if max_node_label_value is not None: + pruned_labels = np.array([np.minimum(label, max_node_label_value).tolist() for label in pruned_labels]) + + return pruned_subgraph_nodes, pruned_labels + + def remove_nodes(self,A_incidence, nodes): + idxs_wo_nodes = list(set(range(A_incidence.shape[1])) - set(nodes)) + return A_incidence[idxs_wo_nodes, :][:, idxs_wo_nodes] + + def node_label_new(self,subgraph, max_distance=1): + # an implementation of the proposed double-radius node labeling (DRNd L) + roots = [0, 1] + sgs_single_root = [self.remove_nodes(subgraph, [root]) for root in roots] + dist_to_roots = [ + np.clip(ssp.csgraph.dijkstra(sg, indices=[0], directed=False, unweighted=True, limit=1e6)[:, 1:], 0, 1e7) + for r, sg in enumerate(sgs_single_root)] + dist_to_roots = np.array(list(zip(dist_to_roots[0][0], dist_to_roots[1][0])), dtype=int) + + # dist_to_roots[np.abs(dist_to_roots) > 1e6] = 0 + # dist_to_roots = dist_to_roots + 1 + target_node_labels = np.array([[0, 1], [1, 0]]) + labels = np.concatenate((target_node_labels, dist_to_roots)) if dist_to_roots.size else target_node_labels + + enclosing_subgraph_nodes = np.where(np.max(labels, axis=1) <= max_distance)[0] + # print(len(enclosing_subgraph_nodes)) + return labels, enclosing_subgraph_nodes + + def ssp_multigraph_to_dgl(self,graph, n_feats=None): + """ + Converting ssp multigraph (i.e. list of adjs) to dgl multigraph. + """ + + g_nx = nx.MultiDiGraph() + g_nx.add_nodes_from(list(range(graph[0].shape[0]))) + # Add edges + for rel, adj in enumerate(graph): + # Convert adjacency matrix to tuples for nx0 + nx_triplets = [] + for src, dst in list(zip(adj.tocoo().row, adj.tocoo().col)): + nx_triplets.append((src, dst, {'type': rel})) + g_nx.add_edges_from(nx_triplets) + + # make dgl graph + g_dgl = dgl.DGLGraph(multigraph=True) + g_dgl = dgl.from_networkx(g_nx, edge_attrs=['type']) + # add node features + if n_feats is not None: + g_dgl.ndata['feat'] = torch.tensor(n_feats) + + return g_dgl + + def prepare_features(self,subgraph, n_labels, max_n_label, n_feats=None): + # One hot encode the node label feature and concat to n_featsure + n_nodes = subgraph.number_of_nodes() + label_feats = np.zeros((n_nodes, max_n_label[0] + 1 + max_n_label[1] + 1)) + label_feats[np.arange(n_nodes), n_labels[:, 0]] = 1 + label_feats[np.arange(n_nodes), max_n_label[0] + 1 + n_labels[:, 1]] = 1 + n_feats = np.concatenate((label_feats, n_feats), axis=1) if n_feats is not None else label_feats + subgraph.ndata['feat'] = torch.FloatTensor(n_feats) + + head_id = np.argwhere([label[0] == 0 and label[1] == 1 for label in n_labels]) + tail_id = np.argwhere([label[0] == 1 and label[1] == 0 for label in n_labels]) + n_ids = np.zeros(n_nodes) + n_ids[head_id] = 1 # head + n_ids[tail_id] = 2 # tail + subgraph.ndata['id'] = torch.FloatTensor(n_ids) + + return subgraph + + def get_subgraphs(self, all_links, adj_list, dgl_adj_list, max_node_label_value, id2entity, node_features=None, + kge_entity2id=None): + # dgl_adj_list = ssp_multigraph_to_dgl(adj_list) + + subgraphs = [] + r_labels = [] + + for link in all_links: + head, tail, rel = link[0], link[1], link[2] + nodes, node_labels = self.subgraph_extraction_labeling((head, tail), rel, adj_list, h=self.args.hop, + enclosing_sub_graph=self.args.enclosing_sub_graph, + max_node_label_value=max_node_label_value) + + subgraph = dgl_adj_list.subgraph(nodes) + subgraph.edata['type'] = dgl_adj_list.edata['type'][dgl_adj_list.subgraph(nodes).edata[dgl.EID]] + subgraph.edata['label'] = torch.tensor(rel * np.ones(subgraph.edata['type'].shape), dtype=torch.long) + + try: + edges_btw_roots = subgraph.edge_ids(0, 1) + edges_btw_roots = torch.tensor([edges_btw_roots]) + except: + edges_btw_roots = torch.tensor([]) + edges_btw_roots = edges_btw_roots.numpy() + rel_link = np.nonzero(subgraph.edata['type'][edges_btw_roots] == rel) + + if rel_link.squeeze().nelement() == 0: + # subgraph.add_edge(0, 1, {'type': torch.tensor([rel]), 'label': torch.tensor([rel])}) + subgraph = dgl.add_edges(subgraph, 0, 1) + subgraph.edata['type'][-1] = torch.tensor(rel).type(torch.LongTensor) + subgraph.edata['label'][-1] = torch.tensor(rel).type(torch.LongTensor) + + kge_nodes = [kge_entity2id[id2entity[n]] for n in nodes] if kge_entity2id else None + n_feats = node_features[kge_nodes] if node_features is not None else None + subgraph = self.prepare_features(subgraph, node_labels, max_node_label_value, n_feats) + + subgraphs.append(subgraph) + r_labels.append(rel) + + batched_graph = dgl.batch(subgraphs) + r_labels = torch.LongTensor(r_labels) + + return (batched_graph, r_labels) + + def get_rank(self,neg_links): + head_neg_links = neg_links['head'][0] + head_target_id = neg_links['head'][1] + + if head_target_id != 10000: + data = self.get_subgraphs(head_neg_links, self.params.adj_list, self.params.dgl_adj_list, self.model.gnn.max_label_value, self.params.id2entity, + None, None) + data = (data[0].to(self.device), data[1].to(self.device)) + head_scores = self.model(data).cpu().squeeze(1).detach().numpy() + head_rank = np.argwhere(np.argsort(head_scores)[::-1] == head_target_id) + 1 + else: + head_scores = np.array([]) + head_rank = 10000 + + tail_neg_links = neg_links['tail'][0] + tail_target_id = neg_links['tail'][1] + + if tail_target_id != 10000: + data = self.get_subgraphs(tail_neg_links, self.params.adj_list, self.params.dgl_adj_list, self.model.gnn.max_label_value, self.params.id2entity, + self.trainset.node_features, self.trainset.kge_entity2id) + data = (data[0].to(self.device), data[1].to(self.device)) + tail_scores = self.model(data).cpu().squeeze(1).detach().numpy() + tail_rank = np.argwhere(np.argsort(tail_scores)[::-1] == tail_target_id) + 1 + else: + tail_scores = np.array([]) + tail_rank = 10000 + + return head_scores, head_rank, tail_scores, tail_rank + + def save_to_file(self,neg_triplets, id2entity, id2relation): + + with open(os.path.join( + self.params.save_path,'ranking_head.txt'), + "w") as f: + for neg_triplet in neg_triplets: + for s, o, r in neg_triplet['head'][0]: + f.write('\t'.join([id2entity[s], id2relation[r], id2entity[o]]) + '\n') + + with open(os.path.join( + self.params.save_path,'ranking_tail.txt'), + "w") as f: + for neg_triplet in neg_triplets: + for s, o, r in neg_triplet['tail'][0]: + f.write('\t'.join([id2entity[s], id2relation[r], id2entity[o]]) + '\n') + + def save_score_to_file(self,neg_triplets, all_head_scores, all_tail_scores, id2entity, id2relation): + + with open(os.path.join( + self.params.save_path,'grail_ranking_head_predictions.txt'), + "w") as f: + for i, neg_triplet in enumerate(neg_triplets): + for [s, o, r], head_score in zip(neg_triplet['head'][0], all_head_scores[50 * i:50 * (i + 1)]): + f.write('\t'.join([id2entity[s], id2relation[r], id2entity[o], str(head_score)]) + '\n') + + with open(os.path.join( + self.params.save_path,'grail_ranking_tail_predictions.txt'), + "w") as f: + for i, neg_triplet in enumerate(neg_triplets): + for [s, o, r], tail_score in zip(neg_triplet['tail'][0], all_tail_scores[50 * i:50 * (i + 1)]): + f.write('\t'.join([id2entity[s], id2relation[r], id2entity[o], str(tail_score)]) + '\n') + + def save_score_to_file_from_ruleN(self, neg_triplets, all_head_scores, all_tail_scores, id2entity, id2relation): + + with open(os.path.join( + self.params.save_path,'grail_ruleN_ranking_head_predictions.txt'), + "w") as f: + for i, neg_triplet in enumerate(neg_triplets): + for [s, o, r], head_score in zip(neg_triplet['head'][0], all_head_scores[50 * i:50 * (i + 1)]): + f.write('\t'.join([id2entity[s], id2relation[r], id2entity[o], str(head_score)]) + '\n') + + with open(os.path.join( + self.params.save_path,'grail_ruleN_ranking_tail_predictions.txt'), + "w") as f: + for i, neg_triplet in enumerate(neg_triplets): + for [s, o, r], tail_score in zip(neg_triplet['tail'][0], all_tail_scores[50 * i:50 * (i + 1)]): + f.write('\t'.join([id2entity[s], id2relation[r], id2entity[o], str(tail_score)]) + '\n') + + + + + + + +import os +from sklearn import metrics +from torch.utils.data import DataLoader +class Evaluator(): + def __init__(self, params, graph_classifier, data): + self.params = params + self.graph_classifier = graph_classifier + self.data = data + + def eval(self, save=False): + pos_scores = [] + pos_labels = [] + neg_scores = [] + neg_labels = [] + dataloader = DataLoader(self.data, batch_size=self.params.batch_size, shuffle=False, num_workers=self.params.num_workers, collate_fn=self.params.collate_fn) + + self.graph_classifier.eval() + with torch.no_grad(): + for b_idx, batch in enumerate(dataloader): + + data_pos, targets_pos, data_neg, targets_neg = self.params.move_batch_to_device(batch, self.params.device) + # print([self.data.id2relation[r.item()] for r in data_pos[1]]) + # pdb.set_trace() + score_pos = self.graph_classifier(data_pos) + score_neg = self.graph_classifier(data_neg) + + # preds += torch.argmax(logits.detach().cpu(), dim=1).tolist() + pos_scores += score_pos.squeeze(1).detach().cpu().tolist() + neg_scores += score_neg.squeeze(1).detach().cpu().tolist() + pos_labels += targets_pos.tolist() + neg_labels += targets_neg.tolist() + + # acc = metrics.accuracy_score(labels, preds) + auc = metrics.roc_auc_score(pos_labels + neg_labels, pos_scores + neg_scores) + auc_pr = metrics.average_precision_score(pos_labels + neg_labels, pos_scores + neg_scores) + + if save: + pos_test_triplets_path = os.path.join(self.params.save_path, 'data/{}/{}.txt'.format(self.params.dataset, self.data.file_name)) + with open(pos_test_triplets_path) as f: + pos_triplets = [line.split() for line in f.read().split('\n')[:-1]] + pos_file_path = os.path.join(self.params.save_path, 'data/{}/grail_{}_predictions.txt'.format(self.params.dataset, self.data.file_name)) + with open(pos_file_path, "w") as f: + for ([s, r, o], score) in zip(pos_triplets, pos_scores): + f.write('\t'.join([s, r, o, str(score)]) + '\n') + + neg_test_triplets_path = os.path.join(self.params.save_path, 'data/{}/neg_{}_0.txt'.format(self.params.dataset, self.data.file_name)) + with open(neg_test_triplets_path) as f: + neg_triplets = [line.split() for line in f.read().split('\n')[:-1]] + neg_file_path = os.path.join(self.params.save_path, 'data/{}/grail_neg_{}_{}_predictions.txt'.format(self.params.dataset, self.data.file_name, self.params.constrained_neg_prob)) + with open(neg_file_path, "w") as f: + for ([s, r, o], score) in zip(neg_triplets, neg_scores): + f.write('\t'.join([s, r, o, str(score)]) + '\n') + + return {'auc': auc, 'auc_pr': auc_pr} + +def process_files(files, saved_relation2id, add_traspose_rels): + ''' + files: Dictionary map of file paths to read the triplets from. + saved_relation2id: Saved relation2id (mostly passed from a trained model) which can be used to map relations to pre-defined indices and filter out the unknown ones. + ''' + entity2id = {} + relation2id = saved_relation2id + + triplets = {} + + ent = 0 + rel = 0 + + for file_type, file_path in files.items(): + + data = [] + with open(file_path) as f: + file_data = [line.split() for line in f.read().split('\n')[:-1]] + + for triplet in file_data: + if triplet[0] not in entity2id: + entity2id[triplet[0]] = ent + ent += 1 + if triplet[2] not in entity2id: + entity2id[triplet[2]] = ent + ent += 1 + + # Save the triplets corresponding to only the known relations + if triplet[1] in saved_relation2id: + data.append([entity2id[triplet[0]], entity2id[triplet[2]], saved_relation2id[triplet[1]]]) + + triplets[file_type] = np.array(data) + + id2entity = {v: k for k, v in entity2id.items()} + id2relation = {v: k for k, v in relation2id.items()} + + # Construct the list of adjacency matrix each corresponding to eeach relation. Note that this is constructed only from the train data. + adj_list = [] + for i in range(len(saved_relation2id)): + idx = np.argwhere(triplets['graph'][:, 2] == i) + adj_list.append(ssp.csc_matrix((np.ones(len(idx), dtype=np.uint8), (triplets['graph'][:, 0][idx].squeeze(1), triplets['graph'][:, 1][idx].squeeze(1))), shape=(len(entity2id), len(entity2id)))) + + # Add transpose matrices to handle both directions of relations. + adj_list_aug = adj_list + if add_traspose_rels: + adj_list_t = [adj.T for adj in adj_list] + adj_list_aug = adj_list + adj_list_t + + dgl_adj_list = ssp_multigraph_to_dgl(adj_list_aug) + + return adj_list, dgl_adj_list, triplets, entity2id, relation2id, id2entity, id2relation + + +def intialize_worker(model, adj_list, dgl_adj_list, id2entity, params, node_features, kge_entity2id): + global model_, adj_list_, dgl_adj_list_, id2entity_, params_, node_features_, kge_entity2id_ + model_, adj_list_, dgl_adj_list_, id2entity_, params_, node_features_, kge_entity2id_ = model, adj_list, dgl_adj_list, id2entity, params, node_features, kge_entity2id + + +def get_neg_samples_replacing_head_tail(test_links, adj_list, num_samples=50): + + n, r = adj_list[0].shape[0], len(adj_list) + heads, tails, rels = test_links[:, 0], test_links[:, 1], test_links[:, 2] + + neg_triplets = [] + for i, (head, tail, rel) in enumerate(zip(heads, tails, rels)): + neg_triplet = {'head': [[], 0], 'tail': [[], 0]} + neg_triplet['head'][0].append([head, tail, rel]) + while len(neg_triplet['head'][0]) < num_samples: + neg_head = head + neg_tail = np.random.choice(n) + + if neg_head != neg_tail and adj_list[rel][neg_head, neg_tail] == 0: + neg_triplet['head'][0].append([neg_head, neg_tail, rel]) + + neg_triplet['tail'][0].append([head, tail, rel]) + while len(neg_triplet['tail'][0]) < num_samples: + neg_head = np.random.choice(n) + neg_tail = tail + # neg_head, neg_tail, rel = np.random.choice(n), np.random.choice(n), np.random.choice(r) + + if neg_head != neg_tail and adj_list[rel][neg_head, neg_tail] == 0: + neg_triplet['tail'][0].append([neg_head, neg_tail, rel]) + + neg_triplet['head'][0] = np.array(neg_triplet['head'][0]) + neg_triplet['tail'][0] = np.array(neg_triplet['tail'][0]) + + neg_triplets.append(neg_triplet) + + return neg_triplets + + +def get_neg_samples_replacing_head_tail_all(test_links, adj_list): + + n, r = adj_list[0].shape[0], len(adj_list) + heads, tails, rels = test_links[:, 0], test_links[:, 1], test_links[:, 2] + + neg_triplets = [] + print('sampling negative triplets...') + for i, (head, tail, rel) in tqdm(enumerate(zip(heads, tails, rels)), total=len(heads)): + neg_triplet = {'head': [[], 0], 'tail': [[], 0]} + neg_triplet['head'][0].append([head, tail, rel]) + for neg_tail in range(n): + neg_head = head + + if neg_head != neg_tail and adj_list[rel][neg_head, neg_tail] == 0: + neg_triplet['head'][0].append([neg_head, neg_tail, rel]) + + neg_triplet['tail'][0].append([head, tail, rel]) + for neg_head in range(n): + neg_tail = tail + + if neg_head != neg_tail and adj_list[rel][neg_head, neg_tail] == 0: + neg_triplet['tail'][0].append([neg_head, neg_tail, rel]) + + neg_triplet['head'][0] = np.array(neg_triplet['head'][0]) + neg_triplet['tail'][0] = np.array(neg_triplet['tail'][0]) + + neg_triplets.append(neg_triplet) + + return neg_triplets + + +def get_neg_samples_replacing_head_tail_from_ruleN(ruleN_pred_path, entity2id, saved_relation2id): + with open(ruleN_pred_path) as f: + pred_data = [line.split() for line in f.read().split('\n')[:-1]] + + neg_triplets = [] + for i in range(len(pred_data) // 3): + neg_triplet = {'head': [[], 10000], 'tail': [[], 10000]} + if pred_data[3 * i][1] in saved_relation2id: + head, rel, tail = entity2id[pred_data[3 * i][0]], saved_relation2id[pred_data[3 * i][1]], entity2id[pred_data[3 * i][2]] + for j, new_head in enumerate(pred_data[3 * i + 1][1::2]): + neg_triplet['head'][0].append([entity2id[new_head], tail, rel]) + if entity2id[new_head] == head: + neg_triplet['head'][1] = j + for j, new_tail in enumerate(pred_data[3 * i + 2][1::2]): + neg_triplet['tail'][0].append([head, entity2id[new_tail], rel]) + if entity2id[new_tail] == tail: + neg_triplet['tail'][1] = j + + neg_triplet['head'][0] = np.array(neg_triplet['head'][0]) + neg_triplet['tail'][0] = np.array(neg_triplet['tail'][0]) + + neg_triplets.append(neg_triplet) + + return neg_triplets + + +def incidence_matrix(adj_list): + ''' + adj_list: List of sparse adjacency matrices + ''' + + rows, cols, dats = [], [], [] + dim = adj_list[0].shape + for adj in adj_list: + adjcoo = adj.tocoo() + rows += adjcoo.row.tolist() + cols += adjcoo.col.tolist() + dats += adjcoo.data.tolist() + row = np.array(rows) + col = np.array(cols) + data = np.array(dats) + return ssp.csc_matrix((data, (row, col)), shape=dim) + + +def _bfs_relational(adj, roots, max_nodes_per_hop=None): + """ + BFS for graphs with multiple edge types. Returns list of level sets. + Each entry in list corresponds to relation specified by adj_list. + Modified from dgl.contrib.data.knowledge_graph to node accomodate sampling + """ + visited = set() + current_lvl = set(roots) + + next_lvl = set() + + while current_lvl: + + for v in current_lvl: + visited.add(v) + + next_lvl = _get_neighbors(adj, current_lvl) + next_lvl -= visited # set difference + + if max_nodes_per_hop and max_nodes_per_hop < len(next_lvl): + next_lvl = set(random.sample(next_lvl, max_nodes_per_hop)) + + yield next_lvl + + current_lvl = set.union(next_lvl) + + +def _get_neighbors(adj, nodes): + """Takes a set of nodes and a graph adjacency matrix and returns a set of neighbors. + Directly copied from dgl.contrib.data.knowledge_graph""" + sp_nodes = _sp_row_vec_from_idx_list(list(nodes), adj.shape[1]) + sp_neighbors = sp_nodes.dot(adj) + neighbors = set(ssp.find(sp_neighbors)[1]) # convert to set of indices + return neighbors + + +def _sp_row_vec_from_idx_list(idx_list, dim): + """Create sparse vector of dimensionality dim from a list of indices.""" + shape = (1, dim) + data = np.ones(len(idx_list)) + row_ind = np.zeros(len(idx_list)) + col_ind = list(idx_list) + return ssp.csr_matrix((data, (row_ind, col_ind)), shape=shape) + + +def get_neighbor_nodes(roots, adj, h=1, max_nodes_per_hop=None): + bfs_generator = _bfs_relational(adj, roots, max_nodes_per_hop) + lvls = list() + for _ in range(h): + try: + lvls.append(next(bfs_generator)) + except StopIteration: + pass + return set().union(*lvls) + + +def subgraph_extraction_labeling(ind, rel, A_list, h=1, enclosing_sub_graph=False, max_nodes_per_hop=None, node_information=None, max_node_label_value=None): + # extract the h-hop enclosing subgraphs around link 'ind' + A_incidence = incidence_matrix(A_list) + A_incidence += A_incidence.T + + # could pack these two into a function + root1_nei = get_neighbor_nodes(set([ind[0]]), A_incidence, h, max_nodes_per_hop) + root2_nei = get_neighbor_nodes(set([ind[1]]), A_incidence, h, max_nodes_per_hop) + + subgraph_nei_nodes_int = root1_nei.intersection(root2_nei) + subgraph_nei_nodes_un = root1_nei.union(root2_nei) + + # Extract subgraph | Roots being in the front is essential for labelling and the model to work properly. + if enclosing_sub_graph: + subgraph_nodes = list(ind) + list(subgraph_nei_nodes_int) + else: + subgraph_nodes = list(ind) + list(subgraph_nei_nodes_un) + + subgraph = [adj[subgraph_nodes, :][:, subgraph_nodes] for adj in A_list] + + labels, enclosing_subgraph_nodes = node_label_new(incidence_matrix(subgraph), max_distance=h) + + pruned_subgraph_nodes = np.array(subgraph_nodes)[enclosing_subgraph_nodes].tolist() + pruned_labels = labels[enclosing_subgraph_nodes] + + if max_node_label_value is not None: + pruned_labels = np.array([np.minimum(label, max_node_label_value).tolist() for label in pruned_labels]) + + return pruned_subgraph_nodes, pruned_labels + + +def remove_nodes(A_incidence, nodes): + idxs_wo_nodes = list(set(range(A_incidence.shape[1])) - set(nodes)) + return A_incidence[idxs_wo_nodes, :][:, idxs_wo_nodes] + + +def node_label_new(subgraph, max_distance=1): + # an implementation of the proposed double-radius node labeling (DRNd L) + roots = [0, 1] + sgs_single_root = [remove_nodes(subgraph, [root]) for root in roots] + dist_to_roots = [np.clip(ssp.csgraph.dijkstra(sg, indices=[0], directed=False, unweighted=True, limit=1e6)[:, 1:], 0, 1e7) for r, sg in enumerate(sgs_single_root)] + dist_to_roots = np.array(list(zip(dist_to_roots[0][0], dist_to_roots[1][0])), dtype=int) + + # dist_to_roots[np.abs(dist_to_roots) > 1e6] = 0 + # dist_to_roots = dist_to_roots + 1 + target_node_labels = np.array([[0, 1], [1, 0]]) + labels = np.concatenate((target_node_labels, dist_to_roots)) if dist_to_roots.size else target_node_labels + + enclosing_subgraph_nodes = np.where(np.max(labels, axis=1) <= max_distance)[0] + # print(len(enclosing_subgraph_nodes)) + return labels, enclosing_subgraph_nodes + + + + +def prepare_features(subgraph, n_labels, max_n_label, n_feats=None): + # One hot encode the node label feature and concat to n_featsure + n_nodes = subgraph.number_of_nodes() + label_feats = np.zeros((n_nodes, max_n_label[0] + 1 + max_n_label[1] + 1)) + label_feats[np.arange(n_nodes), n_labels[:, 0]] = 1 + label_feats[np.arange(n_nodes), max_n_label[0] + 1 + n_labels[:, 1]] = 1 + n_feats = np.concatenate((label_feats, n_feats), axis=1) if n_feats is not None else label_feats + subgraph.ndata['feat'] = torch.FloatTensor(n_feats) + + head_id = np.argwhere([label[0] == 0 and label[1] == 1 for label in n_labels]) + tail_id = np.argwhere([label[0] == 1 and label[1] == 0 for label in n_labels]) + n_ids = np.zeros(n_nodes) + n_ids[head_id] = 1 # head + n_ids[tail_id] = 2 # tail + subgraph.ndata['id'] = torch.FloatTensor(n_ids) + + return subgraph + + +def get_subgraphs(all_links, adj_list, dgl_adj_list, max_node_label_value, id2entity, node_features=None, kge_entity2id=None): + # dgl_adj_list = ssp_multigraph_to_dgl(adj_list) + + subgraphs = [] + r_labels = [] + + for link in all_links: + head, tail, rel = link[0], link[1], link[2] + nodes, node_labels = subgraph_extraction_labeling((head, tail), rel, adj_list, h=params_.hop, enclosing_sub_graph=params.enclosing_sub_graph, max_node_label_value=max_node_label_value) + + subgraph = dgl_adj_list.subgraph(nodes) + subgraph.edata['type'] = dgl_adj_list.edata['type'][dgl_adj_list.subgraph(nodes).edata[dgl.EID] ] + subgraph.edata['label'] = torch.tensor(rel * np.ones(subgraph.edata['type'].shape), dtype=torch.long) + + # edges_btw_roots = subgraph.edge_id(0, 1) + try: + edges_btw_roots = subgraph.edge_ids(0, 1) + edges_btw_roots = torch.tensor([edges_btw_roots]) + except: + edges_btw_roots = torch.tensor([]) + edges_btw_roots = edges_btw_roots.numpy() + rel_link = np.nonzero(subgraph.edata['type'][edges_btw_roots] == rel) + + if rel_link.squeeze().nelement() == 0: + # subgraph.add_edge(0, 1, {'type': torch.tensor([rel]), 'label': torch.tensor([rel])}) + # subgraph.add_edge(0, 1) + subgraph = dgl.add_edges(subgraph, 0, 1) + subgraph.edata['type'][-1] = torch.tensor(rel).type(torch.LongTensor) + subgraph.edata['label'][-1] = torch.tensor(rel).type(torch.LongTensor) + + + kge_nodes = [kge_entity2id[id2entity[n]] for n in nodes] if kge_entity2id else None + n_feats = node_features[kge_nodes] if node_features is not None else None + subgraph = prepare_features(subgraph, node_labels, max_node_label_value, n_feats) + + subgraphs.append(subgraph) + r_labels.append(rel) + + batched_graph = dgl.batch(subgraphs) + r_labels = torch.LongTensor(r_labels) + + return (batched_graph, r_labels) + + +def get_rank(neg_links): + head_neg_links = neg_links['head'][0] + head_target_id = neg_links['head'][1] + + if head_target_id != 10000: + data = get_subgraphs(head_neg_links, adj_list_, dgl_adj_list_, model_.gnn.max_label_value, id2entity_, node_features_, kge_entity2id_) + head_scores = model_(data).squeeze(1).detach().numpy() + head_rank = np.argwhere(np.argsort(head_scores)[::-1] == head_target_id) + 1 + else: + head_scores = np.array([]) + head_rank = 10000 + + tail_neg_links = neg_links['tail'][0] + tail_target_id = neg_links['tail'][1] + + if tail_target_id != 10000: + data = get_subgraphs(tail_neg_links, adj_list_, dgl_adj_list_, model_.gnn.max_label_value, id2entity_, node_features_, kge_entity2id_) + tail_scores = model_(data).squeeze(1).detach().numpy() + tail_rank = np.argwhere(np.argsort(tail_scores)[::-1] == tail_target_id) + 1 + else: + tail_scores = np.array([]) + tail_rank = 10000 + + return head_scores, head_rank, tail_scores, tail_rank + + + +def eval_rank(logger): + # print(params.file_paths) + model = torch.load(params.model_path, map_location='cpu') + + adj_list, dgl_adj_list, triplets, entity2id, relation2id, id2entity, id2relation = process_files(params.file_paths, model.relation2id, params.add_traspose_rels) + + node_features, kge_entity2id = None, None + + if params.mode == 'sample': + neg_triplets = get_neg_samples_replacing_head_tail(triplets['links'], adj_list) + elif params.mode == 'all': + neg_triplets = get_neg_samples_replacing_head_tail_all(triplets['links'], adj_list) + elif params.mode == 'ruleN': + neg_triplets = get_neg_samples_replacing_head_tail_from_ruleN(params.ruleN_pred_path, entity2id, relation2id) + + ranks = [] + all_head_scores = [] + all_tail_scores = [] + intialize_worker(model, adj_list, dgl_adj_list, id2entity, params, node_features, kge_entity2id) + # with mp.Pool(processes=None, initializer=intialize_worker, initargs=(model, adj_list, dgl_adj_list, id2entity, params, node_features, kge_entity2id)) as p: + intialize_worker(model, adj_list, dgl_adj_list, id2entity, params, node_features, kge_entity2id) + for link in tqdm(neg_triplets, total=len(neg_triplets)): + head_scores, head_rank, tail_scores, tail_rank = get_rank(link) + ranks.append(head_rank) + ranks.append(tail_rank) + + all_head_scores += head_scores.tolist() + all_tail_scores += tail_scores.tolist() + + + + + isHit1List = [x for x in ranks if x <= 1] + isHit5List = [x for x in ranks if x <= 5] + isHit10List = [x for x in ranks if x <= 10] + hits_1 = len(isHit1List) / len(ranks) + hits_5 = len(isHit5List) / len(ranks) + hits_10 = len(isHit10List) / len(ranks) + + mrr = np.mean(1 / np.array(ranks)) + + logger.info(f'MRR | Hits@1 | Hits@5 | Hits@10 : {mrr} | {hits_1} | {hits_5} | {hits_10}') \ No newline at end of file diff --git a/openhgnn/trainerflow/Ingram_trainer.py b/openhgnn/trainerflow/Ingram_trainer.py new file mode 100644 index 00000000..d123f6d0 --- /dev/null +++ b/openhgnn/trainerflow/Ingram_trainer.py @@ -0,0 +1,141 @@ +from scipy.sparse import csr_matrix +import numpy as np +import torch +import dgl +from tqdm import tqdm +import random +from ..models import build_model +from . import BaseFlow, register_flow +from ..tasks import build_task +import os +from ..utils import Ingram_utils + + +def evaluate(my_model, target, epoch, init_emb_ent, init_emb_rel, relation_triplets): + with torch.no_grad(): + my_model.eval() + + msg = torch.tensor(target.msg_triplets).cuda() + sup = torch.tensor(target.sup_triplets).cuda() + + emb_ent, emb_rel = my_model(init_emb_ent, init_emb_rel, msg, relation_triplets) + + head_ranks = [] + tail_ranks = [] + ranks = [] + for triplet in tqdm(sup): + triplet = triplet.unsqueeze(dim=0) + head_corrupt = triplet.repeat(target.num_ent, 3) + head_corrupt[:, 0] = torch.arange(end=target.num_ent) + head_scores = my_model.score(emb_ent, emb_rel, head_corrupt) + head_filters = target.filter_dict[('_', int(triplet[0, 1].item()), int(triplet[0, 2].item()))] + head_rank = Ingram_utils.get_rank(triplet, head_scores, head_filters, target=0) + tail_corrupt = triplet.repeat(target.num_ent, 3) + tail_corrupt[:, 2] = torch.arange(end=target.num_ent) + tail_scores = my_model.score(emb_ent, emb_rel, tail_corrupt) + tail_filters = target.filter_dict[(int(triplet[0, 0].item()), int(triplet[0, 1].item()), '_')] + tail_rank = Ingram_utils.get_rank(triplet, tail_scores, tail_filters, target=2) + ranks.append(head_rank) + head_ranks.append(head_rank) + ranks.append(tail_rank) + tail_ranks.append(tail_rank) + + print("--------LP--------") + mr, mrr, hit10, hit3, hit1 = Ingram_utils.get_metrics(ranks) + print(f"MR: {mr:.1f}") + print(f"MRR: {mrr:.3f}") + print(f"Hits@10: {hit10:.3f}") + print(f"Hits@1: {hit1:.3f}") + + +@register_flow("Ingram_trainer") +class Ingram_Trainer(BaseFlow): + """ingram flows.""" + OMP_NUM_THREADS = 8 + torch.manual_seed(0) + random.seed(0) + np.random.seed(0) + torch.autograd.set_detect_anomaly(True) + torch.backends.cudnn.benchmark = True + torch.set_num_threads(8) + torch.cuda.empty_cache() + + def __init__(self, args): + super().__init__(args) + self.args = args + self.output_dir = args.output_dir + self.model_name = args.model + self.device = args.device + self.task = build_task(args) + self.model = build_model(self.model_name).build_model_from_args( + self.args).model + print("build_model_finish") + torch.cuda.set_device(args.device) + self.model = self.model.to(self.device) + self.loss_fn = torch.nn.MarginRankingLoss(margin=args.margin, reduction='mean') + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=args.lr) + self.validation_epoch = args.validation_epoch + self.num_epoch = args.num_epoch + self.num_neg = args.num_neg + self.num_bin = args.num_bin + self.d_e = args.d_e + self.d_r = args.d_r + + def train(self): + my_model = self.model + train = self.task.train_dataloader + valid = self.task.valid_dataloader + test = self.task.test_dataloader + pbar = tqdm(range(self.num_epoch)) + valid_epochs = self.args.validation_epoch + total_loss = 0 + file_format = f"lr_{self.args.lr}_dim_{self.args.d_e}_{self.args.d_r}" + \ + f"_bin_{self.args.num_bin}_total_{self.args.num_epoch}_every_{self.args.validation_epoch}" + \ + f"_neg_{self.args.num_neg}_layer_{self.args.nle}_{self.args.nlr}" + \ + f"_hid_{self.args.hdr_e}_{self.args.hdr_r}" + \ + f"_head_{self.args.num_head}_margin_{self.args.margin}" + for epoch in pbar: + self.optimizer.zero_grad() + msg, sup = train.split_transductive(0.75) + init_emb_ent, init_emb_rel, relation_triplets = Ingram_utils.initialize(train, msg, self.d_e, self.d_r, + self.num_bin) + msg = torch.tensor(msg).cuda() + sup = torch.tensor(sup).cuda() + emb_ent, emb_rel = my_model(init_emb_ent, init_emb_rel, msg, relation_triplets) + pos_scores = my_model.score(emb_ent, emb_rel, sup) + neg_scores = my_model.score(emb_ent, emb_rel, + Ingram_utils.generate_neg(sup, train.num_ent, num_neg=self.num_neg)) + loss = self.loss_fn(pos_scores.repeat(self.num_neg), neg_scores, torch.ones_like(neg_scores)) + + loss.backward() + torch.nn.utils.clip_grad_norm_(my_model.parameters(), 0.1, error_if_nonfinite=False) + self.optimizer.step() + total_loss += loss.item() + pbar.set_description(f"loss {loss.item()}") + + if ((epoch + 1) % valid_epochs) == 0: + print("Validation") + my_model.eval() + val_init_emb_ent, val_init_emb_rel, val_relation_triplets = Ingram_utils.initialize(valid, + valid.msg_triplets, \ + self.d_e, self.d_r, + self.num_bin) + + evaluate(my_model, valid, epoch, val_init_emb_ent, val_init_emb_rel, val_relation_triplets) + path_ckp = self.output_dir + print(path_ckp) + folder = os.path.exists(path_ckp) + + if not folder: # 判断是否存在文件夹如果不存在则创建为文件夹 + os.makedirs(path_ckp) # makedirs 创建文件时如果路径不存在会创建这个路径 + print("--- new folder... ---") + + else: + print("--- There is this folder! ---") + torch.save({'model_state_dict': my_model.state_dict(), \ + 'optimizer_state_dict': self.optimizer.state_dict(), \ + 'inf_emb_ent': val_init_emb_ent, \ + 'inf_emb_rel': val_init_emb_rel}, \ + path_ckp + f"/{file_format}_{epoch + 1}.ckpt") + + my_model.train() diff --git a/openhgnn/trainerflow/LTE_trainer.py b/openhgnn/trainerflow/LTE_trainer.py new file mode 100644 index 00000000..6693123e --- /dev/null +++ b/openhgnn/trainerflow/LTE_trainer.py @@ -0,0 +1,329 @@ + +from . import BaseFlow, register_flow +from ..models import build_model + + +@register_flow("LTE_trainer") +class LTETrainer(BaseFlow): + + def __init__(self, args): + print("2") + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) + device = torch.device('cuda:0') + self.device = device + args.device= device + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + + runner = Runner(args) + runner.fit() + + + def train(self): + print("1") + + +import os +import argparse +import time +import logging +from pprint import pprint +import numpy as np +import random +from pathlib import Path +import torch +from torch.utils.data import DataLoader +import dgl +from ..utils.lte_knowledge_graph import load_data + +from ..utils.lte_data_set import TrainDataset, TestDataset +from ..utils.lte_process_data import process + + + + + +class Runner(object): + def __init__(self, params): + params.embed_dim=None + params.r_ops="" + self.p = params + self.prj_path = Path(__file__).parent.resolve() + self.p.dataset_name=self.p.data + self.data = load_data(self.p.dataset_name) + self.num_ent, self.train_data, self.valid_data, self.test_data, self.num_rels = self.data.num_nodes, self.data.train, self.data.valid, self.data.test, self.data.num_rels + self.triplets = process({'train': self.train_data, 'valid': self.valid_data, 'test': self.test_data}, + self.num_rels) + + self.p.embed_dim = self.p.k_w * \ + self.p.k_h if self.p.embed_dim is None else self.p.embed_dim # output dim of gnn + self.data_iter = self.get_data_iter() + + if self.p.gpu >= 0: + self.g = self.build_graph().to(self.p.device) + else: + self.g = self.build_graph() + self.edge_type, self.edge_norm = self.get_edge_dir_and_norm() + self.model = self.get_model() + self.optimizer = torch.optim.Adam( + self.model.parameters(), lr=self.p.lr, weight_decay=self.p.l2) + self.best_val_mrr, self.best_epoch, self.best_val_results = 0., 0., {} + os.makedirs('./logs', exist_ok=True) + self.logger = logging.getLogger(__name__) + pprint(vars(self.p)) + + def fit(self): + save_root = self.prj_path / 'checkpoints' + + if not save_root.exists(): + save_root.mkdir() + save_path = save_root / (self.p.name + '.pt') + + if self.p.restore: + self.load_model(save_path) + self.logger.info('Successfully Loaded previous model') + + for epoch in range(self.p.max_epochs): + start_time = time.time() + train_loss = self.train() + val_results = self.evaluate('valid') + if val_results['mrr'] > self.best_val_mrr: + self.best_val_results = val_results + self.best_val_mrr = val_results['mrr'] + self.best_epoch = epoch + # self.save_model(save_path) + print(f"hits@1 = {val_results['hits@1']:.5}") + print(f"hits@3 = {val_results['hits@3']:.5}") + print(f"hits@10 = {val_results['hits@10']:.5}") + print( + f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid MRR: {val_results['mrr']:.5}, Best Valid MRR: {self.best_val_mrr:.5}, Cost: {time.time() - start_time:.2f}s") + self.logger.info( + f"[Epoch {epoch}]: Training Loss: {train_loss:.5}, Valid MRR: {val_results['mrr']:.5}, Best Valid MRR: {self.best_val_mrr:.5}, Cost: {time.time() - start_time:.2f}s") + self.logger.info(vars(self.p)) + # self.load_model(save_path) + self.logger.info( + f'Loading best model in {self.best_epoch} epoch, Evaluating on Test data') + start = time.time() + test_results = self.evaluate('test') + end = time.time() + self.logger.info( + f"MRR: Tail {test_results['left_mrr']:.5}, Head {test_results['right_mrr']:.5}, Avg {test_results['mrr']:.5}") + self.logger.info( + f"MR: Tail {test_results['left_mr']:.5}, Head {test_results['right_mr']:.5}, Avg {test_results['mr']:.5}") + self.logger.info(f"hits@1 = {test_results['hits@1']:.5}") + self.logger.info(f"hits@3 = {test_results['hits@3']:.5}") + self.logger.info(f"hits@10 = {test_results['hits@10']:.5}") + self.logger.info("time ={}".format(end-start)) + + def train(self): + self.model.train() + losses = [] + train_iter = self.data_iter['train'] + for step, (triplets, labels) in enumerate(train_iter): + if self.p.gpu >= 0: + triplets, labels = triplets.to(self.p.device), labels.to(self.p.device) + subj, rel = triplets[:, 0], triplets[:, 1] + # print(subj) + pred = self.model(self.g, subj, rel) # [batch_size, num_ent] + loss = self.model.calc_loss(pred, labels) + # print(loss) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + losses.append(loss.item()) + + loss = np.mean(losses) + return loss + + def evaluate(self, split): + """ + Function to evaluate the model on validation or test set + :param split: valid or test, set which data-set to evaluate on + :return: results['mr']: Average of ranks_left and ranks_right + results['mrr']: Mean Reciprocal Rank + results['hits@k']: Probability of getting the correct prediction in top-k ranks based on predicted score + results['left_mrr'], results['left_mr'], results['right_mrr'], results['right_mr'] + results['left_hits@k'], results['right_hits@k'] + """ + + def get_combined_results(left, right): + results = dict() + assert left['count'] == right['count'] + count = float(left['count']) + results['left_mr'] = round(left['mr'] / count, 5) + results['left_mrr'] = round(left['mrr'] / count, 5) + results['right_mr'] = round(right['mr'] / count, 5) + results['right_mrr'] = round(right['mrr'] / count, 5) + results['mr'] = round((left['mr'] + right['mr']) / (2 * count), 5) + results['mrr'] = round( + (left['mrr'] + right['mrr']) / (2 * count), 5) + for k in [1, 3, 10]: + results[f'left_hits@{k}'] = round(left[f'hits@{k}'] / count, 5) + results[f'right_hits@{k}'] = round( + right[f'hits@{k}'] / count, 5) + results[f'hits@{k}'] = round( + (results[f'left_hits@{k}'] + results[f'right_hits@{k}']) / 2, 5) + return results + + self.model.eval() + left_result = self.predict(split, 'tail') + right_result = self.predict(split, 'head') + res = get_combined_results(left_result, right_result) + return res + + def predict(self, split='valid', mode='tail'): + """ + Function to run model evaluation for a given mode + :param split: valid or test, set which data-set to evaluate on + :param mode: head or tail + :return: results['mr']: Sum of ranks + results['mrr']: Sum of Reciprocal Rank + results['hits@k']: counts of getting the correct prediction in top-k ranks based on predicted score + results['count']: number of total predictions + """ + with torch.no_grad(): + results = dict() + test_iter = self.data_iter[f'{split}_{mode}'] + for step, (triplets, labels) in enumerate(test_iter): + triplets, labels = triplets.to(self.p.device), labels.to(self.p.device) + subj, rel, obj = triplets[:, 0], triplets[:, 1], triplets[:, 2] + pred = self.model(self.g, subj, rel) + b_range = torch.arange(pred.shape[0], device=self.p.device) + # [batch_size, 1], get the predictive score of obj + target_pred = pred[b_range, obj] + # label=>-1000000, not label=>pred, filter out other objects with same sub&rel pair + pred = torch.where( + labels.bool(), -torch.ones_like(pred) * 10000000, pred) + # copy predictive score of obj to new pred + pred[b_range, obj] = target_pred + ranks = 1 + torch.argsort(torch.argsort(pred, dim=1, descending=True), dim=1, descending=False)[ + b_range, obj] # get the rank of each (sub, rel, obj) + ranks = ranks.float() + results['count'] = torch.numel( + ranks) + results.get('count', 0) # number of predictions + results['mr'] = torch.sum(ranks).item() + results.get('mr', 0) + results['mrr'] = torch.sum( + 1.0 / ranks).item() + results.get('mrr', 0) + + for k in [1, 3, 10]: + results[f'hits@{k}'] = torch.numel( + ranks[ranks <= k]) + results.get(f'hits@{k}', 0) + return results + + def save_model(self, path): + """ + Function to save a model. It saves the model parameters, best validation scores, + best epoch corresponding to best validation, state of the optimizer and all arguments for the run. + :param path: path where the model is saved + :return: + """ + state = { + 'model': self.model.state_dict(), + 'best_val': self.best_val_results, + 'best_epoch': self.best_epoch, + 'optimizer': self.optimizer.state_dict(), + 'args': vars(self.p) + } + torch.save(state, path) + + def load_model(self, path): + """ + Function to load a saved model + :param path: path where model is loaded + :return: + """ + state = torch.load(path) + self.best_val_results = state['best_val'] + self.best_val_mrr = self.best_val_results['mrr'] + self.best_epoch = state['best_epoch'] + self.model.load_state_dict(state['model']) + self.optimizer.load_state_dict(state['optimizer']) + + def build_graph(self): + g = dgl.DGLGraph() + g.add_nodes(self.num_ent) + + if not self.p.rat: + g.add_edges(self.train_data[:, 0], self.train_data[:, 2]) + g.add_edges(self.train_data[:, 2], self.train_data[:, 0]) + else: + if self.p.ss > 0: + sampleSize = self.p.ss + else: + sampleSize = self.num_ent - 1 + g.add_edges(self.train_data[:, 0], np.random.randint( + low=0, high=sampleSize, size=self.train_data[:, 2].shape)) + g.add_edges(self.train_data[:, 2], np.random.randint( + low=0, high=sampleSize, size=self.train_data[:, 0].shape)) + return g + + def get_data_iter(self): + """ + get data loader for train, valid and test section + :return: dict + """ + + def get_data_loader(dataset_class, split): + return DataLoader( + dataset_class(self.triplets[split], self.num_ent, self.p), + batch_size=self.p.batch_size, + shuffle=True, + num_workers=self.p.num_workers + ) + + return { + 'train': get_data_loader(TrainDataset, 'train'), + 'valid_head': get_data_loader(TestDataset, 'valid_head'), + 'valid_tail': get_data_loader(TestDataset, 'valid_tail'), + 'test_head': get_data_loader(TestDataset, 'test_head'), + 'test_tail': get_data_loader(TestDataset, 'test_tail') + } + + def get_edge_dir_and_norm(self): + """ + :return: edge_type: indicates type of each edge: [E] + """ + in_deg = self.g.in_degrees(range(self.g.number_of_nodes())).float() + norm = in_deg ** -0.5 + norm[torch.isinf(norm).bool()] = 0 + self.g.ndata['xxx'] = norm + self.g.apply_edges( + lambda edges: {'xxx': edges.dst['xxx'] * edges.src['xxx']}) + if self.p.gpu >= 0: + norm = self.g.edata.pop('xxx').squeeze().to(self.p.device) + edge_type = torch.tensor(np.concatenate( + [self.train_data[:, 1], self.train_data[:, 1] + self.num_rels])).to(self.p.device) + else: + norm = self.g.edata.pop('xxx').squeeze() + edge_type = torch.tensor(np.concatenate( + [self.train_data[:, 1], self.train_data[:, 1] + self.num_rels])) + return edge_type, norm + + def get_model(self): + args = self.p + args.num_ents = self.num_ent + args.num_rels= self.num_rels + if self.p.n_layer > 0: + if self.p.score_func.lower() == 'transe': + model = build_model(args.model_name_GCN).build_model_from_args( + args).model + # model = GCN_TransE(args) + else: + raise KeyError( + f'score function {self.p.score_func} not recognized.') + else: + if self.p.score_func.lower() == 'transe': + model = build_model(args.model_name).build_model_from_args( + args).model + # model = TransE(self.num_ent, self.num_rels, params=self.p) + else: + raise NotImplementedError + + if self.p.gpu >= 0: + model.to(self.p.device) + return model + + + diff --git a/openhgnn/trainerflow/NBF_trainer.py b/openhgnn/trainerflow/NBF_trainer.py new file mode 100644 index 00000000..bb7c7312 --- /dev/null +++ b/openhgnn/trainerflow/NBF_trainer.py @@ -0,0 +1,275 @@ +from ..models import build_model +from . import BaseFlow, register_flow +from ..tasks import build_task +from ..utils import extract_embed, EarlyStopping +import os +import sys +import math +import torch +from torch import nn +from torch.nn import functional as F +from torch import distributed as dist +from torch.utils import data as torch_data +import logging +from ..models import NBF +separator = ">" * 30 +line = "-" * 30 + +def train_and_validate(args, model, train_data, valid_data, filtered_data=None): + if args.num_epoch == 0: + return + + world_size = get_world_size() + rank = get_rank() + + train_triplets = torch.cat([train_data.target_edge_index, train_data.target_edge_type.unsqueeze(0)]).t() + sampler = torch_data.DistributedSampler(train_triplets, world_size, rank) + train_loader = torch_data.DataLoader(train_triplets, args.batch_size, sampler=sampler) + + optimizer = ( + torch.optim.Adam(model.parameters(), lr=args.lr)) + + if world_size > 1: + parallel_model = nn.parallel.DistributedDataParallel(model, device_ids=[args.device]) + else: + parallel_model = model + + step = math.ceil(args.num_epoch / 10) + best_result = float("-inf") + best_epoch = -1 + + batch_id = 0 + for i in range(0, args.num_epoch, step): + parallel_model.train() + for epoch in range(i, min(args.num_epoch, i + step)): + if get_rank() == 0: + logger.warning(separator) + logger.warning("Epoch %d begin" % epoch) + + losses = [] + sampler.set_epoch(epoch) + for batch in train_loader: + batch = NBF.negative_sampling(train_data, batch, args.num_negative, + strict=args.strict_negative) + pred = parallel_model(train_data, batch)#forward + target = torch.zeros_like(pred) + target[:, 0] = 1 + loss = F.binary_cross_entropy_with_logits(pred, target, reduction="none") + neg_weight = torch.ones_like(pred) + if args.adversarial_temperature > 0: + with torch.no_grad(): + neg_weight[:, 1:] = F.softmax(pred[:, 1:] / args.adversarial_temperature, dim=-1) + else: + neg_weight[:, 1:] = 1 / args.num_negative + loss = (loss * neg_weight).sum(dim=-1) / neg_weight.sum(dim=-1) + loss = loss.mean() + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + if get_rank() == 0 and batch_id % args.log_interval == 0: + logger.warning(separator) + logger.warning("binary cross entropy: %g" % loss) + losses.append(loss.item()) + batch_id += 1 + + if get_rank() == 0: + avg_loss = sum(losses) / len(losses) + logger.warning(separator) + logger.warning("Epoch %d end" % epoch) + logger.warning(line) + logger.warning("average binary cross entropy: %g" % avg_loss) + + epoch = min(args.num_epoch, i + step) + if rank == 0: + logger.warning("Save checkpoint to model_epoch_%d.pth" % epoch) + state = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict() + } + torch.save(state, "model_epoch_%d.pth" % epoch) + synchronize() + + if rank == 0: + logger.warning(separator) + logger.warning("Evaluate on valid") + + + result = test(args, model, valid_data, filtered_data=filtered_data) + if result > best_result: + best_result = result + best_epoch = epoch + + if rank == 0: + logger.warning("Load checkpoint from model_epoch_%d.pth" % best_epoch) + state = torch.load("model_epoch_%d.pth" % best_epoch, map_location=args.device) + model.load_state_dict(state["model"]) + synchronize() + +@torch.no_grad() +def test(args, model, test_data, filtered_data=None): + world_size = get_world_size() + rank = get_rank() + + test_triplets = torch.cat([test_data.target_edge_index, test_data.target_edge_type.unsqueeze(0)]).t() + sampler = torch_data.DistributedSampler(test_triplets, world_size, rank) + test_loader = torch_data.DataLoader(test_triplets, args.batch_size, sampler=sampler) + + model.eval() + rankings = [] + num_negatives = [] + for batch in test_loader: + t_batch, h_batch = NBF.all_negative(test_data, batch) + t_pred = model(test_data, t_batch) + h_pred = model(test_data, h_batch) + + if filtered_data is None: + t_mask, h_mask = NBF.strict_negative_mask(test_data, batch) + else: + t_mask, h_mask = NBF.strict_negative_mask(filtered_data, batch) + pos_h_index, pos_t_index, pos_r_index = batch.t() + t_ranking = NBF.compute_ranking(t_pred, pos_t_index, t_mask) + h_ranking = NBF.compute_ranking(h_pred, pos_h_index, h_mask) + num_t_negative = t_mask.sum(dim=-1) + num_h_negative = h_mask.sum(dim=-1) + + rankings += [t_ranking, h_ranking] + num_negatives += [num_t_negative, num_h_negative] + + ranking = torch.cat(rankings) + num_negative = torch.cat(num_negatives) + all_size = torch.zeros(world_size, dtype=torch.long, device=args.device) + all_size[rank] = len(ranking) + if world_size > 1: + dist.all_reduce(all_size, op=dist.ReduceOp.SUM) + cum_size = all_size.cumsum(0) + all_ranking = torch.zeros(all_size.sum(), dtype=torch.long, device=args.device) + all_ranking[cum_size[rank] - all_size[rank]: cum_size[rank]] = ranking + all_num_negative = torch.zeros(all_size.sum(), dtype=torch.long, device=args.device) + all_num_negative[cum_size[rank] - all_size[rank]: cum_size[rank]] = num_negative + if world_size > 1: + dist.all_reduce(all_ranking, op=dist.ReduceOp.SUM) + dist.all_reduce(all_num_negative, op=dist.ReduceOp.SUM) + + if rank == 0: + for metric in args.metric: + if metric == "mr": + score = all_ranking.float().mean() + elif metric == "mrr": + score = (1 / all_ranking.float()).mean() + elif metric.startswith("hits@"): + values = metric[5:].split("_") + threshold = int(values[0]) + if len(values) > 1: + num_sample = int(values[1]) + # unbiased estimation + fp_rate = (all_ranking - 1).float() / all_num_negative + score = 0 + for i in range(threshold): + # choose i false positive from num_sample - 1 negatives + num_comb = math.factorial(num_sample - 1) / \ + math.factorial(i) / math.factorial(num_sample - i - 1) + score += num_comb * (fp_rate ** i) * ((1 - fp_rate) ** (num_sample - i - 1)) + score = score.mean() + else: + score = (all_ranking <= threshold).float().mean() + logger.warning("%s: %g" % (metric, score)) + mrr = (1 / all_ranking.float()).mean() + + return mrr + +logger = logging.getLogger(__file__) + +def get_rank(): # get random seed + if dist.is_initialized(): + return dist.get_rank() + if "RANK" in os.environ: + return int(os.environ["RANK"]) + return 0 + +def get_world_size(): + if dist.is_initialized(): + return dist.get_world_size() + if "WORLD_SIZE" in os.environ: + return int(os.environ["WORLD_SIZE"]) + return 1 + +def synchronize(): + if get_world_size() > 1: + dist.barrier() + + + + +@register_flow("NBF_trainer") +class NBF_trainer(BaseFlow): + + """ + NGF_trainer + + """ + + def __init__(self, args): # args == self.config + args.task = args.model +"_" +args.task # task: NBF_link_prediction + self.args = args + self.model_name = args.model + self.device = args.device + self.hg = None + + self.task = build_task(args) + self.dataset = self.task.dataset.dataset # train_data,valid_data,test_data + + + # Build the model. + self.args.num_relation = self.dataset.num_relations + + self.model = args.model + self.model = build_model(self.model).build_model_from_args(self.args, self.hg) + + + + def train(self): + + filtered_data = None + self.model = self.model.to(self.device) + + train_data, valid_data, test_data = self.dataset.train_data , self.dataset.valid_data , self.dataset.test_data + + + train_data.edge_index = train_data.edge_index.to(self.device) + train_data.edge_type = train_data.edge_type.to(self.device) + train_data.target_edge_index = train_data.target_edge_index.to(self.device) + train_data.target_edge_type = train_data.target_edge_type.to(self.device) + + valid_data.edge_index = valid_data.edge_index.to(self.device) + valid_data.edge_type = valid_data.edge_type.to(self.device) + valid_data.target_edge_index = valid_data.target_edge_index.to(self.device) + valid_data.target_edge_type = valid_data.target_edge_type.to(self.device) + + + test_data.edge_index = test_data.edge_index.to(self.device) + test_data.edge_type = test_data.edge_type.to(self.device) + test_data.target_edge_index = test_data.target_edge_index.to(self.device) + test_data.target_edge_type = test_data.target_edge_type.to(self.device) + + train_and_validate(self.args, self.model, train_data, valid_data, filtered_data=filtered_data) + + if get_rank() == 0: + logger.warning(separator) + logger.warning("Evaluate on valid") + + test(self.args, self.model, valid_data, filtered_data=filtered_data) + + if get_rank() == 0: + logger.warning(separator) + logger.warning("Evaluate on test") + + test(self.args, self.model, test_data, filtered_data=filtered_data) + + + + + + + diff --git a/openhgnn/trainerflow/RedGNNT_trainer.py b/openhgnn/trainerflow/RedGNNT_trainer.py new file mode 100644 index 00000000..9a07ee96 --- /dev/null +++ b/openhgnn/trainerflow/RedGNNT_trainer.py @@ -0,0 +1,157 @@ +import torch as th +from tqdm import tqdm +from . import BaseFlow, register_flow +from ..models import build_model +from ..utils import EarlyStopping +from scipy.sparse import csr_matrix +from ..sampler.TransX_sampler import TransX_Sampler +import numpy as np +from torch.optim.lr_scheduler import ExponentialLR +from scipy.stats import rankdata +from openhgnn.tasks import build_task + +def cal_performance(ranks): + mrr = (1. / ranks).sum() / len(ranks) + h_1 = sum(ranks<=1) * 1.0 / len(ranks) + h_10 = sum(ranks<=10) * 1.0 / len(ranks) + return mrr, h_1, h_10 + +def cal_ranks(scores, labels, filters): + scores = scores - np.min(scores, axis=1, keepdims=True) + 1e-8 + full_rank = rankdata(-scores, method='average', axis=1) + filter_scores = scores * filters + filter_rank = rankdata(-filter_scores, method='min', axis=1) + ranks = (full_rank - filter_rank + 1) * labels # get the ranks of multiple answering entities simultaneously + ranks = ranks[np.nonzero(ranks)] + return list(ranks) + + +@register_flow("RedGNNT_trainer") +class RedGNNTTrainer(BaseFlow): + """RedGNN flows.""" + + def __init__(self, args): + super(RedGNNTTrainer, self).__init__(args) + + self.args = args + self.model_name = args.model + self.device = args.device + self.task = build_task(args) + self.batch_size = args.batch_size + self.n_tbatch = args.n_tbatch + self.max_epoch = args.max_epoch + + self.loader = self.task.dataset + self.n_ent = self.loader.n_ent + self.n_rel = self.loader.n_rel + + self.n_train = self.loader.n_train + self.n_valid = self.loader.n_valid + self.n_test = self.loader.n_test + + self.smooth = 1e-5 + + self.model = build_model(self.model).build_model_from_args(self.args, self.task.dataset) + self.model = self.model.to(self.device) + + self.optimizer = self.candidate_optimizer[args.optimizer](self.model.parameters(), + lr=args.lr, weight_decay=args.weight_decay) + self.scheduler = ExponentialLR(self.optimizer, args.decay_rate) + self.stopper = EarlyStopping(args.patience, self._checkpoint) + + def train(self): + for epoch in range(self.max_epoch): + mrr, out_str = self.train_batch() + if epoch % self.evaluate_interval == 0: + self.logger.info("[Evaluation metric] " + out_str) # out test result + early_stop = self.stopper.loss_step(-mrr, self.model) # less is better + if early_stop: + self.logger.train_info(f'Early Stop!\tEpoch:{epoch:03d}.') + break + + + def train_batch(self): + epoch_loss = 0 + + batch_size = self.batch_size + n_batch = self.loader.n_train // batch_size + (self.loader.n_train % batch_size > 0) + + self.model.train() + for i in range(n_batch): + start = i * batch_size + end = min(self.n_train, (i + 1) * batch_size) + batch_idx = np.arange(start, end) + triple = self.loader.get_batch(batch_idx) + + self.model.zero_grad() + scores = self.model(triple[:, 0], triple[:, 1]) + pos_scores = scores[[th.arange(len(scores)).to(self.device), th.LongTensor(triple[:, 2]).to(self.device)]] + max_n = th.max(scores, 1, keepdim=True)[0] + loss = th.sum(- pos_scores + max_n + th.log(th.sum(th.exp(scores - max_n), 1))) + loss.backward() + self.optimizer.step() + + # avoid NaN + for p in self.model.parameters(): + X = p.data.clone() + flag = X != X + X[flag] = np.random.random() + p.data.copy_(X) + epoch_loss += loss.item() + self.scheduler.step() + self.loader.shuffle_train() + valid_mrr, out_str = self.evaluate() + return valid_mrr, out_str + + def evaluate(self, ): + batch_size = self.n_tbatch + + n_data = self.n_valid + n_batch = n_data // batch_size + (n_data % batch_size > 0) + ranking = [] + self.model.eval() + for i in range(n_batch): + start = i * batch_size + end = min(n_data, (i + 1) * batch_size) + batch_idx = np.arange(start, end) + subs, rels, objs = self.loader.get_batch(batch_idx, data='valid') + scores = self.model(subs, rels, mode='valid').data.cpu().numpy() + filters = [] + for i in range(len(subs)): + filt = self.loader.filters[(subs[i], rels[i])] + filt_1hot = np.zeros((self.n_ent,)) + filt_1hot[np.array(filt)] = 1 + filters.append(filt_1hot) + + filters = np.array(filters) + ranks = cal_ranks(scores, objs, filters) + ranking += ranks + ranking = np.array(ranking) + v_mrr, v_h1, v_h10 = cal_performance(ranking) + + n_data = self.n_test + n_batch = n_data // batch_size + (n_data % batch_size > 0) + ranking = [] + self.model.eval() + for i in range(n_batch): + start = i * batch_size + end = min(n_data, (i + 1) * batch_size) + batch_idx = np.arange(start, end) + subs, rels, objs = self.loader.get_batch(batch_idx, data='test') + scores = self.model(subs, rels, mode='test').data.cpu().numpy() + filters = [] + for i in range(len(subs)): + filt = self.loader.filters[(subs[i], rels[i])] + filt_1hot = np.zeros((self.n_ent,)) + filt_1hot[np.array(filt)] = 1 + filters.append(filt_1hot) + + filters = np.array(filters) + ranks = cal_ranks(scores, objs, filters) + ranking += ranks + ranking = np.array(ranking) + t_mrr, t_h1, t_h10 = cal_performance(ranking) + + out_str = '[VALID] MRR:%.4f H@1:%.4f H@10:%.4f\t [TEST] MRR:%.4f H@1:%.4f H@10:%.4f \t' % ( + v_mrr, v_h1, v_h10, t_mrr, t_h1, t_h10) + return v_mrr, out_str diff --git a/openhgnn/trainerflow/RedGNN_trainer.py b/openhgnn/trainerflow/RedGNN_trainer.py new file mode 100644 index 00000000..15bb6b47 --- /dev/null +++ b/openhgnn/trainerflow/RedGNN_trainer.py @@ -0,0 +1,155 @@ +import torch as th +from tqdm import tqdm +from . import BaseFlow, register_flow +from ..models import build_model +from ..utils import EarlyStopping +from scipy.sparse import csr_matrix +from ..sampler.TransX_sampler import TransX_Sampler +import numpy as np +from torch.optim.lr_scheduler import ExponentialLR +from scipy.stats import rankdata +from openhgnn.tasks import build_task + +def cal_performance(ranks): + mrr = (1. / ranks).sum() / len(ranks) + h_1 = sum(ranks<=1) * 1.0 / len(ranks) + h_10 = sum(ranks<=10) * 1.0 / len(ranks) + return mrr, h_1, h_10 + +def cal_ranks(scores, labels, filters): + scores = scores - np.min(scores, axis=1, keepdims=True) + 1e-8 + full_rank = rankdata(-scores, method='average', axis=1) + filter_scores = scores * filters + filter_rank = rankdata(-filter_scores, method='min', axis=1) + ranks = (full_rank - filter_rank + 1) * labels # get the ranks of multiple answering entities simultaneously + ranks = ranks[np.nonzero(ranks)] + return list(ranks) + + +@register_flow("RedGNN_trainer") +class RedGNNTrainer(BaseFlow): + """RedGNN flows.""" + + def __init__(self, args): + super(RedGNNTrainer, self).__init__(args) + + self.args = args + self.model_name = args.model + self.device = args.device + self.task = build_task(args) + self.batch_size = args.batch_size + self.max_epoch = args.max_epoch + + self.loader = self.task.dataset + self.n_ent = self.loader.n_ent + self.n_ent_ind = self.loader.n_ent_ind + + self.n_train = self.loader.n_train + self.n_valid = self.loader.n_valid + self.n_test = self.loader.n_test + + self.model = build_model(self.model).build_model_from_args(self.args, self.task.dataset) + self.model = self.model.to(self.device) + + self.optimizer = self.candidate_optimizer[args.optimizer](self.model.parameters(), + lr=args.lr, weight_decay=args.weight_decay) + self.scheduler = ExponentialLR(self.optimizer, args.decay_rate) + self.stopper = EarlyStopping(args.patience, self._checkpoint) + + def train(self): + for epoch in range(self.max_epoch): + mrr, out_str = self.train_batch() + if epoch % self.evaluate_interval == 0: + self.logger.info("[Evaluation metric] " + out_str) # out test result + early_stop = self.stopper.loss_step(-mrr, self.model) # less is better + if early_stop: + self.logger.train_info(f'Early Stop!\tEpoch:{epoch:03d}.') + break + + + def train_batch(self): + epoch_loss = 0 + + batch_size = self.batch_size + n_batch = self.n_train // batch_size + (self.n_train % batch_size > 0) + + self.model.train() + for i in range(n_batch): + start = i * batch_size + end = min(self.n_train, (i + 1) * batch_size) + batch_idx = np.arange(start, end) + triple = self.loader.get_batch(batch_idx) + + self.model.zero_grad() + scores = self.model(triple[:, 0], triple[:, 1]) + pos_scores = scores[[th.arange(len(scores)).to(self.device), th.LongTensor(triple[:, 2]).to(self.device)]] + max_n = th.max(scores, 1, keepdim=True)[0] + loss = th.sum(- pos_scores + max_n + th.log(th.sum(th.exp(scores - max_n), 1))) + loss.backward() + self.optimizer.step() + + # avoid NaN + for p in self.model.parameters(): + X = p.data.clone() + flag = X != X + X[flag] = np.random.random() + p.data.copy_(X) + epoch_loss += loss.item() + self.scheduler.step() + + valid_mrr, out_str = self.evaluate() + return valid_mrr, out_str + + + def evaluate(self, ): + batch_size = self.batch_size + + n_data = self.n_valid + n_batch = n_data // batch_size + (n_data % batch_size > 0) + ranking = [] + self.model.eval() + for i in range(n_batch): + start = i * batch_size + end = min(n_data, (i + 1) * batch_size) + batch_idx = np.arange(start, end) + subs, rels, objs = self.loader.get_batch(batch_idx, data='valid') + scores = self.model(subs, rels).data.cpu().numpy() + filters = [] + for i in range(len(subs)): + filt = self.loader.val_filters[(subs[i], rels[i])] + filt_1hot = np.zeros((self.n_ent,)) + filt_1hot[np.array(filt)] = 1 + filters.append(filt_1hot) + + filters = np.array(filters) + ranks = cal_ranks(scores, objs, filters) + ranking += ranks + ranking = np.array(ranking) + v_mrr, v_h1, v_h10 = cal_performance(ranking) + + n_data = self.n_test + n_batch = n_data // batch_size + (n_data % batch_size > 0) + ranking = [] + self.model.eval() + for i in range(n_batch): + start = i * batch_size + end = min(n_data, (i + 1) * batch_size) + batch_idx = np.arange(start, end) + subs, rels, objs = self.loader.get_batch(batch_idx, data='test') + scores = self.model(subs, rels, 'inductive').data.cpu().numpy() + filters = [] + for i in range(len(subs)): + filt = self.loader.tst_filters[(subs[i], rels[i])] + filt_1hot = np.zeros((self.n_ent_ind,)) + filt_1hot[np.array(filt)] = 1 + filters.append(filt_1hot) + + filters = np.array(filters) + ranks = cal_ranks(scores, objs, filters) + ranking += ranks + ranking = np.array(ranking) + t_mrr, t_h1, t_h10 = cal_performance(ranking) + + out_str = '[VALID] MRR:%.4f H@1:%.4f H@10:%.4f\t [TEST] MRR:%.4f H@1:%.4f H@10:%.4f' % ( + v_mrr, v_h1, v_h10, t_mrr, t_h1, t_h10) + return v_mrr, out_str diff --git a/openhgnn/trainerflow/SACN_trainer.py b/openhgnn/trainerflow/SACN_trainer.py new file mode 100644 index 00000000..a194ee55 --- /dev/null +++ b/openhgnn/trainerflow/SACN_trainer.py @@ -0,0 +1,251 @@ +import os +import random +import logging +import argparse +import math +import dgl +import numpy as np +import time +from numpy.random.mtrand import set_state +import pandas +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.init import xavier_normal_ +from dgl import function as fn +from torch.utils.data import DataLoader +from ..utils.wgcn_data import load_data + +from ..utils.wgcn_evaluation_dgl import ranking_and_hits +from ..utils.wgcn_utils import EarlyStopping +from ..utils.wgcn_batch_prepare import TrainBatchPrepare, EvalBatchPrepare +from . import BaseFlow, register_flow +from ..models import build_model +from . import BaseFlow, register_flow +from ..tasks import build_task + + + + + + +@register_flow("SACN_trainer") +class SACNTrainer(BaseFlow): + def __init__(self, args): + self.args = args + self.args.dataset_name=args.dataset + self.model_name=args.model + args.model_name=args.model + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) + seed = args.seed + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + self.task = build_task(args) + # self.model = build_model(self.model_name).build_model_from_args( + # self.args).model + print(args) + main(args) + def train(self): + print("111") + +def process_triplets(triplets, all_dict, num_rels): + """ + process triples, store the id of all entities corresponding to (head, rel) + and (tail, rel_reverse) into dict + """ + data_dict = {} + for i in range(triplets.shape[0]): + e1, rel, e2 = triplets[i] + rel_reverse = rel + num_rels + + if (e1, rel) not in data_dict: + data_dict[(e1, rel)] = set() + if (e2, rel_reverse) not in data_dict: + data_dict[(e2, rel_reverse)] = set() + + if (e1, rel) not in all_dict: + all_dict[(e1, rel)] = set() + if (e2, rel_reverse) not in all_dict: + all_dict[(e2, rel_reverse)] = set() + + all_dict[(e1, rel)].add(e2) + all_dict[(e2, rel_reverse)].add(e1) + + data_dict[(e1, rel)].add(e2) + data_dict[(e2, rel_reverse)].add(e1) + + return data_dict + + +def preprocess_data(train_data, valid_data, test_data, num_rels): + all_dict = {} + + train_dict = process_triplets(train_data, all_dict, num_rels) + valid_dict = process_triplets(valid_data, all_dict, num_rels) + test_dict = process_triplets(test_data, all_dict, num_rels) + + return train_dict, valid_dict, test_dict, all_dict + + +def main(args): + os.makedirs('./logs', exist_ok=True) + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler(os.path.join( + 'logs', args.decoder+"_"+args.name)), + logging.StreamHandler() + ]) + logger = logging.getLogger(__name__) + # load graph data + data = load_data(args.dataset_data) + num_nodes = data.num_nodes + train_data = data.train + valid_data = data.valid + test_data = data.test + num_rels = data.num_rels + + save_path = 'checkpoints/' + os.makedirs(save_path, exist_ok=True) + stopper = EarlyStopping( + save_path=save_path, model_name=args.decoder+"_"+args.name, patience=args.patience) + + # check cuda + if args.gpu >= 0: + device = torch.device('cuda') + else: + device = torch.device('cpu') + args.num_entities=num_nodes + args.num_relations=num_rels * 2 + 1 + # create model + model = build_model(args.model_name).build_model_from_args( + args).model + # model = WGCN(num_entities=num_nodes, + # num_relations=num_rels * 2 + 1, args=args) + + # build graph + g = dgl.graph([]) + g.add_nodes(num_nodes) + src, rel, dst = train_data.transpose() + # add reverse edges, reverse relation id is between [num_rels, 2*num_rels) + src, dst = np.concatenate((src, dst)), np.concatenate((dst, src)) + rel = np.concatenate((rel, rel + num_rels)) + # get new train_data with reverse relation + train_data_new = np.stack((src, rel, dst)).transpose() + + # unique train data by (h,r) + train_data_new_pandas = pandas.DataFrame(train_data_new) + train_data_new_pandas = train_data_new_pandas.drop_duplicates([0, 1]) + train_data_unique = np.asarray(train_data_new_pandas) + + if not args.wni: + if args.rat: + if args.ss > 0: + high = args.ss + else: + high = num_nodes + g.add_edges(src, np.random.randint( + low=0, high=high, size=dst.shape)) + else: + g.add_edges(src, dst) + + # add graph self loop + if not args.wsi: + g.add_edges(g.nodes(), g.nodes()) + # add self loop relation type, self loop relation's id is 2*num_rels. + if args.wni: + rel = np.ones([num_nodes]) * num_rels * 2 + else: + rel = np.concatenate((rel, np.ones([num_nodes]) * num_rels * 2)) + print(g) + entity_id = torch.LongTensor([i for i in range(num_nodes)]) + + model = model.to(device) + g = g.to(device) + all_rel = torch.LongTensor(rel).to(device) + entity_id = entity_id.to(device) + + # optimizer + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + + # process the triples and get all tails corresponding to (h,r) + # here valid_dict and test_dict are not used. + train_dict, valid_dict, test_dict, all_dict = preprocess_data( + train_data, valid_data, test_data, num_rels) + + train_batch_prepare = TrainBatchPrepare(train_dict, num_nodes) + + # eval needs to use all the data in train_data, valid_data and test_data + eval_batch_prepare = EvalBatchPrepare(all_dict, num_rels) + + train_dataloader = DataLoader( + dataset=train_data_unique, + batch_size=args.batch_size, + collate_fn=train_batch_prepare.get_batch, + shuffle=True, + drop_last=False, + num_workers=args.num_workers) + + valid_dataloader = DataLoader( + dataset=valid_data, + batch_size=args.batch_size, + collate_fn=eval_batch_prepare.get_batch, + shuffle=False, + drop_last=False, + num_workers=args.num_workers) + + test_dataloader = DataLoader( + dataset=test_data, + batch_size=args.batch_size, + collate_fn=eval_batch_prepare.get_batch, + shuffle=False, + drop_last=False, + num_workers=args.num_workers) + + # training loop + print("start training...") + for epoch in range(args.n_epochs): + model.train() + epoch_start_time = time.time() + for step, batch_tuple in enumerate(train_dataloader): + e1_batch, rel_batch, labels_one_hot = batch_tuple + e1_batch = e1_batch.to(device) + rel_batch = rel_batch.to(device) + labels_one_hot = labels_one_hot.to(device) + labels_one_hot = ((1.0 - 0.1) * labels_one_hot) + \ + (1.0 / labels_one_hot.size(1)) + + pred = model.forward(g, all_rel, e1_batch, rel_batch, entity_id) + optimizer.zero_grad() + loss = model.loss(pred, labels_one_hot) + loss.backward() + optimizer.step() + + logger.info("epoch : {}".format(epoch)) + logger.info("epoch time: {:.4f}".format( + time.time() - epoch_start_time)) + logger.info("loss: {}".format(loss.data)) + + model.eval() + if epoch % args.eval_every == 0: + with torch.no_grad(): + val_mrr = ranking_and_hits( + g, all_rel, model, valid_dataloader, 'dev_evaluation', entity_id, device, logger) + if stopper.step(val_mrr, model): + break + + print("training done") + model.load_state_dict(torch.load(os.path.join( + save_path, args.decoder+"_"+args.name+'.pt'))) + ranking_and_hits(g, all_rel, model, test_dataloader, + 'test_evaluation', entity_id, device, logger) + + + + + + + diff --git a/openhgnn/trainerflow/__init__.py b/openhgnn/trainerflow/__init__.py index f3a117da..a811460a 100644 --- a/openhgnn/trainerflow/__init__.py +++ b/openhgnn/trainerflow/__init__.py @@ -80,6 +80,18 @@ def build_flow(args, flow_name): 'DSSL_trainer': 'openhgnn.trainerflow.DSSL_trainer', 'hgcltrainer': 'openhgnn.trainerflow.hgcl_trainer', 'lightGCN_trainer': 'openhgnn.trainerflow.lightGCN_trainer', + 'Grail_trainer': 'openhgnn.trainerflow.Grail_trainer', + 'ComPILE_trainer': 'openhgnn.trainerflow.ComPILE_trainer', + 'AdapropT_trainer': 'openhgnn.trainerflow.AdapropT_trainer', + 'AdapropI_trainer':'openhgnn.trainerflow.AdapropI_trainer', + 'LTE_trainer': 'openhgnn.trainerflow.LTE_trainer', + 'SACN_trainer': 'openhgnn.trainerflow.SACN_trainer', + 'ExpressGNN_trainer': 'openhgnn.trainerflow.ExpressGNN_trainer', + 'NBF_trainer':'openhgnn.trainerflow.NBF_trainer', + 'Ingram_Trainer' : 'openhgnn.trainerflow.Ingram_trainer', + 'DisenKGAT_trainer':'openhgnn.trainerflow.DisenKGAT_trainer', + 'RedGNN_trainer': 'openhgnn.trainerflow.RedGNN_trainer', + 'RedGNNT_trainer': 'openhgnn.trainerflow.RedGNNT_trainer', } from .hgcl_trainer import HGCLtrainer @@ -108,6 +120,17 @@ def build_flow(args, flow_name): from .node_classification_ac import NodeClassificationAC from .DSSL_trainer import DSSL_trainer from .lightGCN_trainer import lightGCNTrainer +from .Grail_trainer import GrailTrainer +from .ComPILE_trainer import ComPILETrainer +from .AdapropT_trainer import AdapropTTrainer +from .AdapropI_trainer import AdapropITrainer +from .LTE_trainer import LTETrainer +from .SACN_trainer import SACNTrainer +from .ExpressGNN_trainer import ExpressGNNTrainer +from .NBF_trainer import * +from .Ingram_trainer import Ingram_Trainer +from .DisenKGAT_trainer import * +from .RedGNNT_trainer import RedGNNTTrainer __all__ = [ 'BaseFlow', @@ -135,5 +158,13 @@ def build_flow(args, flow_name): 'DSSL_trainer', 'HGCLtrainer', 'lightGCNTrainer', + 'GrailTrainer', + 'ComPILETrainer', + 'AdapropTTrainer', + 'AdapropITrainer', + 'LTETrainer', + 'SACNTrainer', + 'ExpressGNNTrainer', + 'Ingram_trainer', ] classes = __all__ diff --git a/openhgnn/trainerflow/base_flow.py b/openhgnn/trainerflow/base_flow.py index de1cee28..ab981962 100644 --- a/openhgnn/trainerflow/base_flow.py +++ b/openhgnn/trainerflow/base_flow.py @@ -63,7 +63,7 @@ def __init__(self, args): self.max_epoch = args.max_epoch self.optimizer = None - if self.model_name == "MeiREC": + if self.model_name in ["MeiREC", "ExpressGNN", "Ingram", "RedGNN","RedGNNT", "AdapropI", "AdapropT","RedGNNT", "Grail", "ComPILE","DisenKGAT"]: return if self.args.use_uva: diff --git a/openhgnn/utils/AdapropI_utils.py b/openhgnn/utils/AdapropI_utils.py new file mode 100644 index 00000000..e0957033 --- /dev/null +++ b/openhgnn/utils/AdapropI_utils.py @@ -0,0 +1,87 @@ +import random +import numpy as np +from scipy.stats import rankdata +import subprocess +import logging +import math + + +def cal_ranks(scores, labels, filters): + scores = scores - np.min(scores, axis=1, keepdims=True) + full_rank = rankdata(-scores, method='ordinal', axis=1) + filter_scores = scores * filters + filter_rank = rankdata(-filter_scores, method='ordinal', axis=1) + ranks = (full_rank - filter_rank + 1) * labels + ranks = ranks[np.nonzero(ranks)] + return list(ranks) + + +def cal_performance(ranks, masks): + mrr = (1. / ranks).sum() / len(ranks) + m_r = sum(ranks) * 1.0 / len(ranks) + h_1 = sum(ranks <= 1) * 1.0 / len(ranks) + h_3 = sum(ranks <= 3) * 1.0 / len(ranks) + h_10 = sum(ranks <= 10) * 1.0 / len(ranks) + h_10_50 = [] + for i, rank in enumerate(ranks): + num_sample = 50 + threshold = 10 + score = 0 + fp_rate = (rank - 1) / masks[i] + for i in range(threshold): + num_comb = math.factorial(num_sample) / math.factorial(i) / math.factorial(num_sample - i) + score += num_comb * (fp_rate ** i) * ((1 - fp_rate) ** (num_sample - i)) + h_10_50.append(score) + h_10_50 = np.mean(h_10_50) + + return mrr, m_r, h_1, h_3, h_10, h_10_50 + + +def select_gpu(): + nvidia_info = subprocess.run('nvidia-smi', stdout=subprocess.PIPE) + gpu_info = False + gpu_info_line = 0 + proc_info = False + gpu_mem = [] + gpu_occupied = set() + i = 0 + for line in nvidia_info.stdout.split(b'\n'): + line = line.decode().strip() + + if gpu_info: + gpu_info_line += 1 + if line == '': + gpu_info = False + continue + if not ('RTX' in line or 'GTX' in line or ''): + try: + mem_info = line.split('|')[2] + used_mem_mb = int(mem_info.strip().split()[0][:-3]) + gpu_mem.append(used_mem_mb) + except: + continue + if proc_info: + if line == '| No running processes found |': + continue + if line == '+-----------------------------------------------------------------------------+': + proc_info = False + continue + proc_gpu = int(line.split()[1]) + gpu_occupied.add(proc_gpu) + + if line == '|===============================+======================+======================|': + gpu_info = True + if line == '|=============================================================================|': + proc_info = True + i += 1 + + for i in range(0, len(gpu_mem)): + if i not in gpu_occupied: + # print('Automatically selected GPU Np.{} because it is vacant.'.format(i)) + return i + + for i in range(0, len(gpu_mem)): + if gpu_mem[i] == min(gpu_mem): + # print('All GPUs are occupied. Automatically selected GPU No.{} because it has the most free memory.'.format(i)) + return i + diff --git a/openhgnn/utils/Adaprop_utils.py b/openhgnn/utils/Adaprop_utils.py new file mode 100644 index 00000000..ea91e81d --- /dev/null +++ b/openhgnn/utils/Adaprop_utils.py @@ -0,0 +1,28 @@ +import numpy as np +from scipy.stats import rankdata +import os + +def checkPath(path): + if not os.path.exists(path): + os.mkdir(path) + return + +def cal_ranks(scores, labels, filters): + scores = scores - np.min(scores, axis=1, keepdims=True) + 1e-8 + full_rank = rankdata(-scores, method='average', axis=1) + filter_scores = scores * filters + filter_rank = rankdata(-filter_scores, method='min', axis=1) + ranks = (full_rank - filter_rank + 1) * labels + ranks = ranks[np.nonzero(ranks)] + return list(ranks) + +def cal_performance(ranks): + mrr = (1. / ranks).sum() / len(ranks) + h_1 = sum(ranks<=1) * 1.0 / len(ranks) + h_10 = sum(ranks<=10) * 1.0 / len(ranks) + return mrr, h_1, h_10 + +def uniqueWithoutSort(a): + indexes = np.unique(a, return_index=True)[1] + res = [a[index] for index in sorted(indexes)] + return res diff --git a/openhgnn/utils/Grail_utils.py b/openhgnn/utils/Grail_utils.py new file mode 100644 index 00000000..c9d7918e --- /dev/null +++ b/openhgnn/utils/Grail_utils.py @@ -0,0 +1,523 @@ +import os +import pdb +import numpy as np +from scipy.sparse import csc_matrix +import struct +import logging +import random +import pickle as pkl +import pdb +from tqdm import tqdm +import lmdb +import multiprocessing as mp +import numpy as np +import scipy.io as sio +import scipy.sparse as ssp +import sys +import torch +from scipy.special import softmax +import json +import pickle +import networkx as nx +import dgl + + + +def process_files(files, saved_relation2id=None): + ''' + files: Dictionary map of file paths to read the triplets from. + saved_relation2id: Saved relation2id (mostly passed from a trained model) which can be used to map relations to pre-defined indices and filter out the unknown ones. + ''' + entity2id = {} + relation2id = {} if saved_relation2id is None else saved_relation2id + + triplets = {} + + ent = 0 + rel = 0 + + for file_type, file_path in files.items(): + + data = [] + with open(file_path) as f: + file_data = [line.split() for line in f.read().split('\n')[:-1]] + + for triplet in file_data: + if triplet[0] not in entity2id: + entity2id[triplet[0]] = ent + ent += 1 + if triplet[2] not in entity2id: + entity2id[triplet[2]] = ent + ent += 1 + if not saved_relation2id and triplet[1] not in relation2id: + relation2id[triplet[1]] = rel + rel += 1 + + # Save the triplets corresponding to only the known relations + if triplet[1] in relation2id: + data.append([entity2id[triplet[0]], entity2id[triplet[2]], relation2id[triplet[1]]]) + + triplets[file_type] = np.array(data) + + id2entity = {v: k for k, v in entity2id.items()} + id2relation = {v: k for k, v in relation2id.items()} + + # Construct the list of adjacency matrix each corresponding to eeach relation. Note that this is constructed only from the train data. + adj_list = [] + for i in range(len(relation2id)): + idx = np.argwhere(triplets['train'][:, 2] == i) + adj_list.append(csc_matrix((np.ones(len(idx), dtype=np.uint8), (triplets['train'][:, 0][idx].squeeze(1), triplets['train'][:, 1][idx].squeeze(1))), shape=(len(entity2id), len(entity2id)))) + + return adj_list, triplets, entity2id, relation2id, id2entity, id2relation + + +def save_to_file(directory, file_name, triplets, id2entity, id2relation): + file_path = os.path.join(directory, file_name) + with open(file_path, "w") as f: + for s, o, r in triplets: + f.write('\t'.join([id2entity[s], id2relation[r], id2entity[o]]) + '\n') + +def generate_subgraph_datasets(params, data_path, splits=['train', 'valid'], saved_relation2id=None, max_label_value=None): + + testing = 'test' in splits + adj_list, triplets, entity2id, relation2id, id2entity, id2relation = process_files(params.file_paths, saved_relation2id) + + # plot_rel_dist(adj_list, os.path.join(params.main_dir, f'data/{params.dataset}/rel_dist.png')) + + #data_path = f'data/{params.dataset}/relation2id.json' + if not os.path.isdir(data_path) and not testing: + with open(data_path, 'w') as f: + json.dump(relation2id, f) + + graphs = {} + + for split_name in splits: + graphs[split_name] = {'triplets': triplets[split_name], 'max_size': params.max_links} + + # Sample train and valid/test links + for split_name, split in graphs.items(): + logging.info(f"Sampling negative links for {split_name}") + split['pos'], split['neg'] = sample_neg(adj_list, split['triplets'], params.num_neg_samples_per_link, max_size=split['max_size'], constrained_neg_prob=params.constrained_neg_prob) + + if testing: + directory = os.path.join(params.main_dir, 'data/{}/'.format(params.dataset)) + save_to_file(directory, f'neg_{params.test_file}_{params.constrained_neg_prob}.txt', graphs['test']['neg'], id2entity, id2relation) + + links2subgraphs(adj_list, graphs, params, max_label_value) + + +def get_kge_embeddings(dataset, kge_model): + + path = './experiments/kge_baselines/{}_{}'.format(kge_model, dataset) + node_features = np.load(os.path.join(path, 'entity_embedding.npy')) + with open(os.path.join(path, 'id2entity.json')) as json_file: + kge_id2entity = json.load(json_file) + kge_entity2id = {v: int(k) for k, v in kge_id2entity.items()} + + return node_features, kge_entity2id +def links2subgraphs(A, graphs, params, max_label_value=None): + ''' + extract enclosing subgraphs, write map mode + named dbs + ''' + max_n_label = {'value': np.array([0, 0])} + subgraph_sizes = [] + enc_ratios = [] + num_pruned_nodes = [] + + BYTES_PER_DATUM = get_average_subgraph_size(100, list(graphs.values())[0]['pos'], A, params) * 1.5 + links_length = 0 + for split_name, split in graphs.items(): + links_length += (len(split['pos']) + len(split['neg'])) * 2 + map_size = links_length * BYTES_PER_DATUM + map_size = int(map_size)+1 + env = lmdb.open(params.db_path, map_size=map_size, max_dbs=6) + + def extraction_helper(A, links, g_labels, split_env): + + with env.begin(write=True, db=split_env) as txn: + txn.put('num_graphs'.encode(), (len(links)).to_bytes(int.bit_length(len(links)), byteorder='little')) + + with mp.Pool(processes=None, initializer=intialize_worker, initargs=(A, params, max_label_value)) as p: + args_ = zip(range(len(links)), links, g_labels) + for (str_id, datum) in tqdm(p.imap(extract_save_subgraph, args_), total=len(links)): + max_n_label['value'] = np.maximum(np.max(datum['n_labels'], axis=0), max_n_label['value']) + subgraph_sizes.append(datum['subgraph_size']) + enc_ratios.append(datum['enc_ratio']) + num_pruned_nodes.append(datum['num_pruned_nodes']) + + with env.begin(write=True, db=split_env) as txn: + txn.put(str_id, serialize(datum)) + + for split_name, split in graphs.items(): + logging.info(f"Extracting enclosing subgraphs for positive links in {split_name} set") + labels = np.ones(len(split['pos'])) + db_name_pos = split_name + '_pos' + print(db_name_pos) + split_env = env.open_db(db_name_pos.encode()) + extraction_helper(A, split['pos'], labels, split_env) + + logging.info(f"Extracting enclosing subgraphs for negative links in {split_name} set") + labels = np.zeros(len(split['neg'])) + db_name_neg = split_name + '_neg' + print(db_name_neg) + split_env = env.open_db(db_name_neg.encode()) + extraction_helper(A, split['neg'], labels, split_env) + + max_n_label['value'] = max_label_value if max_label_value is not None else max_n_label['value'] + + with env.begin(write=True) as txn: + bit_len_label_sub = int.bit_length(int(max_n_label['value'][0])) + bit_len_label_obj = int.bit_length(int(max_n_label['value'][1])) + txn.put('max_n_label_sub'.encode(), (int(max_n_label['value'][0])).to_bytes(bit_len_label_sub, byteorder='little')) + txn.put('max_n_label_obj'.encode(), (int(max_n_label['value'][1])).to_bytes(bit_len_label_obj, byteorder='little')) + + txn.put('avg_subgraph_size'.encode(), struct.pack('f', float(np.mean(subgraph_sizes)))) + txn.put('min_subgraph_size'.encode(), struct.pack('f', float(np.min(subgraph_sizes)))) + txn.put('max_subgraph_size'.encode(), struct.pack('f', float(np.max(subgraph_sizes)))) + txn.put('std_subgraph_size'.encode(), struct.pack('f', float(np.std(subgraph_sizes)))) + + txn.put('avg_enc_ratio'.encode(), struct.pack('f', float(np.mean(enc_ratios)))) + txn.put('min_enc_ratio'.encode(), struct.pack('f', float(np.min(enc_ratios)))) + txn.put('max_enc_ratio'.encode(), struct.pack('f', float(np.max(enc_ratios)))) + txn.put('std_enc_ratio'.encode(), struct.pack('f', float(np.std(enc_ratios)))) + + txn.put('avg_num_pruned_nodes'.encode(), struct.pack('f', float(np.mean(num_pruned_nodes)))) + txn.put('min_num_pruned_nodes'.encode(), struct.pack('f', float(np.min(num_pruned_nodes)))) + txn.put('max_num_pruned_nodes'.encode(), struct.pack('f', float(np.max(num_pruned_nodes)))) + txn.put('std_num_pruned_nodes'.encode(), struct.pack('f', float(np.std(num_pruned_nodes)))) + + +def get_average_subgraph_size(sample_size, links, A, params): + total_size = 0 + for (n1, n2, r_label) in links[np.random.choice(len(links), sample_size)]: + nodes, n_labels, subgraph_size, enc_ratio, num_pruned_nodes = subgraph_extraction_labeling((n1, n2), r_label, A, params.hop, params.enclosing_sub_graph, params.max_nodes_per_hop) + datum = {'nodes': nodes, 'r_label': r_label, 'g_label': 0, 'n_labels': n_labels, 'subgraph_size': subgraph_size, 'enc_ratio': enc_ratio, 'num_pruned_nodes': num_pruned_nodes} + total_size += len(serialize(datum)) + return total_size / sample_size + + +def intialize_worker(A, params, max_label_value): + global A_, params_, max_label_value_ + A_, params_, max_label_value_ = A, params, max_label_value + + +def extract_save_subgraph(args_): + idx, (n1, n2, r_label), g_label = args_ + nodes, n_labels, subgraph_size, enc_ratio, num_pruned_nodes = subgraph_extraction_labeling((n1, n2), r_label, A_, params_.hop, params_.enclosing_sub_graph, params_.max_nodes_per_hop) + + # max_label_value_ is to set the maximum possible value of node label while doing double-radius labelling. + if max_label_value_ is not None: + n_labels = np.array([np.minimum(label, max_label_value_).tolist() for label in n_labels]) + + datum = {'nodes': nodes, 'r_label': r_label, 'g_label': g_label, 'n_labels': n_labels, 'subgraph_size': subgraph_size, 'enc_ratio': enc_ratio, 'num_pruned_nodes': num_pruned_nodes} + str_id = '{:08}'.format(idx).encode('ascii') + + return (str_id, datum) + + +def get_neighbor_nodes(roots, adj, h=1, max_nodes_per_hop=None): + bfs_generator = _bfs_relational(adj, roots, max_nodes_per_hop) + lvls = list() + for _ in range(h): + try: + lvls.append(next(bfs_generator)) + except StopIteration: + pass + return set().union(*lvls) + + +def subgraph_extraction_labeling(ind, rel, A_list, h=1, enclosing_sub_graph=False, max_nodes_per_hop=None, max_node_label_value=None): + # extract the h-hop enclosing subgraphs around link 'ind' + A_incidence = incidence_matrix(A_list) + A_incidence += A_incidence.T + + root1_nei = get_neighbor_nodes(set([ind[0]]), A_incidence, h, max_nodes_per_hop) + root2_nei = get_neighbor_nodes(set([ind[1]]), A_incidence, h, max_nodes_per_hop) + + subgraph_nei_nodes_int = root1_nei.intersection(root2_nei) + subgraph_nei_nodes_un = root1_nei.union(root2_nei) + + # Extract subgraph | Roots being in the front is essential for labelling and the model to work properly. + if enclosing_sub_graph: + subgraph_nodes = list(ind) + list(subgraph_nei_nodes_int) + else: + subgraph_nodes = list(ind) + list(subgraph_nei_nodes_un) + + subgraph = [adj[subgraph_nodes, :][:, subgraph_nodes] for adj in A_list] + + labels, enclosing_subgraph_nodes = node_label(incidence_matrix(subgraph), max_distance=h) + + pruned_subgraph_nodes = np.array(subgraph_nodes)[enclosing_subgraph_nodes].tolist() + pruned_labels = labels[enclosing_subgraph_nodes] + # pruned_subgraph_nodes = subgraph_nodes + # pruned_labels = labels + + if max_node_label_value is not None: + pruned_labels = np.array([np.minimum(label, max_node_label_value).tolist() for label in pruned_labels]) + + subgraph_size = len(pruned_subgraph_nodes) + enc_ratio = len(subgraph_nei_nodes_int) / (len(subgraph_nei_nodes_un) + 1e-3) + num_pruned_nodes = len(subgraph_nodes) - len(pruned_subgraph_nodes) + + return pruned_subgraph_nodes, pruned_labels, subgraph_size, enc_ratio, num_pruned_nodes + + +def node_label(subgraph, max_distance=1): + # implementation of the node labeling scheme described in the paper + roots = [0, 1] + sgs_single_root = [remove_nodes(subgraph, [root]) for root in roots] + dist_to_roots = [np.clip(ssp.csgraph.dijkstra(sg, indices=[0], directed=False, unweighted=True, limit=1e6)[:, 1:], 0, 1e7) for r, sg in enumerate(sgs_single_root)] + dist_to_roots = np.array(list(zip(dist_to_roots[0][0], dist_to_roots[1][0])), dtype=int) + + target_node_labels = np.array([[0, 1], [1, 0]]) + labels = np.concatenate((target_node_labels, dist_to_roots)) if dist_to_roots.size else target_node_labels + + enclosing_subgraph_nodes = np.where(np.max(labels, axis=1) <= max_distance)[0] + return labels, enclosing_subgraph_nodes + +def sample_neg(adj_list, edges, num_neg_samples_per_link=1, max_size=1000000, constrained_neg_prob=0): + pos_edges = edges + neg_edges = [] + + # if max_size is set, randomly sample train links + if max_size < len(pos_edges): + perm = np.random.permutation(len(pos_edges))[:max_size] + pos_edges = pos_edges[perm] + + # sample negative links for train/test + n, r = adj_list[0].shape[0], len(adj_list) + + # distribution of edges across reelations + theta = 0.001 + edge_count = get_edge_count(adj_list) + rel_dist = np.zeros(edge_count.shape) + idx = np.nonzero(edge_count) + rel_dist[idx] = softmax(theta * edge_count[idx]) + + # possible head and tails for each relation + valid_heads = [adj.tocoo().row.tolist() for adj in adj_list] + valid_tails = [adj.tocoo().col.tolist() for adj in adj_list] + + pbar = tqdm(total=len(pos_edges)) + while len(neg_edges) < num_neg_samples_per_link * len(pos_edges): + neg_head, neg_tail, rel = pos_edges[pbar.n % len(pos_edges)][0], pos_edges[pbar.n % len(pos_edges)][1], pos_edges[pbar.n % len(pos_edges)][2] + if np.random.uniform() < constrained_neg_prob: + if np.random.uniform() < 0.5: + neg_head = np.random.choice(valid_heads[rel]) + else: + neg_tail = np.random.choice(valid_tails[rel]) + else: + if np.random.uniform() < 0.5: + neg_head = np.random.choice(n) + else: + neg_tail = np.random.choice(n) + + if neg_head != neg_tail and adj_list[rel][neg_head, neg_tail] == 0: + neg_edges.append([neg_head, neg_tail, rel]) + pbar.update(1) + + pbar.close() + + neg_edges = np.array(neg_edges) + return pos_edges, neg_edges + +def serialize(data): + data_tuple = tuple(data.values()) + return pickle.dumps(data_tuple) + + +def deserialize(data): + data_tuple = pickle.loads(data) + keys = ('nodes', 'r_label', 'g_label', 'n_label') + return dict(zip(keys, data_tuple)) + + +def get_edge_count(adj_list): + count = [] + for adj in adj_list: + count.append(len(adj.tocoo().row.tolist())) + return np.array(count) + + +def incidence_matrix(adj_list): + ''' + adj_list: List of sparse adjacency matrices + ''' + + rows, cols, dats = [], [], [] + dim = adj_list[0].shape + for adj in adj_list: + adjcoo = adj.tocoo() + rows += adjcoo.row.tolist() + cols += adjcoo.col.tolist() + dats += adjcoo.data.tolist() + row = np.array(rows) + col = np.array(cols) + data = np.array(dats) + return ssp.csc_matrix((data, (row, col)), shape=dim) + + +def remove_nodes(A_incidence, nodes): + idxs_wo_nodes = list(set(range(A_incidence.shape[1])) - set(nodes)) + return A_incidence[idxs_wo_nodes, :][:, idxs_wo_nodes] + + +def ssp_to_torch(A, device, dense=False): + ''' + A : Sparse adjacency matrix + ''' + idx = torch.LongTensor([A.tocoo().row, A.tocoo().col]) + dat = torch.FloatTensor(A.tocoo().data) + A = torch.sparse.FloatTensor(idx, dat, torch.Size([A.shape[0], A.shape[1]])).to(device=device) + return A + + +def ssp_multigraph_to_dgl(graph, n_feats=None): + """ + Converting ssp multigraph (i.e. list of adjs) to dgl multigraph. + """ + + g_nx = nx.MultiDiGraph() + g_nx.add_nodes_from(list(range(graph[0].shape[0]))) + # Add edges + for rel, adj in enumerate(graph): + # Convert adjacency matrix to tuples for nx0 + nx_triplets = [] + for src, dst in list(zip(adj.tocoo().row, adj.tocoo().col)): + nx_triplets.append((src, dst, {'type': rel})) + g_nx.add_edges_from(nx_triplets) + + # make dgl graph + #g_dgl = dgl.DGLGraph(multigraph=True) + #g_dgl.from_networkx(g_nx, edge_attrs=['type']) + g_dgl = dgl.from_networkx(g_nx,edge_attrs=['type']) + # add node features + if n_feats is not None: + g_dgl.ndata['feat'] = torch.tensor(n_feats) + + return g_dgl + + +def collate_dgl(samples): + # The input `samples` is a list of pairs + graphs_pos, g_labels_pos, r_labels_pos, graphs_negs, g_labels_negs, r_labels_negs = map(list, zip(*samples)) + batched_graph_pos = dgl.batch(graphs_pos) + + graphs_neg = [item for sublist in graphs_negs for item in sublist] + g_labels_neg = [item for sublist in g_labels_negs for item in sublist] + r_labels_neg = [item for sublist in r_labels_negs for item in sublist] + + batched_graph_neg = dgl.batch(graphs_neg) + return (batched_graph_pos, r_labels_pos), g_labels_pos, (batched_graph_neg, r_labels_neg), g_labels_neg + + +def collate_dgl2(samples): + # The input `samples` is a list of pairs + graphs_pos, g_labels_pos, r_labels_pos, graphs_negs, g_labels_negs, r_labels_negs = map(list, zip(*samples)) + + # graphs_pos = [item for sublist in graphs_pos for item in sublist] + # g_labels_pos = [item for sublist in g_labels_pos for item in sublist] + # r_labels_pos = [item for sublist in r_labels_pos for item in sublist] + + # batched_graph_pos = dgl.batch(graphs_pos) + + graphs_neg = [item for sublist in graphs_negs for item in sublist] + g_labels_neg = [item for sublist in g_labels_negs for item in sublist] + r_labels_neg = [item for sublist in r_labels_negs for item in sublist] + # print('neg for each pos ', len(graphs_neg)) + # batched_graph_neg = dgl.batch(graphs_neg) + return (graphs_pos, r_labels_pos), g_labels_pos, (graphs_neg, r_labels_neg), g_labels_neg + +def move_batch_to_device_dgl(batch, device): + ((g_dgl_pos, r_labels_pos), targets_pos, (g_dgl_neg, r_labels_neg), targets_neg) = batch + + targets_pos = torch.LongTensor(targets_pos).to(device=device) + r_labels_pos = torch.LongTensor(r_labels_pos).to(device=device) + + targets_neg = torch.LongTensor(targets_neg).to(device=device) + r_labels_neg = torch.LongTensor(r_labels_neg).to(device=device) + + g_dgl_pos = send_graph_to_device(g_dgl_pos, device) + g_dgl_neg = send_graph_to_device(g_dgl_neg, device) + + return ((g_dgl_pos, r_labels_pos), targets_pos, (g_dgl_neg, r_labels_neg), targets_neg) + + +def send_graph_to_device(g, device): + # nodes + labels = g.node_attr_schemes() + g = g.to(device) + for l in labels.keys(): + g.ndata[l] = g.ndata.pop(l).to(device) + + # edges + labels = g.edge_attr_schemes() + for l in labels.keys(): + g.edata[l] = g.edata.pop(l).to(device) + return g + +# The following three functions are modified from networks source codes to +# accomodate diameter and radius for dirercted graphs + + +def eccentricity(G): + e = {} + for n in G.nbunch_iter(): + length = nx.single_source_shortest_path_length(G, n) + e[n] = max(length.values()) + return e + + +def radius(G): + e = eccentricity(G) + e = np.where(np.array(list(e.values())) > 0, list(e.values()), np.inf) + return min(e) + + +def diameter(G): + e = eccentricity(G) + return max(e.values()) + + +def _bfs_relational(adj, roots, max_nodes_per_hop=None): + """ + BFS for graphs. + Modified from dgl.contrib.data.knowledge_graph to accomodate node sampling + """ + visited = set() + current_lvl = set(roots) + + next_lvl = set() + + while current_lvl: + + for v in current_lvl: + visited.add(v) + + next_lvl = _get_neighbors(adj, current_lvl) + next_lvl -= visited # set difference + + if max_nodes_per_hop and max_nodes_per_hop < len(next_lvl): + next_lvl = set(random.sample(next_lvl, max_nodes_per_hop)) + + yield next_lvl + + current_lvl = set.union(next_lvl) + + +def _get_neighbors(adj, nodes): + """Takes a set of nodes and a graph adjacency matrix and returns a set of neighbors. + Directly copied from dgl.contrib.data.knowledge_graph""" + sp_nodes = _sp_row_vec_from_idx_list(list(nodes), adj.shape[1]) + sp_neighbors = sp_nodes.dot(adj) + neighbors = set(ssp.find(sp_neighbors)[1]) # convert to set of indices + return neighbors + + +def _sp_row_vec_from_idx_list(idx_list, dim): + """Create sparse vector of dimensionality dim from a list of indices.""" + shape = (1, dim) + data = np.ones(len(idx_list)) + row_ind = np.zeros(len(idx_list)) + col_ind = list(idx_list) + return ssp.csr_matrix((data, (row_ind, col_ind)), shape=shape) diff --git a/openhgnn/utils/Ingram_utils.py b/openhgnn/utils/Ingram_utils.py new file mode 100644 index 00000000..05b5a5aa --- /dev/null +++ b/openhgnn/utils/Ingram_utils.py @@ -0,0 +1,85 @@ +import torch +from scipy.sparse import csr_matrix +import numpy as np +import dgl + + +def initialize(target, msg, d_e, d_r, B): + init_emb_ent = torch.zeros((target.num_ent, d_e)).cuda() + init_emb_rel = torch.zeros((2 * target.num_rel, d_r)).cuda() + gain = torch.nn.init.calculate_gain('relu') + torch.nn.init.xavier_normal_(init_emb_ent, gain=gain) + torch.nn.init.xavier_normal_(init_emb_rel, gain=gain) + relation_triplets = generate_relation_triplets(msg, target.num_ent, target.num_rel, B) + relation_triplets = torch.tensor(relation_triplets).cuda() + return init_emb_ent, init_emb_rel, relation_triplets + + +def create_relation_graph(triplet, num_ent, num_rel): + ind_h = triplet[:, :2] + ind_t = triplet[:, 1:] + E_h = csr_matrix((np.ones(len(ind_h)), (ind_h[:, 0], ind_h[:, 1])), shape=(num_ent, 2 * num_rel)) + E_t = csr_matrix((np.ones(len(ind_t)), (ind_t[:, 1], ind_t[:, 0])), shape=(num_ent, 2 * num_rel)) + diag_vals_h = E_h.sum(axis=1).A1 + diag_vals_h[diag_vals_h != 0] = 1 / (diag_vals_h[diag_vals_h != 0] ** 2) + diag_vals_t = E_t.sum(axis=1).A1 + diag_vals_t[diag_vals_t != 0] = 1 / (diag_vals_t[diag_vals_t != 0] ** 2) + D_h_inv = csr_matrix((diag_vals_h, (np.arange(num_ent), np.arange(num_ent))), shape=(num_ent, num_ent)) + D_t_inv = csr_matrix((diag_vals_t, (np.arange(num_ent), np.arange(num_ent))), shape=(num_ent, num_ent)) + A_h = E_h.transpose() @ D_h_inv @ E_h + A_t = E_t.transpose() @ D_t_inv @ E_t + return A_h + A_t + + +def get_rank(triplet, scores, filters, target=0): + thres = scores[triplet[0, target]].item() + scores[filters] = thres - 1 + rank = (scores > thres).sum() + (scores == thres).sum() // 2 + 1 + return rank.item() + + +def get_metrics(rank): + rank = np.array(rank, dtype=np.int) + mr = np.mean(rank) + mrr = np.mean(1 / rank) + hit10 = np.sum(rank < 11) / len(rank) + hit3 = np.sum(rank < 4) / len(rank) + hit1 = np.sum(rank < 2) / len(rank) + return mr, mrr, hit10, hit3, hit1 + + +def generate_neg(triplets, num_ent, num_neg=1): + neg_triplets = triplets.unsqueeze(dim=1).repeat(1, num_neg, 1) + rand_result = torch.rand((len(triplets), num_neg)).cuda() + perturb_head = rand_result < 0.5 + perturb_tail = rand_result >= 0.5 + rand_idxs = torch.randint(low=0, high=num_ent - 1, size=(len(triplets), num_neg)).cuda() + rand_idxs[perturb_head] += rand_idxs[perturb_head] >= neg_triplets[:, :, 0][perturb_head] + rand_idxs[perturb_tail] += rand_idxs[perturb_tail] >= neg_triplets[:, :, 2][perturb_tail] + neg_triplets[:, :, 0][perturb_head] = rand_idxs[perturb_head] + neg_triplets[:, :, 2][perturb_tail] = rand_idxs[perturb_tail] + neg_triplets = torch.cat(torch.split(neg_triplets, 1, dim=1), dim=0).squeeze(dim=1) + return neg_triplets + + +def generate_relation_triplets(triplet, num_ent, num_rel, B): + A = create_relation_graph(triplet, num_ent, num_rel) + A_sparse = csr_matrix(A) + G_rel = dgl.from_scipy(A_sparse) # 这里用dgl.from_scipy()函数创建一个图对象 + G_rel.edata['weight'] = torch.from_numpy(A.data) # 这里用A.data获取稀疏矩阵的非零元素,作为边的权重 + relation_triplets = get_relation_triplets(G_rel, B) + return relation_triplets + + +def get_relation_triplets(G_rel, B): + src, dst = G_rel.edges() # 获取边的源节点和目标节点 + w = G_rel.edata['weight'] # 获取边的权重 + nnz = len(w) # 获取边的数量 + temp = torch.argsort(-w) # 对边的权重进行降序排序,并返回排序后的索引 + weight_ranks = torch.empty_like(temp) # 创建一个空的张量,用于存储权重的排名 + weight_ranks[temp] = torch.arange(nnz) + 1 # 根据排序后的索引,给每个权重赋予一个排名 + rk = torch.floor(weight_ranks / nnz * B) - 1 # 把权重的排名映射到一个区间[0, B-1] + rk = rk.int() # 把映射后的权重转换为整数 + relation_triplets = torch.stack([src, dst, rk], dim=1) # 按列拼接三个张量 + relation_triplets = relation_triplets.numpy() # 把张量转换为numpy数组 + return relation_triplets diff --git a/openhgnn/utils/best_config.py b/openhgnn/utils/best_config.py index 9d44aed8..e22b4afb 100644 --- a/openhgnn/utils/best_config.py +++ b/openhgnn/utils/best_config.py @@ -418,6 +418,59 @@ 'wn18':{ 'lr': 0.5, 'weight_decay': 0.0001, 'ent_dim': 400, 'rel_dim': 400, 'neg_size': 98, 'margin': 200, 'batch_size': 100, 'patience':5, 'valid_percent':1, 'test_percent': 1 } + }, + 'RedGNN': { + 'general': { + + }, + 'WN18RR_v1':{ + 'batch_size':100, 'hidden_dim' : 64, 'lr': 0.005, 'weight_decay': 0.0002, 'decay_rate': 0.991, 'attn_dim':5, + 'act': 'idd', 'n_layer': 5 + }, + 'fb237_v1': { + 'batch_size':20, 'hidden_dim' : 32, 'lr': 0.0092, 'weight_decay': 0.0003, 'decay_rate': 0.994, 'attn_dim':5, + 'act': 'relu', 'n_layer': 3 + }, + 'nell_v1': { + 'batch_size':10, 'hidden_dim' : 48, 'lr': 0.0021, 'weight_decay': 0.000189, 'decay_rate': 0.9937, 'attn_dim':5, + 'act': 'relu', 'n_layer': 5 + }, + 'WN18RR_v2':{ + 'batch_size':20, 'hidden_dim' : 48, 'lr': 0.0016, 'weight_decay': 0.0004, 'decay_rate': 0.994, 'attn_dim':3, + 'act': 'relu', 'n_layer': 5 + }, + 'fb237_v2': { + 'batch_size':10, 'hidden_dim' : 48, 'lr': 0.0077, 'weight_decay': 0.0002, 'decay_rate': 0.993, 'attn_dim':5, + 'act': 'relu', 'n_layer': 3 + }, + 'nell_v2': { + 'batch_size':100, 'hidden_dim' : 48, 'lr': 0.0075, 'weight_decay': 0.000189, 'decay_rate': 0.9996, 'attn_dim':5, + 'act': 'relu', 'n_layer': 3 + }, + 'WN18RR_v3':{ + 'batch_size':20, 'hidden_dim' : 64, 'lr': 0.0014, 'weight_decay': 0.0004, 'decay_rate': 0.994, 'attn_dim':5, + 'act': 'tanh', 'n_layer': 5 + }, + 'fb237_v3': { + 'batch_size':30, 'hidden_dim' : 48, 'lr': 0.0006, 'weight_decay': 0.00023, 'decay_rate': 0.993, 'attn_dim':5, + 'act': 'relu', 'n_layer': 3 + }, + 'nell_v3': { + 'batch_size':10, 'hidden_dim' : 16, 'lr': 0.0008, 'weight_decay': 0.0004, 'decay_rate': 0.995, 'attn_dim':3, + 'act': 'relu', 'n_layer': 3 + }, + 'WN18RR_v4':{ + 'batch_size':10, 'hidden_dim' : 32, 'lr': 0.006, 'weight_decay': 0.000132, 'decay_rate': 0.991, 'attn_dim':5, + 'act': 'relu', 'n_layer': 5 + }, + 'fb237_v4': { + 'batch_size':20, 'hidden_dim' : 48, 'lr': 0.0052, 'weight_decay': 0.000018, 'decay_rate': 0.999, 'attn_dim':5, + 'act': 'idd', 'n_layer': 5 + }, + 'nell_v4': { + 'batch_size':20, 'hidden_dim' : 16, 'lr': 0.0005, 'weight_decay': 0.000398, 'decay_rate': 1, 'attn_dim':3, + 'act': 'tanh', 'n_layer': 5 + }, } }, diff --git a/openhgnn/utils/lte_data_set.py b/openhgnn/utils/lte_data_set.py new file mode 100644 index 00000000..cf7d373a --- /dev/null +++ b/openhgnn/utils/lte_data_set.py @@ -0,0 +1,59 @@ +from torch.utils.data import Dataset +import numpy as np +import torch + + +class TrainDataset(Dataset): + def __init__(self, triplets, num_ent, params): + super(TrainDataset, self).__init__() + self.p = params + self.triplets = triplets + self.label_smooth = params.lbl_smooth + self.num_ent = num_ent + + def __len__(self): + return len(self.triplets) + + def __getitem__(self, item): + ele = self.triplets[item] + triple, label = torch.tensor(ele['triple'], dtype=torch.long), np.int32(ele['label']) + label = self.get_label(label) + if self.label_smooth != 0.0: + label = (1.0 - self.label_smooth) * label + (1.0 / self.num_ent) + return triple, label + + def get_label(self, label): + """ + get label corresponding to a (sub, rel) pair + :param label: a list containing indices of objects corresponding to a (sub, rel) pair + :return: a tensor of shape [nun_ent] + """ + y = np.zeros([self.num_ent], dtype=np.float32) + y[label] = 1 + return torch.tensor(y, dtype=torch.float32) + + +class TestDataset(Dataset): + def __init__(self, triplets, num_ent, params): + super(TestDataset, self).__init__() + self.triplets = triplets + self.num_ent = num_ent + + def __len__(self): + return len(self.triplets) + + def __getitem__(self, item): + ele = self.triplets[item] + triple, label = torch.tensor(ele['triple'], dtype=torch.long), np.int32(ele['label']) + label = self.get_label(label) + return triple, label + + def get_label(self, label): + """ + get label corresponding to a (sub, rel) pair + :param label: a list containing indices of objects corresponding to a (sub, rel) pair + :return: a tensor of shape [nun_ent] + """ + y = np.zeros([self.num_ent], dtype=np.float32) + y[label] = 1 + return torch.tensor(y, dtype=torch.float32) diff --git a/openhgnn/utils/lte_knowledge_graph.py b/openhgnn/utils/lte_knowledge_graph.py new file mode 100644 index 00000000..9466c700 --- /dev/null +++ b/openhgnn/utils/lte_knowledge_graph.py @@ -0,0 +1,569 @@ +""" +based on the implementation in DGL +(https://github.com/dmlc/dgl/blob/master/python/dgl/contrib/data/knowledge_graph.py) +Knowledge graph dataset for Relational-GCN +Code adapted from authors' implementation of Relational-GCN +https://github.com/tkipf/relational-gcn +https://github.com/MichSchli/RelationPrediction +""" + +from __future__ import print_function +from __future__ import absolute_import + +import io + +import numpy as np +import scipy.sparse as sp +import os +import gzip +import rdflib as rdf +import pandas as pd +from collections import Counter +import requests +from dgl.data.utils import download, extract_archive, get_download_dir, _get_dgl_url +import zipfile +np.random.seed(123) + +_downlaod_prefix = _get_dgl_url('dataset/') + + +def load_data(dataset, bfs_level=3, relabel=False): + if dataset in ['wn18rr', 'FB15k-237', 'yago']: + return load_link(dataset) + else: + raise ValueError('Unknown dataset: {}'.format(dataset)) + + +class RGCNEntityDataset(object): + """RGCN Entity Classification dataset + + The dataset contains a graph depicting the connectivity of a knowledge + base. Currently, four knowledge bases from the + `RGCN paper `_ are supported: aifb, + mutag, bgs, and am. + + The original knowledge base is stored as an RDF file, and this class will + download and parse the RDF file, and performs preprocessing. + + An object of this class has 11 member attributes needed for entity + classification: + + num_nodes: int + number of entities of knowledge base + num_rels: int + number of relations (including reverse relation) of knowledge base + num_classes: int + number of classes/labels that of entities in knowledge base + edge_src: numpy.array + source node ids of all edges + edge_dst: numpy.array + destination node ids of all edges + edge_type: numpy.array + type of all edges + edge_norm: numpy.array + normalization factor of all edges + labels: numpy.array + labels of node entities + train_idx: numpy.array + ids of entities used for training + valid_idx: numpy.array + ids of entities used for validation + test_idx: numpy.array + ids of entities used for testing + + Usually, users don't need to directly use this class. Instead, DGL provides + wrapper function to load data (see example below). + When loading data, besides specifying dataset name, user can provide two + optional arguments: + + Parameters + ---------- + bfs_level: int + prune out nodes that are more than ``bfs_level`` hops away from + labeled nodes, i.e., nodes won't be touched during propagation. If set + to a number less or equal to 0, all nodes will be retained. + relabel: bool + After pruning, whether or not to relabel all nodes with consecutive + node ids + + Examples + -------- + Load aifb dataset, prune out nodes that are more than 3 hops away from + labeled nodes, and relabel the remaining nodes with consecutive ids + + >>> from dgl.contrib.data import load_data + >>> data = load_data(dataset='aifb', bfs_level=3, relabel=True) + + """ + + def __init__(self, name): + self.name = name + self.dir = get_download_dir() + tgz_path = os.path.join(self.dir, '{}.tgz'.format(self.name)) + download(_downlaod_prefix + '{}.tgz'.format(self.name), tgz_path) + self.dir = os.path.join(self.dir, self.name) + extract_archive(tgz_path, self.dir) + + def load(self, bfs_level=2, relabel=False): + self.num_nodes, edges, self.num_rels, self.labels, labeled_nodes_idx, self.train_idx, self.test_idx = _load_data( + self.name, self.dir) + + # bfs to reduce edges + if bfs_level > 0: + print("removing nodes that are more than {} hops away".format(bfs_level)) + row, col, edge_type = edges.transpose() + A = sp.csr_matrix((np.ones(len(row)), (row, col)), + shape=(self.num_nodes, self.num_nodes)) + bfs_generator = _bfs_relational(A, labeled_nodes_idx) + lvls = list() + lvls.append(set(labeled_nodes_idx)) + for _ in range(bfs_level): + lvls.append(next(bfs_generator)) + to_delete = list(set(range(self.num_nodes)) - set.union(*lvls)) + eid_to_delete = np.isin(row, to_delete) + np.isin(col, to_delete) + eid_to_keep = np.logical_not(eid_to_delete) + self.edge_src = row[eid_to_keep] + self.edge_dst = col[eid_to_keep] + self.edge_type = edge_type[eid_to_keep] + + if relabel: + uniq_nodes, edges = np.unique( + (self.edge_src, self.edge_dst), return_inverse=True) + self.edge_src, self.edge_dst = np.reshape(edges, (2, -1)) + node_map = np.zeros(self.num_nodes, dtype=int) + self.num_nodes = len(uniq_nodes) + node_map[uniq_nodes] = np.arange(self.num_nodes) + self.labels = self.labels[uniq_nodes] + self.train_idx = node_map[self.train_idx] + self.test_idx = node_map[self.test_idx] + print("{} nodes left".format(self.num_nodes)) + else: + self.edge_src, self.edge_dst, self.edge_type = edges.transpose() + + # normalize by dst degree + _, inverse_index, count = np.unique( + (self.edge_dst, self.edge_type), axis=1, return_inverse=True, return_counts=True) + degrees = count[inverse_index] + self.edge_norm = np.ones( + len(self.edge_dst), dtype=np.float32) / degrees.astype(np.float32) + + # convert to pytorch label format + self.num_classes = self.labels.shape[1] + self.labels = np.argmax(self.labels, axis=1) + + +class RGCNLinkDataset(object): + """RGCN link prediction dataset + + The dataset contains a graph depicting the connectivity of a knowledge + base. Currently, the knowledge bases from the + `RGCN paper `_ supported are + FB15k-237, FB15k, wn18 + + The original knowledge base is stored as an RDF file, and this class will + download and parse the RDF file, and performs preprocessing. + + An object of this class has 5 member attributes needed for link + prediction: + + num_nodes: int + number of entities of knowledge base + num_rels: int + number of relations (including reverse relation) of knowledge base + train: numpy.array + all relation triplets (src, rel, dst) for training + valid: numpy.array + all relation triplets (src, rel, dst) for validation + test: numpy.array + all relation triplets (src, rel, dst) for testing + + Usually, user don't need to directly use this class. Instead, DGL provides + wrapper function to load data (see example below). + + Examples + -------- + Load FB15k-237 dataset + + >>> from dgl.contrib.data import load_data + >>> data = load_data(dataset='FB15k-237') + + """ + + def __init__(self, name): + self.name = name + self.dir = './data' + path_ckp = os.path.join(self.dir, self.name) + self.dir = os.path.join(self.dir, self.name) + print(path_ckp) + folder = os.path.exists(path_ckp) + if not folder: # 判断是否存在文件夹如果不存在则创建为文件夹 + os.makedirs(path_ckp) # makedirs 创建文件时如果路径不存在会创建这个路径 + # 下载数据 + url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/FB15k-237.zip" + response = requests.get(url) + with zipfile.ZipFile(io.BytesIO(response.content)) as myzip: + myzip.extractall(path_ckp) + print("--- download data ---") + + else: + print("--- There is data! ---") + + + # zip_path = os.path.join(self.dir, '{}.zip'.format(self.name)) + # self.dir = os.path.join(self.dir, self.name) + # extract_archive(zip_path, self.dir) + + def load(self): + entity_path = os.path.join(self.dir, 'entities.dict') + relation_path = os.path.join(self.dir, 'relations.dict') + train_path = os.path.join(self.dir, 'train.txt') + valid_path = os.path.join(self.dir, 'valid.txt') + test_path = os.path.join(self.dir, 'test.txt') + entity_dict = _read_dictionary(entity_path) + relation_dict = _read_dictionary(relation_path) + self.train = np.asarray(_read_triplets_as_list( + train_path, entity_dict, relation_dict)) + self.valid = np.asarray(_read_triplets_as_list( + valid_path, entity_dict, relation_dict)) + self.test = np.asarray(_read_triplets_as_list( + test_path, entity_dict, relation_dict)) + self.num_nodes = len(entity_dict) + print("# entities: {}".format(self.num_nodes)) + self.num_rels = len(relation_dict) + print("# relations: {}".format(self.num_rels)) + print("# edges: {}".format(len(self.train))) + + +def load_entity(dataset, bfs_level, relabel): + data = RGCNEntityDataset(dataset) + data.load(bfs_level, relabel) + return data + + +def load_link(dataset): + data = RGCNLinkDataset(dataset) + data.load() + return data + + +def _sp_row_vec_from_idx_list(idx_list, dim): + """Create sparse vector of dimensionality dim from a list of indices.""" + shape = (1, dim) + data = np.ones(len(idx_list)) + row_ind = np.zeros(len(idx_list)) + col_ind = list(idx_list) + return sp.csr_matrix((data, (row_ind, col_ind)), shape=shape) + + +def _get_neighbors(adj, nodes): + """Takes a set of nodes and a graph adjacency matrix and returns a set of neighbors.""" + sp_nodes = _sp_row_vec_from_idx_list(list(nodes), adj.shape[1]) + sp_neighbors = sp_nodes.dot(adj) + neighbors = set(sp.find(sp_neighbors)[1]) # convert to set of indices + return neighbors + + +def _bfs_relational(adj, roots): + """ + BFS for graphs with multiple edge types. Returns list of level sets. + Each entry in list corresponds to relation specified by adj_list. + """ + visited = set() + current_lvl = set(roots) + + next_lvl = set() + + while current_lvl: + + for v in current_lvl: + visited.add(v) + + next_lvl = _get_neighbors(adj, current_lvl) + next_lvl -= visited # set difference + + yield next_lvl + + current_lvl = set.union(next_lvl) + + +class RDFReader(object): + __graph = None + __freq = {} + + def __init__(self, file): + + self.__graph = rdf.Graph() + + if file.endswith('nt.gz'): + with gzip.open(file, 'rb') as f: + self.__graph.parse(file=f, format='nt') + else: + self.__graph.parse(file, format=rdf.util.guess_format(file)) + + # See http://rdflib.readthedocs.io for the rdflib documentation + + self.__freq = Counter(self.__graph.predicates()) + + print("Graph loaded, frequencies counted.") + + def triples(self, relation=None): + for s, p, o in self.__graph.triples((None, relation, None)): + yield s, p, o + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.__graph.destroy("store") + self.__graph.close(True) + + def subjectSet(self): + return set(self.__graph.subjects()) + + def objectSet(self): + return set(self.__graph.objects()) + + def relationList(self): + """ + Returns a list of relations, ordered descending by frequency + :return: + """ + res = list(set(self.__graph.predicates())) + res.sort(key=lambda rel: - self.freq(rel)) + return res + + def __len__(self): + return len(self.__graph) + + def freq(self, rel): + if rel not in self.__freq: + return 0 + return self.__freq[rel] + + +def _load_sparse_csr(filename): + loader = np.load(filename) + return sp.csr_matrix((loader['data'], loader['indices'], loader['indptr']), + shape=loader['shape'], dtype=np.float32) + + +def _save_sparse_csr(filename, array): + np.savez(filename, data=array.data, indices=array.indices, + indptr=array.indptr, shape=array.shape) + + +def _load_data(dataset_str='aifb', dataset_path=None): + """ + + :param dataset_str: + :param rel_layers: + :param limit: If > 0, will only load this many adj. matrices + All adjacencies are preloaded and saved to disk, + but only a limited a then restored to memory. + :return: + """ + + print('Loading dataset', dataset_str) + + graph_file = os.path.join( + dataset_path, '{}_stripped.nt.gz'.format(dataset_str)) + task_file = os.path.join(dataset_path, 'completeDataset.tsv') + train_file = os.path.join(dataset_path, 'trainingSet.tsv') + test_file = os.path.join(dataset_path, 'testSet.tsv') + if dataset_str == 'am': + label_header = 'label_cateogory' + nodes_header = 'proxy' + elif dataset_str == 'aifb': + label_header = 'label_affiliation' + nodes_header = 'person' + elif dataset_str == 'mutag': + label_header = 'label_mutagenic' + nodes_header = 'bond' + elif dataset_str == 'bgs': + label_header = 'label_lithogenesis' + nodes_header = 'rock' + else: + raise NameError('Dataset name not recognized: ' + dataset_str) + + edge_file = os.path.join(dataset_path, 'edges.npz') + labels_file = os.path.join(dataset_path, 'labels.npz') + train_idx_file = os.path.join(dataset_path, 'train_idx.npy') + test_idx_file = os.path.join(dataset_path, 'test_idx.npy') + # train_names_file = os.path.join(dataset_path, 'train_names.npy') + # test_names_file = os.path.join(dataset_path, 'test_names.npy') + # rel_dict_file = os.path.join(dataset_path, 'rel_dict.pkl') + # nodes_file = os.path.join(dataset_path, 'nodes.pkl') + + if os.path.isfile(edge_file) and os.path.isfile(labels_file) and \ + os.path.isfile(train_idx_file) and os.path.isfile(test_idx_file): + + # load precomputed adjacency matrix and labels + all_edges = np.load(edge_file) + num_node = all_edges['n'].item() + edge_list = all_edges['edges'] + num_rel = all_edges['nrel'].item() + + print('Number of nodes: ', num_node) + print('Number of edges: ', len(edge_list)) + print('Number of relations: ', num_rel) + + labels = _load_sparse_csr(labels_file) + labeled_nodes_idx = list(labels.nonzero()[0]) + + print('Number of classes: ', labels.shape[1]) + + train_idx = np.load(train_idx_file) + test_idx = np.load(test_idx_file) + + # train_names = np.load(train_names_file) + # test_names = np.load(test_names_file) + # relations_dict = pkl.load(open(rel_dict_file, 'rb')) + + else: + + # loading labels of nodes + labels_df = pd.read_csv(task_file, sep='\t', encoding='utf-8') + labels_train_df = pd.read_csv(train_file, sep='\t', encoding='utf8') + labels_test_df = pd.read_csv(test_file, sep='\t', encoding='utf8') + + with RDFReader(graph_file) as reader: + + relations = reader.relationList() + subjects = reader.subjectSet() + objects = reader.objectSet() + + nodes = list(subjects.union(objects)) + num_node = len(nodes) + num_rel = len(relations) + num_rel = 2 * num_rel + 1 # +1 is for self-relation + + assert num_node < np.iinfo(np.int32).max + print('Number of nodes: ', num_node) + print('Number of relations: ', num_rel) + + relations_dict = {rel: i for i, rel in enumerate(list(relations))} + nodes_dict = {node: i for i, node in enumerate(nodes)} + + edge_list = [] + # self relation + for i in range(num_node): + edge_list.append((i, i, 0)) + + for i, (s, p, o) in enumerate(reader.triples()): + src = nodes_dict[s] + dst = nodes_dict[o] + assert src < num_node and dst < num_node + rel = relations_dict[p] + # relation id 0 is self-relation, so others should start with 1 + edge_list.append((src, dst, 2 * rel + 1)) + # reverse relation + edge_list.append((dst, src, 2 * rel + 2)) + + # sort indices by destination + edge_list = sorted(edge_list, key=lambda x: (x[1], x[0], x[2])) + edge_list = np.asarray(edge_list, dtype=np.int) + print('Number of edges: ', len(edge_list)) + + np.savez(edge_file, edges=edge_list, n=np.asarray( + num_node), nrel=np.asarray(num_rel)) + + nodes_u_dict = {np.unicode(to_unicode(key)): val for key, val in + nodes_dict.items()} + + labels_set = set(labels_df[label_header].values.tolist()) + labels_dict = {lab: i for i, lab in enumerate(list(labels_set))} + + print('{} classes: {}'.format(len(labels_set), labels_set)) + + labels = sp.lil_matrix((num_node, len(labels_set))) + labeled_nodes_idx = [] + + print('Loading training set') + + train_idx = [] + train_names = [] + for nod, lab in zip(labels_train_df[nodes_header].values, + labels_train_df[label_header].values): + nod = np.unicode(to_unicode(nod)) # type: unicode + if nod in nodes_u_dict: + labeled_nodes_idx.append(nodes_u_dict[nod]) + label_idx = labels_dict[lab] + labels[labeled_nodes_idx[-1], label_idx] = 1 + train_idx.append(nodes_u_dict[nod]) + train_names.append(nod) + else: + print(u'Node not in dictionary, skipped: ', + nod.encode('utf-8', errors='replace')) + + print('Loading test set') + + test_idx = [] + test_names = [] + for nod, lab in zip(labels_test_df[nodes_header].values, + labels_test_df[label_header].values): + nod = np.unicode(to_unicode(nod)) + if nod in nodes_u_dict: + labeled_nodes_idx.append(nodes_u_dict[nod]) + label_idx = labels_dict[lab] + labels[labeled_nodes_idx[-1], label_idx] = 1 + test_idx.append(nodes_u_dict[nod]) + test_names.append(nod) + else: + print(u'Node not in dictionary, skipped: ', + nod.encode('utf-8', errors='replace')) + + labeled_nodes_idx = sorted(labeled_nodes_idx) + labels = labels.tocsr() + print('Number of classes: ', labels.shape[1]) + + _save_sparse_csr(labels_file, labels) + + np.save(train_idx_file, train_idx) + np.save(test_idx_file, test_idx) + + # np.save(train_names_file, train_names) + # np.save(test_names_file, test_names) + + # pkl.dump(relations_dict, open(rel_dict_file, 'wb')) + + # end if + + return num_node, edge_list, num_rel, labels, labeled_nodes_idx, train_idx, test_idx + + +def to_unicode(input): + # FIXME (lingfan): not sure about python 2 and 3 str compatibility + return str(input) + """ lingfan: comment out for now + if isinstance(input, unicode): + return input + elif isinstance(input, str): + return input.decode('utf-8', errors='replace') + return str(input).decode('utf-8', errors='replace') + """ + + +def _read_dictionary(filename): + d = {} + with open(filename, 'r+') as f: + for line in f: + line = line.strip().split('\t') + d[line[1]] = int(line[0]) + return d + + +def _read_triplets(filename): + with open(filename, 'r+') as f: + for line in f: + processed_line = line.strip().split('\t') + yield processed_line + + +def _read_triplets_as_list(filename, entity_dict, relation_dict): + l = [] + for triplet in _read_triplets(filename): + s = entity_dict[triplet[0]] + r = relation_dict[triplet[1]] + o = entity_dict[triplet[2]] + l.append([s, r, o]) + return l diff --git a/openhgnn/utils/lte_process_data.py b/openhgnn/utils/lte_process_data.py new file mode 100644 index 00000000..24700fd9 --- /dev/null +++ b/openhgnn/utils/lte_process_data.py @@ -0,0 +1,31 @@ +from collections import defaultdict as ddict + + +def process(dataset, num_rel): + """ + pre-process dataset + :param dataset: a dictionary containing 'train', 'valid' and 'test' data. + :param num_rel: relation number + :return: + """ + sr2o = ddict(set) + for subj, rel, obj in dataset['train']: + sr2o[(subj, rel)].add(obj) + sr2o[(obj, rel + num_rel)].add(subj) + sr2o_train = {k: list(v) for k, v in sr2o.items()} + for split in ['valid', 'test']: + for subj, rel, obj in dataset[split]: + sr2o[(subj, rel)].add(obj) + sr2o[(obj, rel + num_rel)].add(subj) + sr2o_all = {k: list(v) for k, v in sr2o.items()} + triplets = ddict(list) + + for (subj, rel), obj in sr2o_train.items(): + triplets['train'].append({'triple': (subj, rel, -1), 'label': sr2o_train[(subj, rel)]}) + for split in ['valid', 'test']: + for subj, rel, obj in dataset[split]: + triplets[f"{split}_tail"].append({'triple': (subj, rel, obj), 'label': sr2o_all[(subj, rel)]}) + triplets[f"{split}_head"].append( + {'triple': (obj, rel + num_rel, subj), 'label': sr2o_all[(obj, rel + num_rel)]}) + triplets = dict(triplets) + return triplets diff --git a/openhgnn/utils/utils.py b/openhgnn/utils/utils.py index aa9e66cb..240b526b 100644 --- a/openhgnn/utils/utils.py +++ b/openhgnn/utils/utils.py @@ -8,7 +8,7 @@ import random from . import load_HIN, load_KG, load_OGB from .best_config import BEST_CONFIGS - +from typing import Optional, Tuple def sum_up_params(model): """ Count the model parameters """ @@ -491,3 +491,97 @@ def get_ntypes_from_canonical_etypes(canonical_etypes=None): ntypes.add(src) ntypes.add(dst) return ntypes + +def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): + if dim < 0: + dim = other.dim() + dim + if src.dim() == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.dim(), other.dim()): + src = src.unsqueeze(-1) + src = src.expand(other.size()) + return src + +def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + index = broadcast(index, src, dim) + if out is None: + size = list(src.size()) + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + out = torch.zeros(size, dtype=src.dtype, device=src.device) + return out.scatter_add_(dim, index, src) + else: + return out.scatter_add_(dim, index, src) + + +def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + return scatter_sum(src, index, dim, out, dim_size) + + +def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size) + + +def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> torch.Tensor: + out = scatter_sum(src, index, dim, out, dim_size) + dim_size = out.size(dim) + + index_dim = dim + if index_dim < 0: + index_dim = index_dim + src.dim() + if index.dim() <= index_dim: + index_dim = index.dim() - 1 + + ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) + count = scatter_sum(ones, index, index_dim, None, dim_size) + count[count < 1] = 1 + count = broadcast(count, out, dim) + if out.is_floating_point(): + out.true_divide_(count) + else: + out.div_(count, rounding_mode='floor') + return out + + +def scatter_min( + src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size) + + +def scatter_max( + src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size) + + +def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1, + out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, + reduce: str = "sum") -> torch.Tensor: + if reduce == 'sum' or reduce == 'add': + return scatter_sum(src, index, dim, out, dim_size) + if reduce == 'mul': + return scatter_mul(src, index, dim, out, dim_size) + elif reduce == 'mean': + return scatter_mean(src, index, dim, out, dim_size) + elif reduce == 'min': + return scatter_min(src, index, dim, out, dim_size)[0] + elif reduce == 'max': + return scatter_max(src, index, dim, out, dim_size)[0] + else: + raise ValueError \ No newline at end of file diff --git a/openhgnn/utils/wgcn_batch_prepare.py b/openhgnn/utils/wgcn_batch_prepare.py new file mode 100644 index 00000000..10b6d6a8 --- /dev/null +++ b/openhgnn/utils/wgcn_batch_prepare.py @@ -0,0 +1,63 @@ +import torch +import numpy as np + +class EvalBatchPrepare(object): + def __init__(self, eval_dict, num_rels): + # eval_dict uses all the data in train, valid, and test + self.eval_dict = eval_dict + self.num_rels = num_rels + + def get_batch(self, batch_trip): + batch_trip = np.asarray(batch_trip) + e1_batch = batch_trip[:, 0] + rel_batch = batch_trip[:, 1] + e2_batch = batch_trip[:, 2] + # reversed relation id is `rel + num_rels` + rel_reverse_batch = rel_batch + self.num_rels + + head_to_multi_tail_list = [] + tail_to_multi_head_list = [] + keys1 = list(zip(e1_batch, rel_batch)) + keys2 = list(zip(e2_batch, rel_reverse_batch)) + # get (h,r)'s tails + for key in keys1: + cur_tail_id_list = list(self.eval_dict.get(key)) + head_to_multi_tail_list.append(np.asarray(cur_tail_id_list)) + # get (t,r_reverse)'s heads + for key in keys2: + cur_tail_id_list = list(self.eval_dict.get(key)) + tail_to_multi_head_list.append(np.asarray(cur_tail_id_list)) + + e1_batch = torch.from_numpy(e1_batch).reshape(-1, 1) + e2_batch = torch.from_numpy(e2_batch).reshape(-1, 1) + rel_batch = torch.from_numpy(rel_batch).reshape(-1, 1) + rel_reverse_batch = torch.from_numpy(rel_reverse_batch).reshape(-1, 1) + + return e1_batch, e2_batch, rel_batch, rel_reverse_batch, head_to_multi_tail_list, tail_to_multi_head_list + + +class TrainBatchPrepare(object): + def __init__(self, train_dict, num_nodes): + self.entity_num = num_nodes + self.train_dict = train_dict + + def get_batch(self, batch_trip): + # batch_trip shape is (batch_size, 3) + batch_trip = np.asarray(batch_trip) + e1_batch = batch_trip[:, 0] + rel_batch = batch_trip[:, 1] + keys = list(zip(e1_batch, rel_batch)) + + # get (h,r) corresponding tails, convert them to one-hot label + labels_one_hot = np.zeros((batch_trip.shape[0], self.entity_num), dtype=np.float32) + cur_row = 0 + for key in keys: + indices = list(self.train_dict.get(key)) + labels_one_hot[cur_row][indices] = 1 + cur_row += 1 + + e1_batch = torch.from_numpy(e1_batch).reshape(-1, 1) + rel_batch = torch.from_numpy(rel_batch).reshape(-1, 1) + labels_one_hot = torch.from_numpy(labels_one_hot) + + return e1_batch, rel_batch, labels_one_hot diff --git a/openhgnn/utils/wgcn_data.py b/openhgnn/utils/wgcn_data.py new file mode 100644 index 00000000..18fca91f --- /dev/null +++ b/openhgnn/utils/wgcn_data.py @@ -0,0 +1,564 @@ +""" +based on the implementation in DGL +(https://github.com/dmlc/dgl/blob/master/python/dgl/contrib/data/knowledge_graph.py) +""" + +""" Knowledge graph dataset for Relational-GCN +Code adapted from authors' implementation of Relational-GCN +https://github.com/tkipf/relational-gcn +https://github.com/MichSchli/RelationPrediction +""" + +import numpy as np +import scipy.sparse as sp +import os +import gzip +import rdflib as rdf +import pandas as pd +from collections import Counter +import requests +import zipfile +import io + +from dgl.data.utils import download, extract_archive, get_download_dir, _get_dgl_url + +np.random.seed(123) + +_downlaod_prefix = _get_dgl_url('dataset/') + + +def load_data(dataset, bfs_level=3, relabel=False): + if dataset in ['wn18rr', 'FB15k-237', 'yago']: + return load_link(dataset) + else: + raise ValueError('Unknown dataset: {}'.format(dataset)) + + +class RGCNEntityDataset(object): + """RGCN Entity Classification dataset + + The dataset contains a graph depicting the connectivity of a knowledge + base. Currently, four knowledge bases from the + `RGCN paper `_ are supported: aifb, + mutag, bgs, and am. + + The original knowledge base is stored as an RDF file, and this class will + download and parse the RDF file, and performs preprocessing. + + An object of this class has 11 member attributes needed for entity + classification: + + num_nodes: int + number of entities of knowledge base + num_rels: int + number of relations (including reverse relation) of knowledge base + num_classes: int + number of classes/labels that of entities in knowledge base + edge_src: numpy.array + source node ids of all edges + edge_dst: numpy.array + destination node ids of all edges + edge_type: numpy.array + type of all edges + edge_norm: numpy.array + normalization factor of all edges + labels: numpy.array + labels of node entities + train_idx: numpy.array + ids of entities used for training + valid_idx: numpy.array + ids of entities used for validation + test_idx: numpy.array + ids of entities used for testing + + Usually, users don't need to directly use this class. Instead, DGL provides + wrapper function to load data (see example below). + When loading data, besides specifying dataset name, user can provide two + optional arguments: + + Parameters + ---------- + bfs_level: int + prune out nodes that are more than ``bfs_level`` hops away from + labeled nodes, i.e., nodes won't be touched during propagation. If set + to a number less or equal to 0, all nodes will be retained. + relabel: bool + After pruning, whether or not to relabel all nodes with consecutive + node ids + + Examples + -------- + Load aifb dataset, prune out nodes that are more than 3 hops away from + labeled nodes, and relabel the remaining nodes with consecutive ids + + >>> from dgl.contrib.data import load_data + >>> data = load_data(dataset='aifb', bfs_level=3, relabel=True) + + """ + + def __init__(self, name): + self.name = name + self.dir = get_download_dir() + tgz_path = os.path.join(self.dir, '{}.tgz'.format(self.name)) + download(_downlaod_prefix + '{}.tgz'.format(self.name), tgz_path) + self.dir = os.path.join(self.dir, self.name) + extract_archive(tgz_path, self.dir) + + def load(self, bfs_level=2, relabel=False): + self.num_nodes, edges, self.num_rels, self.labels, labeled_nodes_idx, self.train_idx, self.test_idx = _load_data( + self.name, self.dir) + + # bfs to reduce edges + if bfs_level > 0: + print("removing nodes that are more than {} hops away".format(bfs_level)) + row, col, edge_type = edges.transpose() + A = sp.csr_matrix((np.ones(len(row)), (row, col)), + shape=(self.num_nodes, self.num_nodes)) + bfs_generator = _bfs_relational(A, labeled_nodes_idx) + lvls = list() + lvls.append(set(labeled_nodes_idx)) + for _ in range(bfs_level): + lvls.append(next(bfs_generator)) + to_delete = list(set(range(self.num_nodes)) - set.union(*lvls)) + eid_to_delete = np.isin(row, to_delete) + np.isin(col, to_delete) + eid_to_keep = np.logical_not(eid_to_delete) + self.edge_src = row[eid_to_keep] + self.edge_dst = col[eid_to_keep] + self.edge_type = edge_type[eid_to_keep] + + if relabel: + uniq_nodes, edges = np.unique( + (self.edge_src, self.edge_dst), return_inverse=True) + self.edge_src, self.edge_dst = np.reshape(edges, (2, -1)) + node_map = np.zeros(self.num_nodes, dtype=int) + self.num_nodes = len(uniq_nodes) + node_map[uniq_nodes] = np.arange(self.num_nodes) + self.labels = self.labels[uniq_nodes] + self.train_idx = node_map[self.train_idx] + self.test_idx = node_map[self.test_idx] + print("{} nodes left".format(self.num_nodes)) + else: + self.edge_src, self.edge_dst, self.edge_type = edges.transpose() + + # normalize by dst degree + _, inverse_index, count = np.unique( + (self.edge_dst, self.edge_type), axis=1, return_inverse=True, return_counts=True) + degrees = count[inverse_index] + self.edge_norm = np.ones( + len(self.edge_dst), dtype=np.float32) / degrees.astype(np.float32) + + # convert to pytorch label format + self.num_classes = self.labels.shape[1] + self.labels = np.argmax(self.labels, axis=1) + + +class RGCNLinkDataset(object): + """RGCN link prediction dataset + + The dataset contains a graph depicting the connectivity of a knowledge + base. Currently, the knowledge bases from the + `RGCN paper `_ supported are + FB15k-237, FB15k, wn18 + + The original knowledge base is stored as an RDF file, and this class will + download and parse the RDF file, and performs preprocessing. + + An object of this class has 5 member attributes needed for link + prediction: + + num_nodes: int + number of entities of knowledge base + num_rels: int + number of relations (including reverse relation) of knowledge base + train: numpy.array + all relation triplets (src, rel, dst) for training + valid: numpy.array + all relation triplets (src, rel, dst) for validation + test: numpy.array + all relation triplets (src, rel, dst) for testing + + Usually, user don't need to directly use this class. Instead, DGL provides + wrapper function to load data (see example below). + + Examples + -------- + Load FB15k-237 dataset + + >>> from dgl.contrib.data import load_data + >>> data = load_data(dataset='FB15k-237') + + """ + + def __init__(self, name): + self.name = name + self.dir = './data' + path_ckp = os.path.join(self.dir, self.name) + self.dir = os.path.join(self.dir, self.name) + print(path_ckp) + folder = os.path.exists(path_ckp) + if not folder: # 判断是否存在文件夹如果不存在则创建为文件夹 + os.makedirs(path_ckp) # makedirs 创建文件时如果路径不存在会创建这个路径 + # 下载数据 + url = "https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/openhgnn/FB15k-237.zip" + response = requests.get(url) + with zipfile.ZipFile(io.BytesIO(response.content)) as myzip: + myzip.extractall(path_ckp) + print("--- download data ---") + + else: + print("--- There is data! ---") + + def load(self): + entity_path = os.path.join(self.dir, 'entities.dict') + relation_path = os.path.join(self.dir, 'relations.dict') + train_path = os.path.join(self.dir, 'train.txt') + valid_path = os.path.join(self.dir, 'valid.txt') + test_path = os.path.join(self.dir, 'test.txt') + entity_dict = _read_dictionary(entity_path) + relation_dict = _read_dictionary(relation_path) + self.train = np.asarray(_read_triplets_as_list( + train_path, entity_dict, relation_dict)) + self.valid = np.asarray(_read_triplets_as_list( + valid_path, entity_dict, relation_dict)) + self.test = np.asarray(_read_triplets_as_list( + test_path, entity_dict, relation_dict)) + self.num_nodes = len(entity_dict) + print("# entities: {}".format(self.num_nodes)) + self.num_rels = len(relation_dict) + print("# relations: {}".format(self.num_rels)) + print("# edges: {}".format(len(self.train))) + + +def load_entity(dataset, bfs_level, relabel): + data = RGCNEntityDataset(dataset) + data.load(bfs_level, relabel) + return data + + +def load_link(dataset): + data = RGCNLinkDataset(dataset) + data.load() + return data + + +def _sp_row_vec_from_idx_list(idx_list, dim): + """Create sparse vector of dimensionality dim from a list of indices.""" + shape = (1, dim) + data = np.ones(len(idx_list)) + row_ind = np.zeros(len(idx_list)) + col_ind = list(idx_list) + return sp.csr_matrix((data, (row_ind, col_ind)), shape=shape) + + +def _get_neighbors(adj, nodes): + """Takes a set of nodes and a graph adjacency matrix and returns a set of neighbors.""" + sp_nodes = _sp_row_vec_from_idx_list(list(nodes), adj.shape[1]) + sp_neighbors = sp_nodes.dot(adj) + neighbors = set(sp.find(sp_neighbors)[1]) # convert to set of indices + return neighbors + + +def _bfs_relational(adj, roots): + """ + BFS for graphs with multiple edge types. Returns list of level sets. + Each entry in list corresponds to relation specified by adj_list. + """ + visited = set() + current_lvl = set(roots) + + next_lvl = set() + + while current_lvl: + + for v in current_lvl: + visited.add(v) + + next_lvl = _get_neighbors(adj, current_lvl) + next_lvl -= visited # set difference + + yield next_lvl + + current_lvl = set.union(next_lvl) + + +class RDFReader(object): + __graph = None + __freq = {} + + def __init__(self, file): + + self.__graph = rdf.Graph() + + if file.endswith('nt.gz'): + with gzip.open(file, 'rb') as f: + self.__graph.parse(file=f, format='nt') + else: + self.__graph.parse(file, format=rdf.util.guess_format(file)) + + # See http://rdflib.readthedocs.io for the rdflib documentation + + self.__freq = Counter(self.__graph.predicates()) + + print("Graph loaded, frequencies counted.") + + def triples(self, relation=None): + for s, p, o in self.__graph.triples((None, relation, None)): + yield s, p, o + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.__graph.destroy("store") + self.__graph.close(True) + + def subjectSet(self): + return set(self.__graph.subjects()) + + def objectSet(self): + return set(self.__graph.objects()) + + def relationList(self): + """ + Returns a list of relations, ordered descending by frequency + :return: + """ + res = list(set(self.__graph.predicates())) + res.sort(key=lambda rel: - self.freq(rel)) + return res + + def __len__(self): + return len(self.__graph) + + def freq(self, rel): + if rel not in self.__freq: + return 0 + return self.__freq[rel] + + +def _load_sparse_csr(filename): + loader = np.load(filename) + return sp.csr_matrix((loader['data'], loader['indices'], loader['indptr']), + shape=loader['shape'], dtype=np.float32) + + +def _save_sparse_csr(filename, array): + np.savez(filename, data=array.data, indices=array.indices, + indptr=array.indptr, shape=array.shape) + + +def _load_data(dataset_str='aifb', dataset_path=None): + """ + + :param dataset_str: + :param rel_layers: + :param limit: If > 0, will only load this many adj. matrices + All adjacencies are preloaded and saved to disk, + but only a limited a then restored to memory. + :return: + """ + + print('Loading dataset', dataset_str) + + graph_file = os.path.join( + dataset_path, '{}_stripped.nt.gz'.format(dataset_str)) + task_file = os.path.join(dataset_path, 'completeDataset.tsv') + train_file = os.path.join(dataset_path, 'trainingSet.tsv') + test_file = os.path.join(dataset_path, 'testSet.tsv') + if dataset_str == 'am': + label_header = 'label_cateogory' + nodes_header = 'proxy' + elif dataset_str == 'aifb': + label_header = 'label_affiliation' + nodes_header = 'person' + elif dataset_str == 'mutag': + label_header = 'label_mutagenic' + nodes_header = 'bond' + elif dataset_str == 'bgs': + label_header = 'label_lithogenesis' + nodes_header = 'rock' + else: + raise NameError('Dataset name not recognized: ' + dataset_str) + + edge_file = os.path.join(dataset_path, 'edges.npz') + labels_file = os.path.join(dataset_path, 'labels.npz') + train_idx_file = os.path.join(dataset_path, 'train_idx.npy') + test_idx_file = os.path.join(dataset_path, 'test_idx.npy') + # train_names_file = os.path.join(dataset_path, 'train_names.npy') + # test_names_file = os.path.join(dataset_path, 'test_names.npy') + # rel_dict_file = os.path.join(dataset_path, 'rel_dict.pkl') + # nodes_file = os.path.join(dataset_path, 'nodes.pkl') + + if os.path.isfile(edge_file) and os.path.isfile(labels_file) and \ + os.path.isfile(train_idx_file) and os.path.isfile(test_idx_file): + + # load precomputed adjacency matrix and labels + all_edges = np.load(edge_file) + num_node = all_edges['n'].item() + edge_list = all_edges['edges'] + num_rel = all_edges['nrel'].item() + + print('Number of nodes: ', num_node) + print('Number of edges: ', len(edge_list)) + print('Number of relations: ', num_rel) + + labels = _load_sparse_csr(labels_file) + labeled_nodes_idx = list(labels.nonzero()[0]) + + print('Number of classes: ', labels.shape[1]) + + train_idx = np.load(train_idx_file) + test_idx = np.load(test_idx_file) + + # train_names = np.load(train_names_file) + # test_names = np.load(test_names_file) + # relations_dict = pkl.load(open(rel_dict_file, 'rb')) + + else: + + # loading labels of nodes + labels_df = pd.read_csv(task_file, sep='\t', encoding='utf-8') + labels_train_df = pd.read_csv(train_file, sep='\t', encoding='utf8') + labels_test_df = pd.read_csv(test_file, sep='\t', encoding='utf8') + + with RDFReader(graph_file) as reader: + + relations = reader.relationList() + subjects = reader.subjectSet() + objects = reader.objectSet() + + nodes = list(subjects.union(objects)) + num_node = len(nodes) + num_rel = len(relations) + num_rel = 2 * num_rel + 1 # +1 is for self-relation + + assert num_node < np.iinfo(np.int32).max + print('Number of nodes: ', num_node) + print('Number of relations: ', num_rel) + + relations_dict = {rel: i for i, rel in enumerate(list(relations))} + nodes_dict = {node: i for i, node in enumerate(nodes)} + + edge_list = [] + # self relation + for i in range(num_node): + edge_list.append((i, i, 0)) + + for i, (s, p, o) in enumerate(reader.triples()): + src = nodes_dict[s] + dst = nodes_dict[o] + assert src < num_node and dst < num_node + rel = relations_dict[p] + # relation id 0 is self-relation, so others should start with 1 + edge_list.append((src, dst, 2 * rel + 1)) + # reverse relation + edge_list.append((dst, src, 2 * rel + 2)) + + # sort indices by destination + edge_list = sorted(edge_list, key=lambda x: (x[1], x[0], x[2])) + edge_list = np.asarray(edge_list, dtype=np.int) + print('Number of edges: ', len(edge_list)) + + np.savez(edge_file, edges=edge_list, n=np.asarray( + num_node), nrel=np.asarray(num_rel)) + + nodes_u_dict = {np.unicode(to_unicode(key)): val for key, val in + nodes_dict.items()} + + labels_set = set(labels_df[label_header].values.tolist()) + labels_dict = {lab: i for i, lab in enumerate(list(labels_set))} + + print('{} classes: {}'.format(len(labels_set), labels_set)) + + labels = sp.lil_matrix((num_node, len(labels_set))) + labeled_nodes_idx = [] + + print('Loading training set') + + train_idx = [] + train_names = [] + for nod, lab in zip(labels_train_df[nodes_header].values, + labels_train_df[label_header].values): + nod = np.unicode(to_unicode(nod)) # type: unicode + if nod in nodes_u_dict: + labeled_nodes_idx.append(nodes_u_dict[nod]) + label_idx = labels_dict[lab] + labels[labeled_nodes_idx[-1], label_idx] = 1 + train_idx.append(nodes_u_dict[nod]) + train_names.append(nod) + else: + print(u'Node not in dictionary, skipped: ', + nod.encode('utf-8', errors='replace')) + + print('Loading test set') + + test_idx = [] + test_names = [] + for nod, lab in zip(labels_test_df[nodes_header].values, + labels_test_df[label_header].values): + nod = np.unicode(to_unicode(nod)) + if nod in nodes_u_dict: + labeled_nodes_idx.append(nodes_u_dict[nod]) + label_idx = labels_dict[lab] + labels[labeled_nodes_idx[-1], label_idx] = 1 + test_idx.append(nodes_u_dict[nod]) + test_names.append(nod) + else: + print(u'Node not in dictionary, skipped: ', + nod.encode('utf-8', errors='replace')) + + labeled_nodes_idx = sorted(labeled_nodes_idx) + labels = labels.tocsr() + print('Number of classes: ', labels.shape[1]) + + _save_sparse_csr(labels_file, labels) + + np.save(train_idx_file, train_idx) + np.save(test_idx_file, test_idx) + + # np.save(train_names_file, train_names) + # np.save(test_names_file, test_names) + + # pkl.dump(relations_dict, open(rel_dict_file, 'wb')) + + # end if + + return num_node, edge_list, num_rel, labels, labeled_nodes_idx, train_idx, test_idx + + +def to_unicode(input): + # FIXME (lingfan): not sure about python 2 and 3 str compatibility + return str(input) + """ lingfan: comment out for now + if isinstance(input, unicode): + return input + elif isinstance(input, str): + return input.decode('utf-8', errors='replace') + return str(input).decode('utf-8', errors='replace') + """ + + +def _read_dictionary(filename): + d = {} + with open(filename, 'r+') as f: + for line in f: + line = line.strip().split('\t') + d[line[1]] = int(line[0]) + return d + + +def _read_triplets(filename): + with open(filename, 'r+') as f: + for line in f: + processed_line = line.strip().split('\t') + yield processed_line + + +def _read_triplets_as_list(filename, entity_dict, relation_dict): + l = [] + for triplet in _read_triplets(filename): + s = entity_dict[triplet[0]] + r = relation_dict[triplet[1]] + o = entity_dict[triplet[2]] + l.append([s, r, o]) + return l diff --git a/openhgnn/utils/wgcn_evaluation_dgl.py b/openhgnn/utils/wgcn_evaluation_dgl.py new file mode 100644 index 00000000..6abefdd7 --- /dev/null +++ b/openhgnn/utils/wgcn_evaluation_dgl.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +""" +code from github.com/JD-AI-Research-Silicon-Valley/SACN +""" + +import torch +import numpy as np +import time + + +def ranking_and_hits(g, v, model, dev_rank_batcher, name, entity_id, device, logger): + print('') + print(name) + print('') + hits_left = [] + hits_right = [] + hits = [] + ranks = [] + ranks_left = [] + ranks_right = [] + for i in range(10): + hits_left.append([]) + hits_right.append([]) + hits.append([]) + # with open('output_model2.txt', 'w') as file: + for i, batch_tuple in enumerate(dev_rank_batcher): + # print("evaluation batch {}".format(i)) + e1, e2, rel, rel_reverse, e2_multi1, e2_multi2 = batch_tuple + e1 = e1.to(device) + e2 = e2.to(device) + rel = rel.to(device) + rel_reverse = rel_reverse.to(device) + + pred1 = model.forward(g, v, e1, rel, entity_id) + pred2 = model.forward(g, v, e2, rel_reverse, entity_id) + pred1, pred2 = pred1.data, pred2.data + e1, e2 = e1.data, e2.data + # e2_multi1, e2_multi2 = e2_multi1.data, e2_multi2.data + + batch_score_start_time = time.time() + for i in range(len(e2_multi1)): + # these filters contain ALL labels + filter1 = e2_multi1[i] + filter2 = e2_multi2[i] + + # save the prediction that is relevant + target_value1 = pred1[i, e2.cpu().numpy()[i, 0].item()].item() + target_value2 = pred2[i, e1.cpu().numpy()[i, 0].item()].item() + # zero all known cases (this are not interesting) + # this corresponds to the filtered setting + pred1[i][filter1] = 0.0 + pred2[i][filter2] = 0.0 + # write base the saved values + pred1[i][e2[i]] = target_value1 + pred2[i][e1[i]] = target_value2 + + batch_sort_and_rank_start_time = time.time() + # sort and rank + max_values, argsort1 = torch.sort(pred1, 1, descending=True) + max_values, argsort2 = torch.sort(pred2, 1, descending=True) + + argsort1 = argsort1.cpu().numpy() + argsort2 = argsort2.cpu().numpy() + for i in range(len(e2_multi1)): + # find the rank of the target entities + rank1 = np.where(argsort1[i] == e2.cpu().numpy()[i, 0])[0][0] + rank2 = np.where(argsort2[i] == e1.cpu().numpy()[i, 0])[0][0] + # rank+1, since the lowest rank is rank 1 not rank 0 + ranks.append(rank1 + 1) + ranks_left.append(rank1 + 1) + ranks.append(rank2 + 1) + ranks_right.append(rank2 + 1) + + # this could be done more elegantly, but here you go + for hits_level in range(10): + if rank1 <= hits_level: + hits[hits_level].append(1.0) + hits_left[hits_level].append(1.0) + else: + hits[hits_level].append(0.0) + hits_left[hits_level].append(0.0) + + if rank2 <= hits_level: + hits[hits_level].append(1.0) + hits_right[hits_level].append(1.0) + else: + hits[hits_level].append(0.0) + hits_right[hits_level].append(0.0) + logger.info('MRR: {0}, MRR left: {1}, MRR right: {2}'.format( + np.mean(1. / np.array(ranks)), np.mean(1. / np.array(ranks_left)), np.mean(1. / np.array(ranks_right)))) + logger.info('MR: {0}, MR left: {1}, MR right: {2}'.format( + np.mean(ranks), np.mean(ranks_left), np.mean(ranks_right))) + for i in [0, 2, 9]: + logger.info('Hits @{0}: {1}, Hits left @{0}: {2}, Hits right @{0}: {3}'.format( + i + 1, np.mean(hits[i]), np.mean(hits_left[i]), np.mean(hits_right[i]))) + + print('-' * 50) + return np.mean(1. / np.array(ranks)) diff --git a/openhgnn/utils/wgcn_utils.py b/openhgnn/utils/wgcn_utils.py new file mode 100644 index 00000000..409db288 --- /dev/null +++ b/openhgnn/utils/wgcn_utils.py @@ -0,0 +1,34 @@ +from os import path +import torch + + +class EarlyStopping: + def __init__(self, save_path, model_name, patience=10): + self.patience = patience + self.counter = 0 + self.best_score = None + self.early_stop = False + self.save_path = save_path + self.model_name = model_name + + def step(self, acc, model): + score = acc + if self.best_score is None: + self.best_score = score + self.save_checkpoint(model) + elif score < self.best_score: + self.counter += 1 + print( + f'EarlyStopping counter: {self.counter} out of {self.patience}') + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_score = score + self.save_checkpoint(model) + self.counter = 0 + return self.early_stop + + def save_checkpoint(self, model): + '''Saves model when validation loss decrease.''' + torch.save(model.state_dict(), path.join( + self.save_path, self.model_name+'.pt')) diff --git a/requirements.txt b/requirements.txt index 47b11913..96de453d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,11 @@ ogb>=1.3.1 optuna rdflib colorama +igraph torch>=1.9.0 -f https://data.dgl.ai/wheels/repo.html dgl>=0.8.0 -TensorBoard>=2.0.0 \ No newline at end of file +TensorBoard>=2.0.0 +lmdb +rdflib +ordered_set \ No newline at end of file