Skip to content

Commit

Permalink
[skip ci] Code formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
actions-user committed Nov 14, 2023
1 parent b7f15ef commit 03cc369
Show file tree
Hide file tree
Showing 25 changed files with 94 additions and 80 deletions.
3 changes: 1 addition & 2 deletions scripts/WikiExtractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3266,7 +3266,6 @@ def reduce_process(


def main():

parser = argparse.ArgumentParser(
prog=os.path.basename(sys.argv[0]),
formatter_class=argparse.RawDescriptionHelpFormatter,
Expand Down Expand Up @@ -3405,7 +3404,7 @@ def main():

try:
power = "kmg".find(args.bytes[-1].lower()) + 1
file_size = int(args.bytes[:-1]) * 1024 ** power
file_size = int(args.bytes[:-1]) * 1024**power
if file_size < minFileSize:
raise ValueError()
except ValueError:
Expand Down
10 changes: 9 additions & 1 deletion scripts/comparison_BLINK/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,15 @@ def generate_response(self, text, spans):
"context_right": m["context"][1].lower(),
}
data_to_link.append(temp)
_, _, _, _, _, predictions, scores, = main_dense.run(
(
_,
_,
_,
_,
_,
predictions,
scores,
) = main_dense.run(
self.argss, self.logger, *self.model, test_data=data_to_link
)

Expand Down
7 changes: 4 additions & 3 deletions scripts/efficiency_test.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import os

import numpy as np
import requests
import os

from REL.training_datasets import TrainingEvaluationDatasets

np.random.seed(seed=42)

base_url = os.environ.get("REL_BASE_URL")
wiki_version = "wiki_2019"
host = 'localhost'
port = '5555'
host = "localhost"
port = "5555"
datasets = TrainingEvaluationDatasets(base_url, wiki_version).load()["aida_testB"]

# random_docs = np.random.choice(list(datasets.keys()), 50)
Expand Down
1 change: 1 addition & 0 deletions scripts/test_server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

import requests

# Script for testing the implementation of the conversational entity linking API
Expand Down
1 change: 0 additions & 1 deletion scripts/truecase/relq.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@


def main(filename):

fi = open(filename, "r")
document = {
"text": "",
Expand Down
15 changes: 7 additions & 8 deletions scripts/update_db_pem.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sqlite3
import argparse
import sqlite3


def convert(a):
Expand All @@ -11,7 +11,9 @@ def convert_db(db):
cur = con.cursor()
cur2 = con.cursor()

cur.execute("CREATE TABLE IF NOT EXISTS wiki2(word TEXT PRIMARY KEY, p_e_m BLOB, lower TEXT, freq INTEGER)")
cur.execute(
"CREATE TABLE IF NOT EXISTS wiki2(word TEXT PRIMARY KEY, p_e_m BLOB, lower TEXT, freq INTEGER)"
)
cur.execute("SELECT word, p_e_m, lower, freq FROM wiki")

cur2.execute("BEGIN TRANSACTION")
Expand All @@ -33,10 +35,7 @@ def convert_db(db):
cur.execute("COMMIT")


if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'db',
help='Path to database to convert'
)
convert_db(vars(parser.parse_args())['db'])
parser.add_argument("db", help="Path to database to convert")
convert_db(vars(parser.parse_args())["db"])
2 changes: 1 addition & 1 deletion src/REL/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.0'
__version__ = "0.1.0"
36 changes: 21 additions & 15 deletions src/REL/crel/conv_el.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,23 @@
import sys
from pathlib import Path

from REL.response_handler import ResponseHandler

from .bert_md import BERT_MD
from .s2e_pe import pe_data
from .s2e_pe.pe import EEMD, PEMD
from REL.response_handler import ResponseHandler


class ConvEL:
def __init__(
self,
base_url=".",
wiki_version="wiki_2019",
ed_model=None,
user_config=None,
threshold=0,
ner_model="bert_conv-td",
):
self,
base_url=".",
wiki_version="wiki_2019",
ed_model=None,
user_config=None,
threshold=0,
ner_model="bert_conv-td",
):
self.threshold = threshold

self.wiki_version = wiki_version
Expand All @@ -29,8 +30,10 @@ def __init__(
if not ed_model:
ed_model = self._default_ed_model()

self.response_handler = ResponseHandler(self.base_url, self.wiki_version, model=ed_model)

self.response_handler = ResponseHandler(
self.base_url, self.wiki_version, model=ed_model
)

self.eemd = EEMD(s2e_pe_model=str(Path(base_url) / "s2e_ast_onto"))
self.pemd = PEMD()

Expand All @@ -45,10 +48,15 @@ def __init__(

def _default_ed_model(self):
from REL.entity_disambiguation import EntityDisambiguation
return EntityDisambiguation(self.base_url, self.wiki_version, user_config={

return EntityDisambiguation(
self.base_url,
self.wiki_version,
user_config={
"mode": "eval",
"model_path": f"{self.base_url}/{self.wiki_version}/generated/model",
})
},
)

def _error_check(self, conv):
assert type(conv) == list
Expand Down Expand Up @@ -163,5 +171,3 @@ def ed(self, text, spans):
"""Change tuple to list to match the output format of REL API."""
response = self.response_handler.generate_response(text=text, spans=spans)
return [list(ent) for ent in response]


1 change: 0 additions & 1 deletion src/REL/crel/s2e_pe/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def _prune_topk_mentions(self, mention_logits, attention_mask):
def _ce_prune_pem_eem(
self, mention_logits, pem_eem_subtokenspan
): # attention_mask, subtoken_map, pem_eem_subtokenspan):

batch_size, seq_length, _ = mention_logits.size()
assert batch_size == 1 # HJ: currently, only batch_size==1 is supported

Expand Down
1 change: 0 additions & 1 deletion src/REL/crel/s2e_pe/pe_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def _error_check(m_spans, t_span):

# Main
for m_span in m_spans:

# if token span is out of mention span (i.e., does not have any overlaps), then go to next
t_out_m = (t_span[1] <= m_span[0]) or (m_span[1] <= t_span[0])
if t_out_m:
Expand Down
1 change: 0 additions & 1 deletion src/REL/crel/s2e_pe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from time import time

import numpy as np

import torch

from .consts import NULL_ID_FOR_COREF
Expand Down
8 changes: 5 additions & 3 deletions src/REL/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from array import array
from functools import lru_cache
from os import makedirs, path
import numpy as np

import numpy as np
import requests


Expand Down Expand Up @@ -144,13 +144,15 @@ def lookup_list(self, w, table_name, column="emb"):
res = [e if e is None else np.frombuffer(e[0], dtype=np.float32)]
else:
ret = self.lookup_many(column, table_name, w)
mapping = {key: np.frombuffer(value, dtype=np.float32) for key, value in ret}
mapping = {
key: np.frombuffer(value, dtype=np.float32) for key, value in ret
}
res = [mapping.get(word) for word in w]

return res

def lookup_many(self, column, table_name, w):
qmarks = ','.join(('?',)*len(w))
qmarks = ",".join(("?",) * len(w))
return self.cursor.execute(
f"select word,{column} from {table_name} where word in ({qmarks})",
w,
Expand Down
1 change: 0 additions & 1 deletion src/REL/db/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def load_word2emb(self, file_name, batch_size=5000, limit=np.inf, reset=False):

def load_wiki(self, p_e_m_index, mention_total_freq, batch_size=5000, reset=False):
if reset:

self.clear()

batch = []
Expand Down
19 changes: 8 additions & 11 deletions src/REL/entity_disambiguation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from REL.training_datasets import TrainingEvaluationDatasets
from REL.vocabulary import Vocabulary


wiki_prefix = "en.wikipedia.org/wiki/"


Expand All @@ -32,6 +31,7 @@ class EntityDisambiguation:
Parent Entity Disambiguation class that directs the various subclasses used
for the ED step.
"""

def __init__(self, base_url, wiki_version, user_config, reset_embeddings=False):
self.base_url = base_url
self.wiki_version = wiki_version
Expand Down Expand Up @@ -135,25 +135,23 @@ def __get_config(self, user_config):
# extract the files in the archive to that directory
# assign config[model_path] accordingly
with tarfile.open(model_path) as f:

def is_within_directory(directory, target):

abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)

prefix = os.path.commonprefix([abs_directory, abs_target])

return prefix == abs_directory

def safe_extract(tar, path=".", members=None, *, numeric_owner=False):

for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise Exception("Attempted Path Traversal in Tar File")

tar.extractall(path, members, numeric_owner=numeric_owner)



tar.extractall(path, members, numeric_owner=numeric_owner)

safe_extract(f, Path("~/.rel_cache").expanduser())
# NOTE: use double stem to deal with e.g. *.tar.gz
# this also handles *.tar correctly
Expand Down Expand Up @@ -528,7 +526,6 @@ def __predict(self, data, include_timing=False, eval_raw=False):
timing = []

for batch in data: # each document is a minibatch

start = time.time()

token_ids = [
Expand Down
2 changes: 1 addition & 1 deletion src/REL/generate_train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from REL.utils import modify_uppercase_phrase, split_in_words_mention



class GenTrainingTest(MentionDetectionBase):
"""
Class responsible for formatting WNED and AIDA datasets that are required for ED local evaluation and training.
Inherits overlapping functions from the Mention Detection class.
"""

def __init__(self, base_url, wiki_version, wikipedia):
self.wned_path = os.path.join(base_url, "generic/test_datasets/wned-datasets/")
self.aida_path = os.path.join(base_url, "generic/test_datasets/AIDA/")
Expand Down
3 changes: 1 addition & 2 deletions src/REL/mention_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from REL.mention_detection_base import MentionDetectionBase



class MentionDetection(MentionDetectionBase):
"""
Class responsible for mention detection.
"""

def __init__(self, base_url, wiki_version):
self.cnt_exact = 0
self.cnt_partial = 0
Expand Down Expand Up @@ -134,7 +134,6 @@ def find_mentions(self, dataset, tagger=None):
for (idx_sent, (sentence, ground_truth_sentence)), snt in zip(
contents.items(), sentences
):

# Only include offset if using Flair.
if is_flair:
offset = raw_text.find(sentence, cum_sent_length)
Expand Down
2 changes: 1 addition & 1 deletion src/REL/mulrel_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from torch.autograd import Variable



class PreRank(torch.nn.Module):
"""
PreRank class is used for preranking entities for a given mention by multiplying entity vectors with
word vectors
"""

def __init__(self, config, embeddings=None):
super(PreRank, self).__init__()
self.config = config
Expand Down
21 changes: 12 additions & 9 deletions src/REL/response_handler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from REL.entity_disambiguation import EntityDisambiguation
from REL.ner import load_flair_ner
from flair.models import SequenceTagger

from REL.entity_disambiguation import EntityDisambiguation
from REL.mention_detection import MentionDetection
from REL.ner import load_flair_ner
from REL.utils import process_results

MD_MODELS = {}


def _get_mention_detection_model(base_url, wiki_version):
"""Return instance of previously generated model for the same wiki version."""
try:
Expand All @@ -30,11 +32,12 @@ def __init__(self, base_url, wiki_version, model, tagger_ner=None):
self.custom_ner = not isinstance(tagger_ner, SequenceTagger)
self.mention_detection = _get_mention_detection_model(base_url, wiki_version)

def generate_response(self,
*,
text: list,
spans: list,
):
def generate_response(
self,
*,
text: list,
spans: list,
):
"""
Generates response for API. Can be either ED only or EL, meaning end-to-end.
Expand All @@ -43,7 +46,7 @@ def generate_response(self,

if len(text) == 0:
return []

processed = {self.API_DOC: [text, spans]}

if len(spans) > 0:
Expand Down Expand Up @@ -74,4 +77,4 @@ def generate_response(self,
if len(result) > 0:
return [*result.values()][0]

return []
return []
Loading

0 comments on commit 03cc369

Please sign in to comment.