Skip to content

Commit

Permalink
Merge pull request #59 from chakki-works/enhancement/checkConsistentL…
Browse files Browse the repository at this point in the history
…ength

Add length check to classification_report v1
  • Loading branch information
Hironsan committed Oct 13, 2020
2 parents 8305f13 + 76e5248 commit 8dd9f67
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 3 deletions.
9 changes: 7 additions & 2 deletions seqeval/metrics/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,11 @@ def check_consistent_length(y_true: List[List[str]], y_pred: List[List[str]]):
"""
len_true = list(map(len, y_true))
len_pred = list(map(len, y_pred))
is_list = set(map(type, y_true + y_pred))
if len(y_true) != len(y_pred) or len_true != len_pred or not is_list == {list}:
is_list = set(map(type, y_true)) | set(map(type, y_pred))
if not is_list == {list}:
raise TypeError('Found input variables without list of list.')

if len(y_true) != len(y_pred) or len_true != len_pred:
message = 'Found input variables with inconsistent numbers of samples:\n{}\n{}'.format(len_true, len_pred)
raise ValueError(message)

Expand Down Expand Up @@ -338,6 +341,8 @@ def classification_report(y_true: List[List[str]],
weighted avg 0.50 0.50 0.50 2
<BLANKLINE>
"""
check_consistent_length(y_true, y_pred)

if scheme is None or not issubclass(scheme, Token):
scheme = auto_detect(y_true, suffix)
target_names = unique_labels(y_true, y_pred, scheme, suffix)
Expand Down
63 changes: 62 additions & 1 deletion tests/test_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sklearn.utils._testing import (assert_array_almost_equal,
assert_array_equal)

from seqeval.metrics.v1 import (classification_report,
from seqeval.metrics.v1 import (check_consistent_length, classification_report,
precision_recall_fscore_support, unique_labels)
from seqeval.scheme import IOB2

Expand All @@ -26,6 +26,43 @@ def test_unique_labels(y_true, y_pred, expected):
assert labels == expected


class TestCheckConsistentLength:

@pytest.mark.parametrize(
'y_true, y_pred',
[
([[]], [[]]),
([['B']], [['B']])
]
)
def test_check_valid_list(self, y_true, y_pred):
check_consistent_length(y_true, y_pred)

@pytest.mark.parametrize(
'y_true, y_pred',
[
([()], [()]),
(np.array([[]]), np.array([[]])),
(np.array([[]]), [[]])
]
)
def test_check_invalid_type(self, y_true, y_pred):
with pytest.raises(TypeError):
check_consistent_length(y_true, y_pred)

@pytest.mark.parametrize(
'y_true, y_pred',
[
([[]], [['B']]),
([['B'], []], [['B']]),
([['B'], []], [['B'], ['I']])
]
)
def test_invalid_length(self, y_true, y_pred):
with pytest.raises(ValueError):
check_consistent_length(y_true, y_pred)


class TestPrecisionRecallFscoreSupport:

def test_bad_beta(self):
Expand Down Expand Up @@ -101,6 +138,30 @@ def test_average_scores_beta_inf(self, average, expected):

class TestClassificationReport:

@pytest.mark.parametrize(
'y_true, y_pred',
[
([()], [()]),
(np.array([[]]), np.array([[]])),
(np.array([[]]), [[]])
]
)
def test_check_invalid_type(self, y_true, y_pred):
with pytest.raises(TypeError):
check_consistent_length(y_true, y_pred)

@pytest.mark.parametrize(
'y_true, y_pred',
[
([[]], [['B']]),
([['B'], []], [['B']]),
([['B'], []], [['B'], ['I']])
]
)
def test_invalid_length(self, y_true, y_pred):
with pytest.raises(ValueError):
check_consistent_length(y_true, y_pred)

def test_output_dict(self):
y_true = [['B-A', 'B-B', 'O', 'B-A']]
y_pred = [['O', 'B-B', 'B-C', 'B-A']]
Expand Down

0 comments on commit 8dd9f67

Please sign in to comment.