Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…to main
  • Loading branch information
mromanello committed Jun 3, 2022
2 parents 985be21 + 3c73c87 commit 6041aaf
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 42 deletions.
90 changes: 53 additions & 37 deletions hipe_commons/helpers/tsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# ======================================================================================================================
# VARIABLES
# ======================================================================================================================

COL_LABELS = [
"TOKEN",
"NE-COARSE-LIT",
Expand Down Expand Up @@ -536,7 +535,48 @@ def write_tsv(documents: List[List[TSVLine]], output_path: str) -> None:
f.write(csv_content)


def tsv_to_dataframe(path: Optional[str] = None, url: Optional[str] = None) -> pd.DataFrame:
def tsv_to_dict(path: Optional[str] = None, url: Optional[str] = None, keep_comments: bool = False) -> Dict[str, List[str]]:
"""The simplest and most straightforward way to get tsv-data into a python structure. This function is used as the
basis for other converters"""

data = get_tsv_data(path, url).split('\n')
header = data[0].split('\t')

if not keep_comments:
dict_ = {k: [] for k in ['n'] + header}

for i, line in enumerate(data[1:]): # As data[0] is the header
if line and not line.startswith('#'):
line = line.split('\t')
dict_['n'].append(i+1) # as we are starting with data[1:]
for j, k in enumerate(header):
dict_[k].append(line[j])
else:
continue

else:
comments, dict_ = {}, None

for i, line in enumerate(data[1:]): # As data[0] is the header

if line:
parsed_line = parse_tsv_line(line, i+1)

if isinstance(parsed_line, TSVComment): # If comment, stock comment's field and value
comments[parsed_line.field] = parsed_line.value

else: # else, appends annotations and comments values to `dict_`
dict_ = {k: [] for k in ['n'] + header + list(comments.keys())} if not dict_ else dict_
for k in dict_.keys():
formated_k = k.lower().replace('-', '_')
dict_[k].append(
getattr(parsed_line, formated_k)
if formated_k in parsed_line._fields else comments[k])


return dict_

def tsv_to_dataframe(path: Optional[str] = None, url: Optional[str] = None, keep_comments:bool=False) -> pd.DataFrame:
"""Converts a HIPE-compliant tsv to a `pd.DataFrame`, keeping comment fields as columns.
Each row corresponds to an annotation row of the tsv (i.e. a token). Commented fields (e.g. `'document_id`) are
Expand All @@ -548,40 +588,13 @@ def tsv_to_dataframe(path: Optional[str] = None, url: Optional[str] = None) -> p
:param str path: Path to a HIPE-compliant tsv file
:param str url: url to a HIPE-compliant tsv file
"""
data = get_tsv_data(path=path, url=url)

comments = {}
header = None
dict_ = None

for i, line in enumerate(data.split('\n')):
return pd.DataFrame(tsv_to_dict(path=path, url=url, keep_comments=keep_comments))

if not line: # Skips empty lines
continue

if not header: # Skips lines until valid header
header = COL_LABELS if line.split('\t') == COL_LABELS else None

else:
parsed_line = parse_tsv_line(line, i)

if isinstance(parsed_line, TSVComment): # If comment, stock comment's field and value
comments[parsed_line.field] = parsed_line.value

else: # else, appends annotations and comments values to `dict_`
dict_ = {k: [] for k in ['n'] + header + list(comments.keys())} if not dict_ else dict_
for k in dict_.keys():
formated_k = k.lower().replace('-', '_')
dict_[k].append(
getattr(parsed_line, formated_k)
if formated_k in parsed_line._fields else comments[k])
return pd.DataFrame(dict_)


def tsv_to_lists(labels: List[str],
path: Optional[str] = None,
url: Optional[str] = None,
segmentation_flag: Union[str, int] = 'EndOf') -> Dict[str, List[List[str]]]:
def tsv_to_segmented_lists(labels: List[str],
path: Optional[str] = None,
url: Optional[str] = None,
segmentation_flag: Union[str, int] = 'EndOf') -> Dict[str, List[List[str]]]:
"""Converts a HIPE-compliant tsv to lists of examples containing lists of tokens,
with their aligned labels and doc_ids.
Expand All @@ -602,7 +615,7 @@ def tsv_to_lists(labels: List[str],
:returns: Dict, see above
"""

df = tsv_to_dataframe(path=path, url=url)
df = tsv_to_dataframe(path=path, url=url, keep_comments=True)
d = {k: [] for k in ['texts', 'doc_ids'] + labels}

doc_id_col = [col for col in df.columns if 'document_id' in col][0]
Expand Down Expand Up @@ -649,7 +662,7 @@ def tsv_to_huggingface_dataset(
to the datasets pyarrow datastructure."""

from datasets import Dataset
data = tsv_to_lists(labels=labels, path=path, url=url, segmentation_flag=segmentation_flag)
data = tsv_to_segmented_lists(labels=labels, path=path, url=url, segmentation_flag=segmentation_flag)
return Dataset.from_dict(data)


Expand Down Expand Up @@ -752,7 +765,7 @@ def __getitem__(self, idx):
def __len__(self):
return len(self.labels)

data = tsv_to_lists(labels=[label_type], path=path, url=url, segmentation_flag=segmentation_flag)
data = tsv_to_segmented_lists(labels=[label_type], path=path, url=url, segmentation_flag=segmentation_flag)

tokenized_texts = tokenizer(data['texts'], is_split_into_words=True, **tokenizer_kwargs)

Expand All @@ -777,3 +790,6 @@ def get_unique_labels(path: Optional[str] = None, url: Optional[str] = None, lab
labels.append('I-' + label)

return labels



17 changes: 12 additions & 5 deletions tests/test_helpers_tsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pip
import pytest

from hipe_commons.helpers.tsv import parse_tsv, HipeDocument, tsv_to_dataframe, tsv_to_lists, tsv_to_torch_dataset, \
get_unique_labels, tsv_to_huggingface_dataset
from hipe_commons.helpers.tsv import parse_tsv, HipeDocument, tsv_to_dataframe, tsv_to_segmented_lists, tsv_to_torch_dataset, \
get_unique_labels, tsv_to_huggingface_dataset, tsv_to_dict


def test_parse_tsv_from_file(sample_tsv_path):
Expand All @@ -30,7 +30,7 @@ def test_tsv_to_dataframe(sample_tsv_url, sample_tsv_string):


def test_tsv_to_lists(sample_tsv_url, sample_tsv_string, sample_label):
d = tsv_to_lists([sample_label], url=sample_tsv_url)
d = tsv_to_segmented_lists([sample_label], url=sample_tsv_url)
segmentation_flag_count = sum([1 for l in sample_tsv_string.split('\n') if 'EndOf' in l])
assert len(d['texts']) in [segmentation_flag_count,
segmentation_flag_count + 1] # In case the file doesn't end with flag.
Expand All @@ -42,7 +42,7 @@ def test_tsv_to_torch_dataset(sample_tsv_url, sample_label):
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
data_lists = tsv_to_lists([sample_label], url=sample_tsv_url)
data_lists = tsv_to_segmented_lists([sample_label], url=sample_tsv_url)
unique_labels = get_unique_labels(label_list=[l for l_list in data_lists[sample_label] for l in l_list])
labels_to_ids = {l: i for i, l in enumerate(unique_labels)}

Expand All @@ -54,6 +54,13 @@ def test_tsv_to_torch_dataset(sample_tsv_url, sample_label):
@pytest.mark.skipif(pip.main(['show', 'datasets']) != 0,
reason="""`datasets` not installed, skipping test.""")
def test_tsv_to_torch_dataset(sample_tsv_url, sample_label):
data_lists = tsv_to_lists([sample_label], url=sample_tsv_url)
data_lists = tsv_to_segmented_lists([sample_label], url=sample_tsv_url)
dataset = tsv_to_huggingface_dataset([sample_label], url=sample_tsv_url)
assert all([a == b['texts'] for a, b in zip(data_lists['texts'], dataset)])


def test_tsv_to_dict(sample_tsv_url, sample_tsv_string):
dict_ = tsv_to_dict(url=sample_tsv_url)
# Make sure there are as many annotation row in the file as there are rows in the df
file_lines = len([1 for line in sample_tsv_string.split('\n') if (not line.startswith('#')) and line.strip('\n')])
assert all([len(dict_[k])+1 == file_lines for k in dict_.keys()])

0 comments on commit 6041aaf

Please sign in to comment.