Skip to content

Commit

Permalink
Merge pull request #53 from chakki-works/enhancement/refactorClassifi…
Browse files Browse the repository at this point in the history
…cationReport

Enhancement/refactor classification report
  • Loading branch information
Hironsan committed Oct 8, 2020
2 parents a48a9d1 + 78bda02 commit a0c562a
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 26 deletions.
35 changes: 9 additions & 26 deletions seqeval/metrics/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import numpy as np

from seqeval.reporters import DictReporter, StringReporter


def get_entities(seq, suffix=False):
"""Gets entities from sequence.
Expand Down Expand Up @@ -343,16 +345,11 @@ def classification_report(y_true, y_pred, digits=2, suffix=False, output_dict=Fa
avg_types = ['micro avg', 'macro avg', 'weighted avg']

if output_dict:
report_dict = dict()
reporter = DictReporter()
else:
avg_width = max([len(x) for x in avg_types])
width = max(name_width, avg_width, digits)
headers = ["precision", "recall", "f1-score", "support"]
head_fmt = u'{:>{width}s} ' + u' {:>9}' * len(headers)
report = head_fmt.format(u'', *headers, width=width)
report += u'\n\n'

row_fmt = u'{:>{width}s} ' + u' {:>9.{digits}f}' * 3 + u' {:>9}\n'
reporter = StringReporter(width=width, digits=digits)

ps, rs, f1s, s = [], [], [], []
for type_name in sorted(d1.keys()):
Expand All @@ -366,47 +363,33 @@ def classification_report(y_true, y_pred, digits=2, suffix=False, output_dict=Fa
r = nb_correct / nb_true if nb_true > 0 else 0
f1 = 2 * p * r / (p + r) if p + r > 0 else 0

if output_dict:
report_dict[type_name] = {'precision': p, 'recall': r, 'f1-score': f1, 'support': nb_true}
else:
report += row_fmt.format(*[type_name, p, r, f1, nb_true], width=width, digits=digits)
reporter.write(type_name, p, r, f1, nb_true)

ps.append(p)
rs.append(r)
f1s.append(f1)
s.append(nb_true)

if not output_dict:
report += u'\n'
reporter.write_blank()

# compute averages
nb_true = np.sum(s)

for avg_type in avg_types:
if avg_type == 'micro avg':
# micro average
p = precision_score(y_true, y_pred, suffix=suffix)
r = recall_score(y_true, y_pred, suffix=suffix)
f1 = f1_score(y_true, y_pred, suffix=suffix)
elif avg_type == 'macro avg':
# macro average
p = np.average(ps)
r = np.average(rs)
f1 = np.average(f1s)
elif avg_type == 'weighted avg':
# weighted average
p = np.average(ps, weights=s)
r = np.average(rs, weights=s)
f1 = np.average(f1s, weights=s)
else:
assert False, "unexpected average: {}".format(avg_type)
reporter.write(avg_type, p, r, f1, nb_true)
reporter.write_blank()

if output_dict:
report_dict[avg_type] = {'precision': p, 'recall': r, 'f1-score': f1, 'support': nb_true}
else:
report += row_fmt.format(*[avg_type, p, r, f1, nb_true], width=width, digits=digits)

if output_dict:
return report_dict
else:
return report
return reporter.report()
69 changes: 69 additions & 0 deletions seqeval/reporters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import abc


class Reporter(abc.ABC):

def __init__(self, *args, **kwargs):
pass

@abc.abstractmethod
def report(self):
raise NotImplementedError

@abc.abstractmethod
def write(self, row_name: str, precision: float, recall: float, f1: float, support: int):
raise NotImplementedError


class DictReporter(Reporter):

def __init__(self, *args, **kwargs):
super().__init__()
self.report_dict = {}

def report(self):
return self.report_dict

def write(self, row_name: str, precision: float, recall: float, f1: float, support: int):
self.report_dict[row_name] = {
'precision': precision,
'recall': recall,
'f1-score': f1,
'support': support
}

def write_blank(self):
pass


class StringReporter(Reporter):

def __init__(self, *args, **kwargs):
super().__init__()
self.buffer = []
self.row_fmt = '{:>{width}s} ' + ' {:>9.{digits}f}' * 3 + ' {:>9}'
self.width = kwargs.get('width', 10)
self.digits = kwargs.get('digits', 4)

def report(self):
report = self.write_header()
report += '\n'.join(self.buffer)
return report

def write(self, row_name: str, precision: float, recall: float, f1: float, support: int):
row = self.row_fmt.format(
*[row_name, precision, recall, f1, support],
width=self.width,
digits=self.digits
)
self.buffer.append(row)

def write_header(self):
headers = ['precision', 'recall', 'f1-score', 'support']
head_fmt = '{:>{width}s} ' + ' {:>9}' * len(headers)
report = head_fmt.format('', *headers, width=self.width)
report += '\n\n'
return report

def write_blank(self):
self.buffer.append('')
26 changes: 26 additions & 0 deletions tests/test_reporters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest
from seqeval.reporters import DictReporter


@pytest.mark.parametrize(
'rows, expected',
[
([], {}),
(
[['PERSON', 0.82, 0.79, 0.81, 24]],
{
'PERSON': {
'precision': 0.82,
'recall': 0.79,
'f1-score': 0.81,
'support': 24
}
}
)
]
)
def test_dict_reporter_output(rows, expected):
reporter = DictReporter()
for row in rows:
reporter.write(*row)
assert reporter.report() == expected

0 comments on commit a0c562a

Please sign in to comment.