From b2ee970fe7bc7c13eedf1c85a90e020b621e37cb Mon Sep 17 00:00:00 2001 From: Hironsan Date: Thu, 8 Oct 2020 10:09:23 +0900 Subject: [PATCH 1/2] Add two reporters --- seqeval/reporters.py | 69 +++++++++++++++++++++++++++++++++++++++++ tests/test_reporters.py | 26 ++++++++++++++++ 2 files changed, 95 insertions(+) create mode 100644 seqeval/reporters.py create mode 100644 tests/test_reporters.py diff --git a/seqeval/reporters.py b/seqeval/reporters.py new file mode 100644 index 0000000..c5cabb0 --- /dev/null +++ b/seqeval/reporters.py @@ -0,0 +1,69 @@ +import abc + + +class Reporter(abc.ABC): + + def __init__(self, *args, **kwargs): + pass + + @abc.abstractmethod + def report(self): + raise NotImplementedError + + @abc.abstractmethod + def write(self, row_name: str, precision: float, recall: float, f1: float, support: int): + raise NotImplementedError + + +class DictReporter(Reporter): + + def __init__(self, *args, **kwargs): + super().__init__() + self.report_dict = {} + + def report(self): + return self.report_dict + + def write(self, row_name: str, precision: float, recall: float, f1: float, support: int): + self.report_dict[row_name] = { + 'precision': precision, + 'recall': recall, + 'f1-score': f1, + 'support': support + } + + def write_blank(self): + pass + + +class StringReporter(Reporter): + + def __init__(self, *args, **kwargs): + super().__init__() + self.buffer = [] + self.row_fmt = '{:>{width}s} ' + ' {:>9.{digits}f}' * 3 + ' {:>9}' + self.width = kwargs.get('width', 10) + self.digits = kwargs.get('digits', 4) + + def report(self): + report = self.write_header() + report += '\n'.join(self.buffer) + return report + + def write(self, row_name: str, precision: float, recall: float, f1: float, support: int): + row = self.row_fmt.format( + *[row_name, precision, recall, f1, support], + width=self.width, + digits=self.digits + ) + self.buffer.append(row) + + def write_header(self): + headers = ['precision', 'recall', 'f1-score', 'support'] + head_fmt = '{:>{width}s} ' + ' {:>9}' * len(headers) + report = head_fmt.format('', *headers, width=self.width) + report += '\n\n' + return report + + def write_blank(self): + self.buffer.append('') diff --git a/tests/test_reporters.py b/tests/test_reporters.py new file mode 100644 index 0000000..1e6d3f4 --- /dev/null +++ b/tests/test_reporters.py @@ -0,0 +1,26 @@ +import pytest +from seqeval.reporters import DictReporter, StringReporter + + +@pytest.mark.parametrize( + 'rows, expected', + [ + ([], {}), + ( + [['PERSON', 0.82, 0.79, 0.81, 24]], + { + 'PERSON': { + 'precision': 0.82, + 'recall': 0.79, + 'f1-score': 0.81, + 'support': 24 + } + } + ) + ] +) +def test_dict_reporter_output(rows, expected): + reporter = DictReporter() + for row in rows: + reporter.write(*row) + assert reporter.report() == expected From 78bda02b7c10d0ee806586c50d37ec5906a09d7a Mon Sep 17 00:00:00 2001 From: Hironsan Date: Fri, 9 Oct 2020 08:11:06 +0900 Subject: [PATCH 2/2] Update classification_report to use reporters --- seqeval/metrics/sequence_labeling.py | 35 +++++++--------------------- tests/test_reporters.py | 2 +- 2 files changed, 10 insertions(+), 27 deletions(-) diff --git a/seqeval/metrics/sequence_labeling.py b/seqeval/metrics/sequence_labeling.py index e29d4d6..dc6d78e 100644 --- a/seqeval/metrics/sequence_labeling.py +++ b/seqeval/metrics/sequence_labeling.py @@ -12,6 +12,8 @@ import numpy as np +from seqeval.reporters import DictReporter, StringReporter + def get_entities(seq, suffix=False): """Gets entities from sequence. @@ -343,16 +345,11 @@ def classification_report(y_true, y_pred, digits=2, suffix=False, output_dict=Fa avg_types = ['micro avg', 'macro avg', 'weighted avg'] if output_dict: - report_dict = dict() + reporter = DictReporter() else: avg_width = max([len(x) for x in avg_types]) width = max(name_width, avg_width, digits) - headers = ["precision", "recall", "f1-score", "support"] - head_fmt = u'{:>{width}s} ' + u' {:>9}' * len(headers) - report = head_fmt.format(u'', *headers, width=width) - report += u'\n\n' - - row_fmt = u'{:>{width}s} ' + u' {:>9.{digits}f}' * 3 + u' {:>9}\n' + reporter = StringReporter(width=width, digits=digits) ps, rs, f1s, s = [], [], [], [] for type_name in sorted(d1.keys()): @@ -366,47 +363,33 @@ def classification_report(y_true, y_pred, digits=2, suffix=False, output_dict=Fa r = nb_correct / nb_true if nb_true > 0 else 0 f1 = 2 * p * r / (p + r) if p + r > 0 else 0 - if output_dict: - report_dict[type_name] = {'precision': p, 'recall': r, 'f1-score': f1, 'support': nb_true} - else: - report += row_fmt.format(*[type_name, p, r, f1, nb_true], width=width, digits=digits) + reporter.write(type_name, p, r, f1, nb_true) ps.append(p) rs.append(r) f1s.append(f1) s.append(nb_true) - if not output_dict: - report += u'\n' + reporter.write_blank() - # compute averages nb_true = np.sum(s) for avg_type in avg_types: if avg_type == 'micro avg': - # micro average p = precision_score(y_true, y_pred, suffix=suffix) r = recall_score(y_true, y_pred, suffix=suffix) f1 = f1_score(y_true, y_pred, suffix=suffix) elif avg_type == 'macro avg': - # macro average p = np.average(ps) r = np.average(rs) f1 = np.average(f1s) elif avg_type == 'weighted avg': - # weighted average p = np.average(ps, weights=s) r = np.average(rs, weights=s) f1 = np.average(f1s, weights=s) else: assert False, "unexpected average: {}".format(avg_type) + reporter.write(avg_type, p, r, f1, nb_true) + reporter.write_blank() - if output_dict: - report_dict[avg_type] = {'precision': p, 'recall': r, 'f1-score': f1, 'support': nb_true} - else: - report += row_fmt.format(*[avg_type, p, r, f1, nb_true], width=width, digits=digits) - - if output_dict: - return report_dict - else: - return report + return reporter.report() diff --git a/tests/test_reporters.py b/tests/test_reporters.py index 1e6d3f4..4a85be5 100644 --- a/tests/test_reporters.py +++ b/tests/test_reporters.py @@ -1,5 +1,5 @@ import pytest -from seqeval.reporters import DictReporter, StringReporter +from seqeval.reporters import DictReporter @pytest.mark.parametrize(