From 008789cc80c51a88f53b7d53b947d00667b38408 Mon Sep 17 00:00:00 2001 From: Gertjan van den Burg Date: Tue, 5 Sep 2023 21:59:07 +0100 Subject: [PATCH 1/5] Add type hints to CleverCSV --- Makefile | 8 +- clevercsv/__version__.py | 6 +- clevercsv/_optional.py | 7 +- clevercsv/_types.py | 50 ++ clevercsv/break_ties.py | 69 ++- clevercsv/cabstraction.pyi | 11 + clevercsv/consistency.py | 15 +- clevercsv/console/commands/detect.py | 5 +- clevercsv/console/commands/view.py | 31 +- clevercsv/cparser.pyi | 54 ++ clevercsv/cparser_util.py | 91 +++- clevercsv/cparser_util.pyi | 69 +++ clevercsv/detect.py | 14 +- clevercsv/detect_pattern.py | 22 +- clevercsv/detect_type.py | 43 +- clevercsv/dialect.py | 48 +- clevercsv/dict_read_write.py | 96 ++-- clevercsv/encoding.py | 7 +- clevercsv/escape.py | 14 +- clevercsv/py.typed | 0 clevercsv/read.py | 49 +- clevercsv/wrappers.py | 120 +++-- clevercsv/write.py | 34 +- pyproject.toml | 7 + setup.py | 3 +- stubs/pandas/__init__.pyi | 119 +++++ stubs/pythonfuzz/__init__.pyi | 0 stubs/pythonfuzz/main.pyi | 6 + stubs/regex/__init__.pyi | 61 +++ stubs/regex/_regex.pyi | 13 + stubs/regex/_regex_core.pyi | 503 ++++++++++++++++++ stubs/regex/regex.pyi | 189 +++++++ stubs/tabview/__init__.pyi | 1 + stubs/tabview/tabview.pyi | 163 ++++++ stubs/termcolor/__init__.pyi | 22 + stubs/wilderness/__init__.pyi | 168 ++++++ .../test_dialect_detection.py | 3 +- tests/test_unit/test_console.py | 60 ++- tests/test_unit/test_detect_type.py | 33 +- tests/test_unit/test_dict.py | 18 +- tests/test_unit/test_wrappers.py | 9 +- tests/test_unit/test_write.py | 7 - 42 files changed, 1988 insertions(+), 260 deletions(-) create mode 100644 clevercsv/_types.py create mode 100644 clevercsv/cabstraction.pyi create mode 100644 clevercsv/cparser.pyi create mode 100644 clevercsv/cparser_util.pyi create mode 100644 clevercsv/py.typed create mode 100644 stubs/pandas/__init__.pyi create mode 100644 stubs/pythonfuzz/__init__.pyi create mode 100644 stubs/pythonfuzz/main.pyi create mode 100644 stubs/regex/__init__.pyi create mode 100644 stubs/regex/_regex.pyi create mode 100644 stubs/regex/_regex_core.pyi create mode 100644 stubs/regex/regex.pyi create mode 100644 stubs/tabview/__init__.pyi create mode 100644 stubs/tabview/tabview.pyi create mode 100644 stubs/termcolor/__init__.pyi create mode 100644 stubs/wilderness/__init__.pyi diff --git a/Makefile b/Makefile index 3cba1509..4317c169 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ MAKEFLAGS += --no-builtin-rules PACKAGE=clevercsv DOC_DIR=./docs/ -VENV_DIR=/tmp/clevercsv_venv/ +VENV_DIR=/tmp/clevercsv_venv PYTHON ?= python .PHONY: help @@ -51,7 +51,7 @@ dist: man ## Make Python source distribution .PHONY: test integration integration_partial -test: green pytest +test: mypy green pytest green: venv ## Run unit tests source $(VENV_DIR)/bin/activate && green -a -vv ./tests/test_unit @@ -59,6 +59,10 @@ green: venv ## Run unit tests pytest: venv ## Run unit tests with PyTest source $(VENV_DIR)/bin/activate && pytest -ra -m 'not network' +mypy: venv ## Run type checks + source $(VENV_DIR)/bin/activate && \ + mypy --check-untyped-defs ./stubs $(PACKAGE) ./tests + integration: venv ## Run integration tests source $(VENV_DIR)/bin/activate && python ./tests/test_integration/test_dialect_detection.py -v diff --git a/clevercsv/__version__.py b/clevercsv/__version__.py index 2b9f8268..e054b220 100644 --- a/clevercsv/__version__.py +++ b/clevercsv/__version__.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- -VERSION = (0, 8, 0) +from typing import Tuple -__version__ = ".".join(map(str, VERSION)) +VERSION: Tuple[int, int, int] = (0, 8, 0) + +__version__: str = ".".join(map(str, VERSION)) diff --git a/clevercsv/_optional.py b/clevercsv/_optional.py index 922f44b1..77588e84 100644 --- a/clevercsv/_optional.py +++ b/clevercsv/_optional.py @@ -13,9 +13,12 @@ import importlib +from types import ModuleType + from typing import Dict from typing import List from typing import NamedTuple +from typing import Optional from packaging.version import Version @@ -35,7 +38,9 @@ class OptionalDependency(NamedTuple): ] -def import_optional_dependency(name, raise_on_missing=True): +def import_optional_dependency( + name: str, raise_on_missing: bool = True +) -> Optional[ModuleType]: """ Import an optional dependency. diff --git a/clevercsv/_types.py b/clevercsv/_types.py new file mode 100644 index 00000000..e14cc006 --- /dev/null +++ b/clevercsv/_types.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import csv +import os +import sys + +from typing import TYPE_CHECKING +from typing import Any +from typing import Mapping +from typing import Type +from typing import Union + +from clevercsv.dialect import SimpleDialect + +AnyPath = Union[str, bytes, os.PathLike[str], os.PathLike[bytes]] +_OpenFile = Union[AnyPath, int] +_DictRow = Mapping[str, Any] +_DialectLike = Union[str, csv.Dialect, Type[csv.Dialect], SimpleDialect] + +if sys.version_info >= (3, 8): + from typing import Dict as _DictReadMapping +else: + from collections import OrderedDict as _DictReadMapping + + +if TYPE_CHECKING: + from _typeshed import FileDescriptorOrPath # NOQA + from _typeshed import SupportsIter # NOQA + from _typeshed import SupportsWrite # NOQA + + __all__ = [ + "SupportsWrite", + "SupportsIter", + "FileDescriptorOrPath", + "AnyPath", + "_OpenFile", + "_DictRow", + "_DialectLike", + "_DictReadMapping", + ] +else: + __all__ = [ + "AnyPath", + "_OpenFile", + "_DictRow", + "_DialectLike", + "_DictReadMapping", + ] diff --git a/clevercsv/break_ties.py b/clevercsv/break_ties.py index 017fc631..490adfd7 100644 --- a/clevercsv/break_ties.py +++ b/clevercsv/break_ties.py @@ -7,12 +7,17 @@ """ +from typing import List +from typing import Optional + from .cparser_util import parse_string from .dialect import SimpleDialect from .utils import pairwise -def tie_breaker(data, dialects): +def tie_breaker( + data: str, dialects: List[SimpleDialect] +) -> Optional[SimpleDialect]: """ Break ties between dialects. @@ -42,7 +47,9 @@ def tie_breaker(data, dialects): return None -def reduce_pairwise(data, dialects): +def reduce_pairwise( + data: str, dialects: List[SimpleDialect] +) -> Optional[List[SimpleDialect]]: """Reduce the set of dialects by breaking pairwise ties Parameters @@ -62,7 +69,7 @@ def reduce_pairwise(data, dialects): """ equal_delim = len(set([d.delimiter for d in dialects])) == 1 if not equal_delim: - return None + return None # TODO: This might be wrong, it can just return the input! # First, identify dialects that result in the same parsing result. equal_dialects = [] @@ -99,7 +106,9 @@ def _dialects_only_differ_in_field( ) -def break_ties_two(data, A, B): +def break_ties_two( + data: str, A: SimpleDialect, B: SimpleDialect +) -> Optional[SimpleDialect]: """Break ties between two dialects. This function breaks ties between two dialects that give the same score. We @@ -152,7 +161,7 @@ def break_ties_two(data, A, B): # quotechar has an effect return d_yes elif _dialects_only_differ_in_field(A, B, "delimiter"): - if sorted([A.delimiter, B.delimiter]) == sorted([",", " "]): + if set([A.delimiter, B.delimiter]) == set([",", " "]): # Artifact due to type detection (comma as radix point) if A.delimiter == ",": return A @@ -175,14 +184,14 @@ def break_ties_two(data, A, B): # we can't break this tie (for now) if len(X) != len(Y): return None - for x, y in zip(X, Y): - if len(x) != len(y): + for row_X, row_Y in zip(X, Y): + if len(row_X) != len(row_Y): return None cells_escaped = [] cells_unescaped = [] - for x, y in zip(X, Y): - for u, v in zip(x, y): + for row_X, row_Y in zip(X, Y): + for u, v in zip(row_X, row_Y): if u != v: cells_unescaped.append(u) cells_escaped.append(v) @@ -221,16 +230,18 @@ def break_ties_two(data, A, B): if len(X) != len(Y): return None - for x, y in zip(X, Y): - if len(x) != len(y): + for row_X, row_Y in zip(X, Y): + if len(row_X) != len(row_Y): return None # if we're here, then there is no effect on structure. # we test if the only cells that differ are those that have an # escapechar+quotechar combination. + assert isinstance(d_yes.escapechar, str) + assert isinstance(d_yes.quotechar, str) eq = d_yes.escapechar + d_yes.quotechar - for rX, rY in zip(X, Y): - for x, y in zip(rX, rY): + for row_X, row_Y in zip(X, Y): + for x, y in zip(row_X, row_Y): if x != y: if eq not in x: return None @@ -243,7 +254,9 @@ def break_ties_two(data, A, B): return None -def break_ties_three(data, A, B, C): +def break_ties_three( + data: str, A: SimpleDialect, B: SimpleDialect, C: SimpleDialect +) -> Optional[SimpleDialect]: """Break ties between three dialects. If the delimiters and the escape characters are all equal, then we look for @@ -273,7 +286,7 @@ def break_ties_three(data, A, B, C): Returns ------- - dialect: SimpleDialect + dialect: Optional[SimpleDialect] The chosen dialect if the tie can be broken, None otherwise. Notes @@ -307,6 +320,7 @@ def break_ties_three(data, A, B, C): ) if p_none is None: return None + assert d_none is not None rem = [ (p, d) for p, d in zip([pA, pB, pC], dialects) if not p == p_none @@ -318,6 +332,8 @@ def break_ties_three(data, A, B, C): # the CSV paper. When fixing the delimiter to Tab, rem = []. # Try to reduce pairwise new_dialects = reduce_pairwise(data, dialects) + if new_dialects is None: + return None if len(new_dialects) == 1: return new_dialects[0] return None @@ -347,7 +363,9 @@ def break_ties_three(data, A, B, C): return None -def break_ties_four(data, dialects): +def break_ties_four( + data: str, dialects: List[SimpleDialect] +) -> Optional[SimpleDialect]: """Break ties between four dialects. This function works by breaking the ties between pairs of dialects that @@ -368,7 +386,7 @@ def break_ties_four(data, dialects): Returns ------- - dialect: SimpleDialect + dialect: Optional[SimpleDialect] The chosen dialect if the tie can be broken, None otherwise. Notes @@ -378,19 +396,22 @@ def break_ties_four(data, dialects): examples are found. """ + # TODO: Check for length 4, handle more than 4 too? equal_delim = len(set([d.delimiter for d in dialects])) == 1 if not equal_delim: return None - dialects = reduce_pairwise(data, dialects) + reduced_dialects = reduce_pairwise(data, dialects) + if reduced_dialects is None: + return None # Defer to other functions if the number of dialects was reduced - if len(dialects) == 1: - return dialects[0] - elif len(dialects) == 2: - return break_ties_two(data, *dialects) - elif len(dialects) == 3: - return break_ties_three(data, *dialects) + if len(reduced_dialects) == 1: + return reduced_dialects[0] + elif len(reduced_dialects) == 2: + return break_ties_two(data, *reduced_dialects) + elif len(reduced_dialects) == 3: + return break_ties_three(data, *reduced_dialects) return None diff --git a/clevercsv/cabstraction.pyi b/clevercsv/cabstraction.pyi new file mode 100644 index 00000000..df3228e7 --- /dev/null +++ b/clevercsv/cabstraction.pyi @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +def base_abstraction( + data: str, + delimiter: Optional[str], + quotechar: Optional[str], + escapechar: Optional[str], +) -> str: ... +def c_merge_with_quotechar(data: str) -> str: ... diff --git a/clevercsv/consistency.py b/clevercsv/consistency.py index 6a80d711..71d2afdb 100644 --- a/clevercsv/consistency.py +++ b/clevercsv/consistency.py @@ -89,7 +89,7 @@ def cached_is_known_type(cell: str, is_quoted: bool) -> bool: def detect( self, data: str, delimiters: Optional[Iterable[str]] = None - ) -> None: + ) -> Optional[SimpleDialect]: """Detect the dialect using the consistency measure Parameters @@ -184,8 +184,11 @@ def get_best_dialects( ) -> List[SimpleDialect]: """Identify the dialects with the highest consistency score""" Qscores = [score.Q for score in scores.values()] - Qscores = list(filter(lambda q: q is not None, Qscores)) - Qmax = max(Qscores) + Qmax = -float("inf") + for q in Qscores: + if q is None: + continue + Qmax = max(Qmax, q) return [d for d, score in scores.items() if score.Q == Qmax] def compute_type_score( @@ -194,6 +197,7 @@ def compute_type_score( """Compute the type score""" total = known = 0 for row in parse_string(data, dialect, return_quoted=True): + assert all(isinstance(cell, tuple) for cell in row) for cell, is_quoted in row: total += 1 known += self._cached_is_known_type(cell, is_quoted=is_quoted) @@ -203,7 +207,10 @@ def compute_type_score( def detect_dialect_consistency( - data, delimiters=None, skip=True, verbose=False + data: str, + delimiters: Optional[Iterable[str]] = None, + skip: bool = True, + verbose: bool = False, ): """Helper function that wraps ConsistencyDetector""" # Mostly kept for backwards compatibility diff --git a/clevercsv/console/commands/detect.py b/clevercsv/console/commands/detect.py index 590338a4..74775d90 100644 --- a/clevercsv/console/commands/detect.py +++ b/clevercsv/console/commands/detect.py @@ -4,6 +4,9 @@ import sys import time +from typing import Any +from typing import Dict + from wilderness import Command from clevercsv.wrappers import detect_dialect @@ -125,7 +128,7 @@ def handle(self): if self.args.add_runtime: print(f"runtime = {runtime}") elif self.args.json: - dialect_dict = dialect.to_dict() + dialect_dict: Dict[str, Any] = dialect.to_dict() if self.args.add_runtime: dialect_dict["runtime"] = runtime print(json.dumps(dialect_dict)) diff --git a/clevercsv/console/commands/view.py b/clevercsv/console/commands/view.py index dbce0916..df09b509 100644 --- a/clevercsv/console/commands/view.py +++ b/clevercsv/console/commands/view.py @@ -2,19 +2,9 @@ import sys -try: - import tabview -except ImportError: - - class TabView: - def view(*args, **kwargs): - print( - "Error: unfortunately Tabview is not available on Windows.", - file=sys.stderr, - ) - - tabview = TabView() - +from typing import List +from typing import Optional +from typing import Sequence from wilderness import Command @@ -61,6 +51,17 @@ def register(self): help="Transpose the columns of the input file before viewing", ) + def _tabview(self, rows) -> None: + try: + from tabview import view + except ImportError: + print( + "Error: unfortunately Tabview is not available on Windows.", + file=sys.stderr, + ) + return + view(rows) + def handle(self) -> int: verbose = self.args.verbose num_chars = parse_int(self.args.num_chars, "num-chars") @@ -77,7 +78,7 @@ def handle(self) -> int: if self.args.transpose: max_row_length = max(map(len, rows)) - fixed_rows = [] + fixed_rows: List[Sequence[Optional[str]]] = [] for row in rows: if len(row) == max_row_length: fixed_rows.append(row) @@ -86,5 +87,5 @@ def handle(self) -> int: row + [None] * (max_row_length - len(row)) ) rows = list(map(list, zip(*fixed_rows))) - tabview.view(rows) + self._tabview(rows) return 0 diff --git a/clevercsv/cparser.pyi b/clevercsv/cparser.pyi new file mode 100644 index 00000000..53e3a064 --- /dev/null +++ b/clevercsv/cparser.pyi @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import Final +from typing import Generic +from typing import Iterable +from typing import List +from typing import Literal +from typing import Optional +from typing import Tuple +from typing import TypeVar +from typing import overload + +_T = TypeVar("_T") + +class Parser(Generic[_T]): + _return_quoted: Final[bool] + + @overload + def __init__( + self: Parser[List[Tuple[str, bool]]], + delimiter: Optional[str] = "", + quotechar: Optional[str] = "", + escapechar: Optional[str] = "", + field_limit: Optional[int] = 128 * 1024, + strict: Optional[bool] = False, + return_quoted: Literal[True] = ..., + ) -> None: ... + @overload + def __init__( + self: Parser[List[str]], + delimiter: Optional[str] = "", + quotechar: Optional[str] = "", + escapechar: Optional[str] = "", + field_limit: Optional[int] = 128 * 1024, + strict: Optional[bool] = False, + return_quoted: Literal[False] = ..., + ) -> None: ... + @overload + def __init__( + self, + data: Iterable[str], + delimiter: Optional[str] = "", + quotechar: Optional[str] = "", + escapechar: Optional[str] = "", + field_limit: Optional[int] = 128 * 1024, + strict: Optional[bool] = False, + return_quoted: bool = ..., + ) -> None: ... + def __iter__(self) -> "Parser": ... + def __next__(self) -> _T: ... + +class Error(Exception): ... diff --git a/clevercsv/cparser_util.py b/clevercsv/cparser_util.py index 4ac528ea..484c2573 100644 --- a/clevercsv/cparser_util.py +++ b/clevercsv/cparser_util.py @@ -7,15 +7,23 @@ import io +from typing import Any +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + from .cparser import Error as ParserError from .cparser import Parser from .dialect import SimpleDialect from .exceptions import Error -_FIELD_SIZE_LIMIT = 128 * 1024 +_FIELD_SIZE_LIMIT: int = 128 * 1024 -def field_size_limit(*args, **kwargs): +def field_size_limit(*args: Any, **kwargs: Any) -> int: """Get/Set the limit to the field size. This function is adapted from the one in the Python CSV module. See the @@ -23,29 +31,54 @@ def field_size_limit(*args, **kwargs): """ global _FIELD_SIZE_LIMIT old_limit = _FIELD_SIZE_LIMIT - args = list(args) + list(kwargs.values()) - if not 0 <= len(args) <= 1: + all_args = list(args) + list(kwargs.values()) + if not 0 <= len(all_args) <= 1: raise TypeError( - "field_size_limit expected at most 1 arguments, got %i" % len(args) + "field_size_limit expected at most 1 arguments, got %i" + % len(all_args) ) - if len(args) == 0: + if len(all_args) == 0: return old_limit - limit = args[0] + limit = all_args[0] if not isinstance(limit, int): raise TypeError("limit must be an integer") _FIELD_SIZE_LIMIT = int(limit) return old_limit +def _parse_data( + data: Iterable[str], + delimiter: str, + quotechar: str, + escapechar: str, + strict: bool, + return_quoted: bool = False, +) -> Iterator[Union[List[str], List[Tuple[str, bool]]]]: + parser = Parser( + data, + delimiter=delimiter, + quotechar=quotechar, + escapechar=escapechar, + field_limit=field_size_limit(), + strict=strict, + return_quoted=return_quoted, + ) + try: + for row in parser: + yield row + except ParserError as e: + raise Error(str(e)) + + def parse_data( - data, - dialect=None, - delimiter=None, - quotechar=None, - escapechar=None, - strict=None, - return_quoted=False, -): + data: Iterable[str], + dialect: Optional[SimpleDialect] = None, + delimiter: Optional[str] = None, + quotechar: Optional[str] = None, + escapechar: Optional[str] = None, + strict: Optional[bool] = None, + return_quoted: bool = False, +) -> Iterator[Union[List[str], List[Tuple[str, bool]]]]: """Parse the data given a dialect using the C parser Parameters @@ -96,22 +129,24 @@ def parse_data( escapechar_ = escapechar if escapechar is not None else dialect.escapechar strict_ = strict if strict is not None else dialect.strict - parser = Parser( + yield from _parse_data( data, - delimiter=delimiter_, - quotechar=quotechar_, - escapechar=escapechar_, - field_limit=field_size_limit(), - strict=strict_, + delimiter_, + quotechar_, + escapechar_, + strict_, return_quoted=return_quoted, ) - try: - for row in parser: - yield row - except ParserError as e: - raise Error(str(e)) -def parse_string(data, *args, **kwargs): +def parse_string( + data: str, + dialect: SimpleDialect, + return_quoted: bool = False, +) -> Iterator[Union[List[str], List[Tuple[str, bool]]]]: """Utility for when the CSV file is encoded as a single string""" - return parse_data(io.StringIO(data, newline=""), *args, **kwargs) + return parse_data( + iter(io.StringIO(data, newline="")), + dialect=dialect, + return_quoted=return_quoted, + ) diff --git a/clevercsv/cparser_util.pyi b/clevercsv/cparser_util.pyi new file mode 100644 index 00000000..e78c21cd --- /dev/null +++ b/clevercsv/cparser_util.pyi @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- + +from typing import Any +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Literal +from typing import Optional +from typing import Tuple +from typing import Union +from typing import overload + +from .dialect import SimpleDialect + +def field_size_limit(*args: Any, **kwargs: Any) -> int: ... +@overload +def _parse_data( + data: Iterable[str], + delimiter: str, + quotechar: str, + escapechar: str, + strict: bool, + return_quoted: Literal[False] = ..., +) -> Iterator[List[str]]: ... +@overload +def _parse_data( + data: Iterable[str], + delimiter: str, + quotechar: str, + escapechar: str, + strict: bool, + return_quoted: Literal[True], +) -> Iterator[List[Tuple[str, bool]]]: ... +@overload +def _parse_data( + data: Iterable[str], + delimiter: str, + quotechar: str, + escapechar: str, + strict: bool, + return_quoted: bool = ..., +) -> Iterator[Union[List[str], List[Tuple[str, bool]]]]: ... +def parse_data( + data: Iterable[str], + dialect: Optional[SimpleDialect] = None, + delimiter: Optional[str] = None, + quotechar: Optional[str] = None, + escapechar: Optional[str] = None, + strict: Optional[bool] = None, + return_quoted: bool = False, +) -> Iterator[Union[List[str], List[Tuple[str, bool]]]]: ... +@overload +def parse_string( + data: str, + dialect: SimpleDialect, + return_quoted: Literal[False] = ..., +) -> Iterator[List[str]]: ... +@overload +def parse_string( + data: str, + dialect: SimpleDialect, + return_quoted: Literal[True], +) -> Iterator[List[Tuple[str, bool]]]: ... +@overload +def parse_string( + data: str, + dialect: SimpleDialect, + return_quoted: bool = ..., +) -> Iterator[Union[List[str], List[Tuple[str, bool]]]]: ... diff --git a/clevercsv/detect.py b/clevercsv/detect.py index 8c5735ff..57c1e619 100644 --- a/clevercsv/detect.py +++ b/clevercsv/detect.py @@ -9,6 +9,10 @@ from io import StringIO +from typing import Dict +from typing import Optional +from typing import Union + from .consistency import ConsistencyDetector from .normal_form import detect_dialect_normal from .read import reader @@ -28,9 +32,6 @@ class Detector: """ - def __init__(self): - pass - def sniff(self, sample, delimiters=None, verbose=False): # Compatibility method for Python return self.detect(sample, delimiters=delimiters, verbose=verbose) @@ -126,10 +127,11 @@ def has_header(self, sample): header = next(rdr) # assume first row is header columns = len(header) - columnTypes = {} + columnTypes: Dict[int, Optional[Union[int, type]]] = {} for i in range(columns): columnTypes[i] = None + thisType: Union[int, type] checked = 0 for row in rdr: # arbitrary number of rows to check, to keep it sane @@ -169,6 +171,10 @@ def has_header(self, sample): else: hasHeader -= 1 else: # attempt typecast + if colType is None: + hasHeader += 1 + continue + try: colType(header[col]) except (ValueError, TypeError): diff --git a/clevercsv/detect_pattern.py b/clevercsv/detect_pattern.py index 1b26bbea..dc8b49bd 100644 --- a/clevercsv/detect_pattern.py +++ b/clevercsv/detect_pattern.py @@ -10,24 +10,28 @@ import collections import re +from typing import Optional from typing import Pattern from .cabstraction import base_abstraction from .cabstraction import c_merge_with_quotechar +from .dialect import SimpleDialect -DEFAULT_EPS_PAT = 1e-3 +DEFAULT_EPS_PAT: float = 1e-3 RE_MULTI_C: Pattern = re.compile(r"C{2,}") -def pattern_score(data, dialect, eps=DEFAULT_EPS_PAT): +def pattern_score( + data: str, dialect: SimpleDialect, eps: float = DEFAULT_EPS_PAT +) -> float: """ Compute the pattern score for given data and a dialect. Parameters ---------- - data : string + data : str The data of the file as a raw character string dialect: dialect.Dialect @@ -41,7 +45,7 @@ def pattern_score(data, dialect, eps=DEFAULT_EPS_PAT): """ A = make_abstraction(data, dialect) row_patterns = collections.Counter(A.split("R")) - P = 0 + P = 0.0 for pat_k, Nk in row_patterns.items(): Lk = len(pat_k.split("D")) P += Nk * (max(eps, Lk - 1) / Lk) @@ -49,7 +53,7 @@ def pattern_score(data, dialect, eps=DEFAULT_EPS_PAT): return P -def make_abstraction(data, dialect): +def make_abstraction(data: str, dialect: SimpleDialect) -> str: """Create an abstract representation of the CSV file based on the dialect. This function constructs the basic abstraction used to compute the row @@ -78,7 +82,9 @@ def make_abstraction(data, dialect): return A -def merge_with_quotechar(S, dialect=None): +def merge_with_quotechar( + S: str, dialect: Optional[SimpleDialect] = None +) -> str: """Merge quoted blocks in the abstraction This function takes the abstract representation and merges quoted blocks @@ -103,7 +109,7 @@ def merge_with_quotechar(S, dialect=None): return c_merge_with_quotechar(S) -def fill_empties(abstract): +def fill_empties(abstract: str) -> str: """Fill empty cells in the abstraction The way the row patterns are constructed assumes that empty cells are @@ -143,7 +149,7 @@ def fill_empties(abstract): return abstract -def strip_trailing(abstract): +def strip_trailing(abstract: str) -> str: """Strip trailing row separator from abstraction.""" while abstract.endswith("R"): abstract = abstract[:-1] diff --git a/clevercsv/detect_type.py b/clevercsv/detect_type.py index abc7f69d..97fe116e 100644 --- a/clevercsv/detect_type.py +++ b/clevercsv/detect_type.py @@ -10,6 +10,7 @@ import json from typing import Dict +from typing import List from typing import Optional from typing import Pattern @@ -19,7 +20,7 @@ DEFAULT_EPS_TYPE = 1e-10 -class TypeDetector(object): +class TypeDetector: def __init__( self, patterns: Optional[Dict[str, Pattern]] = None, @@ -48,26 +49,27 @@ def _register_type_tests(self): ("json", self.is_json_obj), ] - def list_known_types(self): + def list_known_types(self) -> List[str]: return [tt[0] for tt in self._type_tests] - def is_known_type(self, cell, is_quoted=False): + def is_known_type(self, cell: str, is_quoted: bool = False) -> bool: return self.detect_type(cell, is_quoted=is_quoted) is not None - def detect_type(self, cell, is_quoted=False): + def detect_type(self, cell: str, is_quoted: bool = False): cell = cell.strip() if self.strip_whitespace else cell for name, func in self._type_tests: if func(cell, is_quoted=is_quoted): return name return None - def _run_regex(self, cell, patname): + def _run_regex(self, cell: str, patname: str) -> bool: cell = cell.strip() if self.strip_whitespace else cell pat = self.patterns.get(patname, None) + assert pat is not None match = pat.fullmatch(cell) return match is not None - def is_number(self, cell, **kwargs): + def is_number(self, cell: str, is_quoted: bool = False) -> bool: if cell == "": return False if self._run_regex(cell, "number_1"): @@ -78,21 +80,21 @@ def is_number(self, cell, **kwargs): return True return False - def is_ipv4(self, cell, **kwargs): + def is_ipv4(self, cell: str, is_quoted: bool = False) -> bool: return self._run_regex(cell, "ipv4") - def is_url(self, cell, **kwargs): + def is_url(self, cell: str, is_quoted: bool = False) -> bool: return self._run_regex(cell, "url") - def is_email(self, cell, **kwargs): + def is_email(self, cell: str, is_quoted: bool = False) -> bool: return self._run_regex(cell, "email") - def is_unicode_alphanum(self, cell, is_quoted=False, **kwargs): + def is_unicode_alphanum(self, cell: str, is_quoted: bool = False) -> bool: if is_quoted: return self._run_regex(cell, "unicode_alphanum_quoted") return self._run_regex(cell, "unicode_alphanum") - def is_date(self, cell, **kwargs): + def is_date(self, cell: str, is_quoted: bool = False) -> bool: # This function assumes the cell is not a number. cell = cell.strip() if self.strip_whitespace else cell if not cell: @@ -101,7 +103,7 @@ def is_date(self, cell, **kwargs): return False return self._run_regex(cell, "date") - def is_time(self, cell, **kwargs): + def is_time(self, cell: str, is_quoted: bool = False) -> bool: cell = cell.strip() if self.strip_whitespace else cell if not cell: return False @@ -114,14 +116,15 @@ def is_time(self, cell, **kwargs): or self._run_regex(cell, "time_hhmmsszz") ) - def is_empty(self, cell, **kwargs): + def is_empty(self, cell: str, is_quoted: bool = False) -> bool: return cell == "" - def is_percentage(self, cell, **kwargs): + def is_percentage(self, cell: str, is_quoted: bool = False) -> bool: return cell.endswith("%") and self.is_number(cell.rstrip("%")) - def is_currency(self, cell, **kwargs): + def is_currency(self, cell: str, is_quoted: bool = False) -> bool: pat = self.patterns.get("currency", None) + assert pat is not None m = pat.fullmatch(cell) if m is None: return False @@ -130,7 +133,7 @@ def is_currency(self, cell, **kwargs): return False return True - def is_datetime(self, cell, **kwargs): + def is_datetime(self, cell: str, is_quoted: bool = False) -> bool: # Takes care of cells with '[date] [time]' and '[date]T[time]' (iso) if not cell: return False @@ -182,18 +185,18 @@ def is_datetime(self, cell, **kwargs): return True return False - def is_nan(self, cell, **kwargs): + def is_nan(self, cell: str, is_quoted: bool = False) -> bool: if cell.lower() in ["n/a", "na", "nan"]: return True return False - def is_unix_path(self, cell, **kwargs): + def is_unix_path(self, cell: str, is_quoted: bool = False) -> bool: return self._run_regex(cell, "unix_path") - def is_bytearray(self, cell: str, **kwargs) -> bool: + def is_bytearray(self, cell: str, is_quoted: bool = False) -> bool: return cell.startswith("bytearray(b") and cell.endswith(")") - def is_json_obj(self, cell: str, **kwargs) -> bool: + def is_json_obj(self, cell: str, is_quoted: bool = False) -> bool: if not (cell.startswith("{") and cell.endswith("}")): return False try: diff --git a/clevercsv/dialect.py b/clevercsv/dialect.py index e01879ed..023622a3 100644 --- a/clevercsv/dialect.py +++ b/clevercsv/dialect.py @@ -12,13 +12,19 @@ import functools import json +from typing import Any +from typing import Dict +from typing import Optional +from typing import Type +from typing import Union + excel = csv.excel excel_tab = csv.excel_tab unix_dialect = csv.unix_dialect @functools.total_ordering -class SimpleDialect(object): +class SimpleDialect: """ The simplified dialect object. @@ -42,13 +48,19 @@ class SimpleDialect(object): """ - def __init__(self, delimiter, quotechar, escapechar, strict=False): + def __init__( + self, + delimiter: Optional[str], + quotechar: Optional[str], + escapechar: Optional[str], + strict: bool = False, + ): self.delimiter = delimiter self.quotechar = quotechar self.escapechar = escapechar self.strict = strict - def validate(self): + def validate(self) -> None: if self.delimiter is None or len(self.delimiter) > 1: raise ValueError( "Delimiter should be zero or one characters, got: %r" @@ -70,21 +82,26 @@ def validate(self): ) @classmethod - def from_dict(cls, d): - d = cls( + def from_dict( + cls: Type["SimpleDialect"], d: Dict[str, Any] + ) -> "SimpleDialect": + dialect = cls( d["delimiter"], d["quotechar"], d["escapechar"], strict=d["strict"] ) - return d + return dialect @classmethod - def from_csv_dialect(cls, d): + def from_csv_dialect( + cls: Type["SimpleDialect"], d: csv.Dialect + ) -> "SimpleDialect": delimiter = "" if d.delimiter is None else d.delimiter quotechar = "" if d.quoting == csv.QUOTE_NONE else d.quotechar escapechar = "" if d.escapechar is None else d.escapechar return cls(delimiter, quotechar, escapechar, strict=d.strict) - def to_csv_dialect(self): + def to_csv_dialect(self) -> csv.Dialect: class dialect(csv.Dialect): + assert self.delimiter is not None delimiter = self.delimiter quotechar = '"' if self.quotechar == "" else self.quotechar escapechar = None if self.escapechar == "" else self.escapechar @@ -93,10 +110,13 @@ class dialect(csv.Dialect): csv.QUOTE_NONE if self.quotechar == "" else csv.QUOTE_MINIMAL ) skipinitialspace = False + # TODO: We need to set this because it can't be None anymore in + # recent versions of Python + lineterminator = "\n" - return dialect + return dialect() - def to_dict(self): + def to_dict(self) -> Dict[str, Union[str, bool, None]]: self.validate() d = dict( delimiter=self.delimiter, @@ -106,16 +126,16 @@ def to_dict(self): ) return d - def serialize(self): + def serialize(self) -> str: """Serialize dialect to a JSON object""" return json.dumps(self.to_dict()) @classmethod - def deserialize(cls, obj): + def deserialize(cls: Type["SimpleDialect"], obj: str) -> "SimpleDialect": """Deserialize dialect from a JSON object""" return cls.from_dict(json.loads(obj)) - def __repr__(self): + def __repr__(self) -> str: return "SimpleDialect(%r, %r, %r)" % ( self.delimiter, self.quotechar, @@ -125,7 +145,7 @@ def __repr__(self): def __key(self): return (self.delimiter, self.quotechar, self.escapechar, self.strict) - def __hash__(self): + def __hash__(self) -> int: return hash(self.__key()) def __eq__(self, other): diff --git a/clevercsv/dict_read_write.py b/clevercsv/dict_read_write.py index 97a1e7bb..f9cf8adb 100644 --- a/clevercsv/dict_read_write.py +++ b/clevercsv/dict_read_write.py @@ -3,50 +3,75 @@ """ DictReader and DictWriter. -This code is entirely copied from the Python csv module. The only exception is +This code is entirely copied from the Python csv module. The only exception is that it uses the `reader` and `writer` classes from our package. Author: Gertjan van den Burg """ - import warnings from collections import OrderedDict - -from .read import reader -from .write import writer - - -class DictReader(object): +from collections.abc import Collection + +from typing import TYPE_CHECKING +from typing import Any +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import Literal +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import TypeVar +from typing import Union +from typing import cast + +from clevercsv.read import reader +from clevercsv.write import writer + +if TYPE_CHECKING: + from clevercsv._types import SupportsWrite + from clevercsv._types import _DialectLike + from clevercsv._types import _DictReadMapping + +_T = TypeVar("_T") + + +class DictReader( + Generic[_T], Iterator["_DictReadMapping[Union[_T, Any], Union[str, Any]]"] +): def __init__( self, - f, - fieldnames=None, - restkey=None, - restval=None, - dialect="excel", - *args, - **kwds - ): + f: Iterable[str], + fieldnames: Optional[Sequence[_T]] = None, + restkey: Optional[str] = None, + restval: Optional[str] = None, + dialect: "_DialectLike" = "excel", + *args: Any, + **kwds: Any, + ) -> None: self._fieldnames = fieldnames self.restkey = restkey self.restval = restval - self.reader = reader(f, dialect, *args, **kwds) + self.reader: reader = reader(f, dialect, *args, **kwds) self.dialect = dialect self.line_num = 0 - def __iter__(self): + def __iter__(self) -> "DictReader": return self @property - def fieldnames(self): + def fieldnames(self) -> Sequence[_T]: if self._fieldnames is None: try: - self._fieldnames = next(self.reader) + fieldnames = next(self.reader) + self._fieldnames = [cast(_T, f) for f in fieldnames] except StopIteration: pass + assert self._fieldnames is not None + # Note: this was added because I don't think it's expected that Python # simply drops information if there are duplicate headers. There is # discussion on this issue in the Python bug tracker here: @@ -62,10 +87,10 @@ def fieldnames(self): return self._fieldnames @fieldnames.setter - def fieldnames(self, value): + def fieldnames(self, value: Sequence[_T]) -> None: self._fieldnames = value - def __next__(self): + def __next__(self) -> "_DictReadMapping[Union[_T, Any], Union[str, Any]]": if self.line_num == 0: self.fieldnames row = next(self.reader) @@ -73,7 +98,8 @@ def __next__(self): while row == []: row = next(self.reader) - d = OrderedDict(zip(self.fieldnames, row)) + + d: _DictReadMapping = OrderedDict(zip(self.fieldnames, row)) lf = len(self.fieldnames) lr = len(row) if lf < lr: @@ -84,16 +110,16 @@ def __next__(self): return d -class DictWriter(object): +class DictWriter(Generic[_T]): def __init__( self, - f, - fieldnames, - restval="", - extrasaction="raise", - dialect="excel", - *args, - **kwds + f: "SupportsWrite[str]", + fieldnames: Collection[_T], + restval: Optional[Any] = "", + extrasaction: Literal["raise", "ignore"] = "raise", + dialect: "_DialectLike" = "excel", + *args: Any, + **kwds: Any, ): self.fieldnames = fieldnames self.restval = restval @@ -104,11 +130,11 @@ def __init__( self.extrasaction = extrasaction self.writer = writer(f, dialect, *args, **kwds) - def writeheader(self): + def writeheader(self) -> Any: header = dict(zip(self.fieldnames, self.fieldnames)) return self.writerow(header) - def _dict_to_list(self, rowdict): + def _dict_to_list(self, rowdict: Mapping[_T, Any]) -> Iterator[Any]: if self.extrasaction == "raise": wrong_fields = rowdict.keys() - self.fieldnames if wrong_fields: @@ -118,8 +144,8 @@ def _dict_to_list(self, rowdict): ) return (rowdict.get(key, self.restval) for key in self.fieldnames) - def writerow(self, rowdict): + def writerow(self, rowdict: Mapping[_T, Any]) -> Any: return self.writer.writerow(self._dict_to_list(rowdict)) - def writerows(self, rowdicts): + def writerows(self, rowdicts: Iterable[Mapping[_T, Any]]) -> None: return self.writer.writerows(map(self._dict_to_list, rowdicts)) diff --git a/clevercsv/encoding.py b/clevercsv/encoding.py index e1010d0b..68bd259c 100644 --- a/clevercsv/encoding.py +++ b/clevercsv/encoding.py @@ -9,12 +9,17 @@ """ +from typing import Optional + import chardet from ._optional import import_optional_dependency +from ._types import _OpenFile -def get_encoding(filename, try_cchardet=True): +def get_encoding( + filename: _OpenFile, try_cchardet: bool = True +) -> Optional[str]: """Get the encoding of the file This function uses the chardet package for detecting the encoding of a diff --git a/clevercsv/escape.py b/clevercsv/escape.py index d72b0b4d..3ba574b5 100644 --- a/clevercsv/escape.py +++ b/clevercsv/escape.py @@ -11,8 +11,12 @@ import sys import unicodedata +from typing import Iterable +from typing import Optional +from typing import Set + #: Set of default characters to *never* consider as escape character -DEFAULT_BLOCK_CHARS = set( +DEFAULT_BLOCK_CHARS: Set[str] = set( [ "!", "?", @@ -30,7 +34,7 @@ ) #: Set of characters in the Unicode "Po" category -UNICODE_PO_CHARS = set( +UNICODE_PO_CHARS: Set[str] = set( [ c for c in map(chr, range(sys.maxunicode + 1)) @@ -39,7 +43,9 @@ ) -def is_potential_escapechar(char, encoding, block_char=None): +def is_potential_escapechar( + char: str, encoding: str, block_char: Optional[Iterable[str]] = None +) -> bool: """Check if a character is a potential escape character. A character is considered a potential escape character if it is in the @@ -54,7 +60,7 @@ def is_potential_escapechar(char, encoding, block_char=None): encoding : str The encoding of the character - block_char : iterable + block_char : Optional[Iterable[str]] Characters that are in the Punctuation Other category but that should not be considered as escape character. If None, the default set is used, which is defined in :py:data:`DEFAULT_BLOCK_CHARS`. diff --git a/clevercsv/py.typed b/clevercsv/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/clevercsv/read.py b/clevercsv/read.py index 365115bd..90ce3867 100644 --- a/clevercsv/read.py +++ b/clevercsv/read.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ -Drop-in replacement for the Python csv reader class. This is a wrapper for the +Drop-in replacement for the Python csv reader class. This is a wrapper for the Parser class, defined in :mod:`cparser`. Author: Gertjan van den Burg @@ -10,22 +10,40 @@ import csv +from typing import Any +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Optional + from . import field_size_limit +from ._types import _DialectLike from .cparser import Error as ParserError from .cparser import Parser from .dialect import SimpleDialect from .exceptions import Error -class reader(object): - def __init__(self, csvfile, dialect="excel", **fmtparams): +class reader: + def __init__( + self, + csvfile: Iterable[str], + dialect: _DialectLike = "excel", + **fmtparams: Any, + ): self.csvfile = csvfile self.original_dialect = dialect - self.dialect = self._make_simple_dialect(dialect, **fmtparams) - self.line_num = 0 - self.parser_gen = None + self._dialect = self._make_simple_dialect(dialect, **fmtparams) + self.line_num: int = 0 + self.parser_gen: Optional[Parser] = None + + @property + def dialect(self) -> csv.Dialect: + return self._dialect.to_csv_dialect() - def _make_simple_dialect(self, dialect, **fmtparams): + def _make_simple_dialect( + self, dialect: _DialectLike, **fmtparams: Any + ) -> SimpleDialect: if isinstance(dialect, str): sd = SimpleDialect.from_csv_dialect(csv.get_dialect(dialect)) elif isinstance(dialect, csv.Dialect): @@ -40,27 +58,24 @@ def _make_simple_dialect(self, dialect, **fmtparams): sd.validate() return sd - def __iter__(self): + def __iter__(self) -> Iterator[List[str]]: self.parser_gen = Parser( self.csvfile, - delimiter=self.dialect.delimiter, - quotechar=self.dialect.quotechar, - escapechar=self.dialect.escapechar, + delimiter=self._dialect.delimiter, + quotechar=self._dialect.quotechar, + escapechar=self._dialect.escapechar, field_limit=field_size_limit(), - strict=self.dialect.strict, + strict=self._dialect.strict, ) return self - def __next__(self): + def __next__(self) -> List[str]: if self.parser_gen is None: self.__iter__() + assert self.parser_gen is not None try: row = next(self.parser_gen) except ParserError as e: raise Error(str(e)) self.line_num += 1 return row - - def next(self): - # for python 2 - return self.__next__() diff --git a/clevercsv/wrappers.py b/clevercsv/wrappers.py index 486a2bd1..807c33f9 100644 --- a/clevercsv/wrappers.py +++ b/clevercsv/wrappers.py @@ -6,12 +6,23 @@ Author: Gertjan van den Burg """ +from __future__ import annotations import os import warnings +from typing import TYPE_CHECKING +from typing import Any +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import Optional +from typing import TypeVar + from ._optional import import_optional_dependency from .detect import Detector +from .dialect import SimpleDialect from .dict_read_write import DictReader from .dict_read_write import DictWriter from .encoding import get_encoding @@ -19,10 +30,23 @@ from .read import reader from .write import writer +if TYPE_CHECKING: + import pandas as pd + + from ._types import FileDescriptorOrPath + from ._types import _DialectLike + from ._types import _DictReadMapping + +_T = TypeVar("_T") + def stream_dicts( - filename, dialect=None, encoding=None, num_chars=None, verbose=False -): + filename: FileDescriptorOrPath, + dialect: Optional[_DialectLike] = None, + encoding: Optional[str] = None, + num_chars: Optional[int] = None, + verbose: bool = False, +) -> Iterator["_DictReadMapping"]: """Read a CSV file as a generator over dictionaries This function streams the rows of the CSV file as dictionaries. The keys of @@ -71,14 +95,18 @@ def stream_dicts( data = fid.read(num_chars) if num_chars else fid.read() dialect = Detector().detect(data, verbose=verbose) fid.seek(0) - r = DictReader(fid, dialect=dialect) - for row in r: + reader: DictReader = DictReader(fid, dialect=dialect) + for row in reader: yield row def read_dicts( - filename, dialect=None, encoding=None, num_chars=None, verbose=False -): + filename: "FileDescriptorOrPath", + dialect: Optional["_DialectLike"] = None, + encoding: Optional[str] = None, + num_chars: Optional[int] = None, + verbose: bool = False, +) -> List["_DictReadMapping"]: """Read a CSV file as a list of dictionaries This function returns the rows of the CSV file as a list of dictionaries. @@ -132,12 +160,12 @@ def read_dicts( def read_table( - filename, - dialect=None, - encoding=None, - num_chars=None, - verbose=False, -): + filename: "FileDescriptorOrPath", + dialect: Optional["_DialectLike"] = None, + encoding: Optional[str] = None, + num_chars: Optional[int] = None, + verbose: bool = False, +) -> List[List[str]]: """Read a CSV file as a table (a list of lists) This is a convenience function that reads a CSV file and returns the data @@ -191,12 +219,12 @@ def read_table( def stream_table( - filename, - dialect=None, - encoding=None, - num_chars=None, - verbose=False, -): + filename: "FileDescriptorOrPath", + dialect: Optional["_DialectLike"] = None, + encoding: Optional[str] = None, + num_chars: Optional[int] = None, + verbose: bool = False, +) -> Iterator[List[str]]: """Read a CSV file as a generator over rows of a table This is a convenience function that reads a CSV file and returns the data @@ -251,7 +279,12 @@ def stream_table( yield from r -def read_dataframe(filename, *args, num_chars=None, **kwargs): +def read_dataframe( + filename: "FileDescriptorOrPath", + *args: Any, + num_chars: Optional[int] = None, + **kwargs: Any, +) -> pd.DataFrame: """Read a CSV file to a Pandas dataframe This function uses CleverCSV to detect the dialect, and then passes this to @@ -284,6 +317,7 @@ def read_dataframe(filename, *args, num_chars=None, **kwargs): if not (os.path.exists(filename) and os.path.isfile(filename)): raise ValueError("Filename must be a regular file") pd = import_optional_dependency("pandas") + assert pd is not None # Use provided encoding or detect it, and record it for pandas enc = kwargs.get("encoding") or get_encoding(filename) @@ -306,13 +340,13 @@ def read_dataframe(filename, *args, num_chars=None, **kwargs): def detect_dialect( - filename, - num_chars=None, - encoding=None, - verbose=False, - method="auto", - skip=True, -): + filename: "FileDescriptorOrPath", + num_chars: Optional[int] = None, + encoding: Optional[str] = None, + verbose: bool = False, + method: str = "auto", + skip: bool = True, +) -> SimpleDialect: """Detect the dialect of a CSV file This is a utility function that simply returns the detected dialect of a @@ -360,8 +394,12 @@ def detect_dialect( def write_table( - table, filename, dialect="excel", transpose=False, encoding=None -): + table: Iterable[Iterable[Any]], + filename: "FileDescriptorOrPath", + dialect: "_DialectLike" = "excel", + transpose: bool = False, + encoding: Optional[str] = None, +) -> None: """Write a table (a list of lists) to a file This is a convenience function for writing a table to a CSV file. If the @@ -400,17 +438,24 @@ def write_table( return if transpose: - table = list(map(list, zip(*table))) + list_table = list(map(list, zip(*table))) + else: + list_table = list(map(list, table)) - if len(set(map(len, table))) > 1: + if len(set(map(len, list_table))) > 1: raise ValueError("Table doesn't have constant row length.") with open(filename, "w", newline="", encoding=encoding) as fp: w = writer(fp, dialect=dialect) - w.writerows(table) + w.writerows(list_table) -def write_dicts(items, filename, dialect="excel", encoding=None): +def write_dicts( + items: Iterable[Mapping[_T, Any]], + filename: "FileDescriptorOrPath", + dialect: "_DialectLike" = "excel", + encoding: Optional[str] = None, +) -> None: """Write a list of dicts to a file This is a convenience function to write dicts to a file. The header is @@ -440,8 +485,15 @@ def write_dicts(items, filename, dialect="excel", encoding=None): if not items: return - fieldnames = list(items[0].keys()) + iterator = iter(items) + try: + first = next(iterator) + except StopIteration: + return + + fieldnames = list(first.keys()) with open(filename, "w", newline="", encoding=encoding) as fp: w = DictWriter(fp, fieldnames=fieldnames, dialect=dialect) w.writeheader() - w.writerows(items) + w.writerow(first) + w.writerows(iterator) diff --git a/clevercsv/write.py b/clevercsv/write.py index 01472bd1..a0e3403b 100644 --- a/clevercsv/write.py +++ b/clevercsv/write.py @@ -8,8 +8,20 @@ """ +from __future__ import annotations + import csv +from typing import TYPE_CHECKING +from typing import Any +from typing import Iterable +from typing import Type + +if TYPE_CHECKING: + from clevercsv._types import SupportsWrite + +from clevercsv._types import _DialectLike + from .dialect import SimpleDialect from .exceptions import Error @@ -25,13 +37,23 @@ ] -class writer(object): - def __init__(self, csvfile, dialect="excel", **fmtparams): +class writer: + def __init__( + self, + csvfile: SupportsWrite, + dialect: _DialectLike = "excel", + **fmtparams, + ): self.original_dialect = dialect - self.dialect = self._make_python_dialect(dialect, **fmtparams) + self.dialect: Type[csv.Dialect] = self._make_python_dialect( + dialect, **fmtparams + ) self._writer = csv.writer(csvfile, dialect=self.dialect) - def _make_python_dialect(self, dialect, **fmtparams): + def _make_python_dialect( + self, dialect: _DialectLike, **fmtparams + ) -> Type[csv.Dialect]: + d: _DialectLike = "" if isinstance(dialect, str): d = csv.get_dialect(dialect) elif isinstance(dialect, csv.Dialect): @@ -56,13 +78,13 @@ def _make_python_dialect(self, dialect, **fmtparams): newdialect = type("dialect", (csv.Dialect,), props) return newdialect - def writerow(self, row): + def writerow(self, row: Iterable[Any]) -> Any: try: return self._writer.writerow(row) except csv.Error as e: raise Error(str(e)) - def writerows(self, rows): + def writerows(self, rows: Iterable[Iterable[Any]]) -> Any: try: return self._writer.writerows(rows) except csv.Error as e: diff --git a/pyproject.toml b/pyproject.toml index 23130b7a..a834bd6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,3 +11,10 @@ lines_between_types=1 [tool.ruff] # Exclude stubs directory for now exclude = ["stubs"] + +[tool.mypy] +python_version = 3.8 +warn_unused_configs = true + +# [[tool.mypy.overrides]] +# packages = ["stubs", "clevercsv", "tests"] diff --git a/setup.py b/setup.py index 96b20b9d..5adab650 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ LICENSE = "MIT" LICENSE_TROVE = "License :: OSI Approved :: MIT License" NAME = "clevercsv" -REQUIRES_PYTHON = ">=3.6.0" +REQUIRES_PYTHON = ">=3.8.0" URL = "https://github.com/alan-turing-institute/CleverCSV" VERSION = None @@ -113,6 +113,7 @@ def run(self): install_requires=REQUIRED, extras_require=EXTRAS, include_package_data=True, + package_data={"clevercsv": ["py.typed"]}, license=LICENSE, ext_modules=[ Extension("clevercsv.cparser", sources=["src/cparser.c"]), diff --git a/stubs/pandas/__init__.pyi b/stubs/pandas/__init__.pyi new file mode 100644 index 00000000..e84ca9dd --- /dev/null +++ b/stubs/pandas/__init__.pyi @@ -0,0 +1,119 @@ +from typing import Any + +from pandas._config import describe_option as describe_option +from pandas._config import get_option as get_option +from pandas._config import option_context as option_context +from pandas._config import options as options +from pandas._config import reset_option as reset_option +from pandas._config import set_option as set_option +from pandas.core.api import NA as NA +from pandas.core.api import BooleanDtype as BooleanDtype +from pandas.core.api import Categorical as Categorical +from pandas.core.api import CategoricalDtype as CategoricalDtype +from pandas.core.api import CategoricalIndex as CategoricalIndex +from pandas.core.api import DataFrame as DataFrame +from pandas.core.api import DateOffset as DateOffset +from pandas.core.api import DatetimeIndex as DatetimeIndex +from pandas.core.api import DatetimeTZDtype as DatetimeTZDtype +from pandas.core.api import Flags as Flags +from pandas.core.api import Float32Dtype as Float32Dtype +from pandas.core.api import Float64Dtype as Float64Dtype +from pandas.core.api import Float64Index as Float64Index +from pandas.core.api import Grouper as Grouper +from pandas.core.api import Index as Index +from pandas.core.api import IndexSlice as IndexSlice +from pandas.core.api import Int8Dtype as Int8Dtype +from pandas.core.api import Int16Dtype as Int16Dtype +from pandas.core.api import Int32Dtype as Int32Dtype +from pandas.core.api import Int64Dtype as Int64Dtype +from pandas.core.api import Int64Index as Int64Index +from pandas.core.api import Interval as Interval +from pandas.core.api import IntervalDtype as IntervalDtype +from pandas.core.api import IntervalIndex as IntervalIndex +from pandas.core.api import MultiIndex as MultiIndex +from pandas.core.api import NamedAgg as NamedAgg +from pandas.core.api import NaT as NaT +from pandas.core.api import Period as Period +from pandas.core.api import PeriodDtype as PeriodDtype +from pandas.core.api import PeriodIndex as PeriodIndex +from pandas.core.api import RangeIndex as RangeIndex +from pandas.core.api import Series as Series +from pandas.core.api import StringDtype as StringDtype +from pandas.core.api import Timedelta as Timedelta +from pandas.core.api import TimedeltaIndex as TimedeltaIndex +from pandas.core.api import Timestamp as Timestamp +from pandas.core.api import UInt8Dtype as UInt8Dtype +from pandas.core.api import UInt16Dtype as UInt16Dtype +from pandas.core.api import UInt32Dtype as UInt32Dtype +from pandas.core.api import UInt64Dtype as UInt64Dtype +from pandas.core.api import UInt64Index as UInt64Index +from pandas.core.api import array as array +from pandas.core.api import bdate_range as bdate_range +from pandas.core.api import date_range as date_range +from pandas.core.api import factorize as factorize +from pandas.core.api import interval_range as interval_range +from pandas.core.api import isna as isna +from pandas.core.api import isnull as isnull +from pandas.core.api import notna as notna +from pandas.core.api import notnull as notnull +from pandas.core.api import period_range as period_range +from pandas.core.api import set_eng_float_format as set_eng_float_format +from pandas.core.api import timedelta_range as timedelta_range +from pandas.core.api import to_datetime as to_datetime +from pandas.core.api import to_numeric as to_numeric +from pandas.core.api import to_timedelta as to_timedelta +from pandas.core.api import unique as unique +from pandas.core.api import value_counts as value_counts +from pandas.core.arrays.sparse import SparseDtype as SparseDtype +from pandas.core.computation.api import eval as eval +from pandas.core.reshape.api import concat as concat +from pandas.core.reshape.api import crosstab as crosstab +from pandas.core.reshape.api import cut as cut +from pandas.core.reshape.api import get_dummies as get_dummies +from pandas.core.reshape.api import lreshape as lreshape +from pandas.core.reshape.api import melt as melt +from pandas.core.reshape.api import merge as merge +from pandas.core.reshape.api import merge_asof as merge_asof +from pandas.core.reshape.api import merge_ordered as merge_ordered +from pandas.core.reshape.api import pivot as pivot +from pandas.core.reshape.api import pivot_table as pivot_table +from pandas.core.reshape.api import qcut as qcut +from pandas.core.reshape.api import wide_to_long as wide_to_long +from pandas.io.api import ExcelFile as ExcelFile +from pandas.io.api import ExcelWriter as ExcelWriter +from pandas.io.api import HDFStore as HDFStore +from pandas.io.api import read_clipboard as read_clipboard +from pandas.io.api import read_csv as read_csv +from pandas.io.api import read_excel as read_excel +from pandas.io.api import read_feather as read_feather +from pandas.io.api import read_fwf as read_fwf +from pandas.io.api import read_gbq as read_gbq +from pandas.io.api import read_hdf as read_hdf +from pandas.io.api import read_html as read_html +from pandas.io.api import read_json as read_json +from pandas.io.api import read_orc as read_orc +from pandas.io.api import read_parquet as read_parquet +from pandas.io.api import read_pickle as read_pickle +from pandas.io.api import read_sas as read_sas +from pandas.io.api import read_spss as read_spss +from pandas.io.api import read_sql as read_sql +from pandas.io.api import read_sql_query as read_sql_query +from pandas.io.api import read_sql_table as read_sql_table +from pandas.io.api import read_stata as read_stata +from pandas.io.api import read_table as read_table +from pandas.io.api import to_pickle as to_pickle +from pandas.tseries import offsets as offsets +from pandas.tseries.api import infer_freq as infer_freq +from pandas.util._print_versions import show_versions as show_versions +from pandas.util._tester import test as test + +__docformat__: str +hard_dependencies: Any +missing_dependencies: Any +module: Any +v: Any +__git_version__: Any + +def __getattr__(name: Any): ... + +# __doc__: str diff --git a/stubs/pythonfuzz/__init__.pyi b/stubs/pythonfuzz/__init__.pyi new file mode 100644 index 00000000..e69de29b diff --git a/stubs/pythonfuzz/main.pyi b/stubs/pythonfuzz/main.pyi new file mode 100644 index 00000000..f13dbe8b --- /dev/null +++ b/stubs/pythonfuzz/main.pyi @@ -0,0 +1,6 @@ +from typing import Any +from typing import Callable + +class PythonFuzz: + def __init__(self, func: Callable) -> None: ... + def __call__(self, *args: Any, **kwargs: Any) -> None: ... diff --git a/stubs/regex/__init__.pyi b/stubs/regex/__init__.pyi new file mode 100644 index 00000000..b47bdfb5 --- /dev/null +++ b/stubs/regex/__init__.pyi @@ -0,0 +1,61 @@ +from .regex import * + +# Names in __all__ with no definition: +# A +# ASCII +# B +# BESTMATCH +# D +# DEBUG +# DEFAULT_VERSION +# DOTALL +# E +# ENHANCEMATCH +# F +# FULLCASE +# I +# IGNORECASE +# L +# LOCALE +# M +# MULTILINE +# Match +# P +# POSIX +# Pattern +# R +# REVERSE +# Regex +# S +# Scanner +# T +# TEMPLATE +# U +# UNICODE +# V0 +# V1 +# VERBOSE +# VERSION0 +# VERSION1 +# W +# WORD +# X +# __doc__ +# __version__ +# cache_all +# compile +# error +# escape +# findall +# finditer +# fullmatch +# match +# purge +# search +# split +# splititer +# sub +# subf +# subfn +# subn +# template diff --git a/stubs/regex/_regex.pyi b/stubs/regex/_regex.pyi new file mode 100644 index 00000000..6171611d --- /dev/null +++ b/stubs/regex/_regex.pyi @@ -0,0 +1,13 @@ +from typing import Any + +CODE_SIZE: int +MAGIC: int +copyright: str + +def compile(*args, **kwargs) -> Any: ... +def fold_case(*args, **kwargs) -> Any: ... +def get_all_cases(*args, **kwargs) -> Any: ... +def get_code_size(*args, **kwargs) -> Any: ... +def get_expand_on_folding(*args, **kwargs) -> Any: ... +def get_properties(*args, **kwargs) -> Any: ... +def has_property_value(*args, **kwargs) -> Any: ... diff --git a/stubs/regex/_regex_core.pyi b/stubs/regex/_regex_core.pyi new file mode 100644 index 00000000..66cb0e6c --- /dev/null +++ b/stubs/regex/_regex_core.pyi @@ -0,0 +1,503 @@ +from typing import Any as _Any + +class error(Exception): + msg: _Any + pattern: _Any + pos: _Any + lineno: _Any + colno: _Any + def __init__( + self, message, pattern: _Any | None = ..., pos: _Any | None = ... + ) -> None: ... + +class _UnscopedFlagSet(Exception): ... +class ParseError(Exception): ... +class _FirstSetError(Exception): ... + +A: int + +ASCII: int +B: int +BESTMATCH: int +D: int +DEBUG: int +E: int +ENHANCEMATCH: int +F: int +FULLCASE: int +I: int +IGNORECASE: int +L: int +LOCALE: int +M: int +MULTILINE: int +P: int +POSIX: int +R: int +REVERSE: int +S: int +DOTALL: int +U: int +UNICODE: int +V0: int +VERSION0: int +V1: int +VERSION1: int +W: int +WORD: int +X: int +VERBOSE: int +T: int +TEMPLATE: int +DEFAULT_VERSION = VERSION1 + +class Namespace: ... + +class RegexBase: + def __init__(self) -> None: ... + def with_flags( + self, + positive: _Any | None = ..., + case_flags: _Any | None = ..., + zerowidth: _Any | None = ..., + ): ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def optimise(self, info, reverse): ... + def pack_characters(self, info): ... + def remove_captures(self): ... + def is_atomic(self): ... + def can_be_affix(self): ... + def contains_group(self): ... + def get_firstset(self, reverse) -> None: ... + def has_simple_start(self): ... + def compile(self, reverse: bool = ..., fuzzy: bool = ...): ... + def is_empty(self): ... + def __hash__(self): ... + def __eq__(self, other): ... + def __ne__(self, other): ... + def get_required_string(self, reverse): ... + +class ZeroWidthBase(RegexBase): + positive: _Any + def __init__(self, positive: bool = ...) -> None: ... + def get_firstset(self, reverse): ... + def dump(self, indent, reverse) -> None: ... + def max_width(self): ... + +class Any(RegexBase): + def has_simple_start(self): ... + def dump(self, indent, reverse) -> None: ... + def max_width(self): ... + +class AnyAll(Any): ... +class AnyU(Any): ... + +class Atomic(RegexBase): + subpattern: _Any + def __init__(self, subpattern) -> None: ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def optimise(self, info, reverse): ... + def pack_characters(self, info): ... + def remove_captures(self): ... + def can_be_affix(self): ... + def contains_group(self): ... + def get_firstset(self, reverse): ... + def has_simple_start(self): ... + def dump(self, indent, reverse) -> None: ... + def is_empty(self): ... + def __eq__(self, other): ... + def max_width(self): ... + def get_required_string(self, reverse): ... + +class Boundary(ZeroWidthBase): ... + +class Branch(RegexBase): + branches: _Any + def __init__(self, branches) -> None: ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def optimise(self, info, reverse): ... + def pack_characters(self, info): ... + def remove_captures(self): ... + def is_atomic(self): ... + def can_be_affix(self): ... + def contains_group(self): ... + def get_firstset(self, reverse): ... + def dump(self, indent, reverse) -> None: ... + def is_empty(self): ... + def __eq__(self, other): ... + def max_width(self): ... + +class CallGroup(RegexBase): + info: _Any + group: _Any + position: _Any + def __init__(self, info, group, position) -> None: ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def remove_captures(self) -> None: ... + def dump(self, indent, reverse) -> None: ... + def __eq__(self, other): ... + def max_width(self): ... + def __del__(self) -> None: ... + +class CallRef(RegexBase): + ref: _Any + parsed: _Any + def __init__(self, ref, parsed) -> None: ... + +class Character(RegexBase): + value: _Any + positive: _Any + case_flags: _Any + zerowidth: _Any + folded: _Any + def __init__( + self, + value, + positive: bool = ..., + case_flags=..., + zerowidth: bool = ..., + ) -> None: ... + def rebuild(self, positive, case_flags, zerowidth): ... + def optimise(self, info, reverse, in_set: bool = ...): ... + def get_firstset(self, reverse): ... + def has_simple_start(self): ... + def dump(self, indent, reverse) -> None: ... + def matches(self, ch): ... + def max_width(self): ... + folded_characters: _Any + def get_required_string(self, reverse): ... + +class Conditional(RegexBase): + info: _Any + group: _Any + yes_item: _Any + no_item: _Any + position: _Any + def __init__(self, info, group, yes_item, no_item, position) -> None: ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def optimise(self, info, reverse): ... + def pack_characters(self, info): ... + def remove_captures(self) -> None: ... + def is_atomic(self): ... + def can_be_affix(self): ... + def contains_group(self): ... + def get_firstset(self, reverse): ... + def dump(self, indent, reverse) -> None: ... + def is_empty(self): ... + def __eq__(self, other): ... + def max_width(self): ... + def __del__(self) -> None: ... + +class DefaultBoundary(ZeroWidthBase): ... +class DefaultEndOfWord(ZeroWidthBase): ... +class DefaultStartOfWord(ZeroWidthBase): ... +class EndOfLine(ZeroWidthBase): ... +class EndOfLineU(EndOfLine): ... +class EndOfString(ZeroWidthBase): ... +class EndOfStringLine(ZeroWidthBase): ... +class EndOfStringLineU(EndOfStringLine): ... +class EndOfWord(ZeroWidthBase): ... +class Failure(ZeroWidthBase): ... + +class Fuzzy(RegexBase): + subpattern: _Any + constraints: _Any + def __init__(self, subpattern, constraints: _Any | None = ...) -> None: ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def pack_characters(self, info): ... + def remove_captures(self): ... + def is_atomic(self): ... + def contains_group(self): ... + def dump(self, indent, reverse) -> None: ... + def is_empty(self): ... + def __eq__(self, other): ... + def max_width(self): ... + +class Grapheme(RegexBase): + def dump(self, indent, reverse) -> None: ... + def max_width(self): ... + +class GraphemeBoundary: + def compile(self, reverse, fuzzy): ... + +class GreedyRepeat(RegexBase): + subpattern: _Any + min_count: _Any + max_count: _Any + def __init__(self, subpattern, min_count, max_count) -> None: ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def optimise(self, info, reverse): ... + def pack_characters(self, info): ... + def remove_captures(self): ... + def is_atomic(self): ... + def can_be_affix(self): ... + def contains_group(self): ... + def get_firstset(self, reverse): ... + def dump(self, indent, reverse) -> None: ... + def is_empty(self): ... + def __eq__(self, other): ... + def max_width(self): ... + def get_required_string(self, reverse): ... + +class PossessiveRepeat(GreedyRepeat): + def is_atomic(self): ... + def dump(self, indent, reverse) -> None: ... + +class Group(RegexBase): + info: _Any + group: _Any + subpattern: _Any + call_ref: _Any + def __init__(self, info, group, subpattern) -> None: ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def optimise(self, info, reverse): ... + def pack_characters(self, info): ... + def remove_captures(self): ... + def is_atomic(self): ... + def can_be_affix(self): ... + def contains_group(self): ... + def get_firstset(self, reverse): ... + def has_simple_start(self): ... + def dump(self, indent, reverse) -> None: ... + def __eq__(self, other): ... + def max_width(self): ... + def get_required_string(self, reverse): ... + def __del__(self) -> None: ... + +class Keep(ZeroWidthBase): ... +class LazyRepeat(GreedyRepeat): ... + +class LookAround(RegexBase): + behind: _Any + positive: _Any + subpattern: _Any + def __init__(self, behind, positive, subpattern) -> None: ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def optimise(self, info, reverse): ... + def pack_characters(self, info): ... + def remove_captures(self): ... + def is_atomic(self): ... + def can_be_affix(self): ... + def contains_group(self): ... + def get_firstset(self, reverse): ... + def dump(self, indent, reverse) -> None: ... + def is_empty(self): ... + def __eq__(self, other): ... + def max_width(self): ... + +class LookAroundConditional(RegexBase): + behind: _Any + positive: _Any + subpattern: _Any + yes_item: _Any + no_item: _Any + def __init__( + self, behind, positive, subpattern, yes_item, no_item + ) -> None: ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def optimise(self, info, reverse): ... + def pack_characters(self, info): ... + def remove_captures(self) -> None: ... + def is_atomic(self): ... + def can_be_affix(self): ... + def contains_group(self): ... + def dump(self, indent, reverse) -> None: ... + def is_empty(self): ... + def __eq__(self, other): ... + def max_width(self): ... + def get_required_string(self, reverse): ... + +class PrecompiledCode(RegexBase): + code: _Any + def __init__(self, code) -> None: ... + +class Property(RegexBase): + value: _Any + positive: _Any + case_flags: _Any + zerowidth: _Any + def __init__( + self, + value, + positive: bool = ..., + case_flags=..., + zerowidth: bool = ..., + ) -> None: ... + def rebuild(self, positive, case_flags, zerowidth): ... + def optimise(self, info, reverse, in_set: bool = ...): ... + def get_firstset(self, reverse): ... + def has_simple_start(self): ... + def dump(self, indent, reverse) -> None: ... + def matches(self, ch): ... + def max_width(self): ... + +class Prune(ZeroWidthBase): ... + +class Range(RegexBase): + lower: _Any + upper: _Any + positive: _Any + case_flags: _Any + zerowidth: _Any + def __init__( + self, + lower, + upper, + positive: bool = ..., + case_flags=..., + zerowidth: bool = ..., + ) -> None: ... + def rebuild(self, positive, case_flags, zerowidth): ... + def optimise(self, info, reverse, in_set: bool = ...): ... + def dump(self, indent, reverse) -> None: ... + def matches(self, ch): ... + def max_width(self): ... + +class RefGroup(RegexBase): + info: _Any + group: _Any + position: _Any + case_flags: _Any + def __init__(self, info, group, position, case_flags=...) -> None: ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def remove_captures(self) -> None: ... + def dump(self, indent, reverse) -> None: ... + def max_width(self): ... + def __del__(self) -> None: ... + +class SearchAnchor(ZeroWidthBase): ... + +class Sequence(RegexBase): + items: _Any + def __init__(self, items: _Any | None = ...) -> None: ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def optimise(self, info, reverse): ... + def pack_characters(self, info): ... + def remove_captures(self): ... + def is_atomic(self): ... + def can_be_affix(self): ... + def contains_group(self): ... + def get_firstset(self, reverse): ... + def has_simple_start(self): ... + def dump(self, indent, reverse) -> None: ... + def is_empty(self): ... + def __eq__(self, other): ... + def max_width(self): ... + def get_required_string(self, reverse): ... + +class SetBase(RegexBase): + info: _Any + items: _Any + positive: _Any + case_flags: _Any + zerowidth: _Any + char_width: int + def __init__( + self, + info, + items, + positive: bool = ..., + case_flags=..., + zerowidth: bool = ..., + ) -> None: ... + def rebuild(self, positive, case_flags, zerowidth): ... + def get_firstset(self, reverse): ... + def has_simple_start(self): ... + def dump(self, indent, reverse) -> None: ... + def max_width(self): ... + def __del__(self) -> None: ... + +class SetDiff(SetBase): + items: _Any + def optimise(self, info, reverse, in_set: bool = ...): ... + def matches(self, ch): ... + +class SetInter(SetBase): + items: _Any + def optimise(self, info, reverse, in_set: bool = ...): ... + def matches(self, ch): ... + +class SetSymDiff(SetBase): + items: _Any + def optimise(self, info, reverse, in_set: bool = ...): ... + def matches(self, ch): ... + +class SetUnion(SetBase): + items: _Any + def optimise(self, info, reverse, in_set: bool = ...): ... + def matches(self, ch): ... + +class Skip(ZeroWidthBase): ... +class StartOfLine(ZeroWidthBase): ... +class StartOfLineU(StartOfLine): ... +class StartOfString(ZeroWidthBase): ... +class StartOfWord(ZeroWidthBase): ... + +class String(RegexBase): + characters: _Any + case_flags: _Any + folded_characters: _Any + required: bool + def __init__(self, characters, case_flags=...) -> None: ... + def get_firstset(self, reverse): ... + def has_simple_start(self): ... + def dump(self, indent, reverse) -> None: ... + def max_width(self): ... + def get_required_string(self, reverse): ... + +class Literal(String): + def dump(self, indent, reverse) -> None: ... + +class StringSet(Branch): + info: _Any + name: _Any + case_flags: _Any + set_key: _Any + branches: _Any + def __init__(self, info, name, case_flags=...) -> None: ... + def dump(self, indent, reverse) -> None: ... + def __del__(self) -> None: ... + +class Source: + string: _Any + char_type: _Any + pos: int + ignore_space: bool + sep: _Any + def __init__(self, string): ... + def get(self, override_ignore: bool = ...): ... + def get_many(self, count: int = ...): ... + def get_while(self, test_set, include: bool = ...): ... + def skip_while(self, test_set, include: bool = ...) -> None: ... + def match(self, substring): ... + def expect(self, substring) -> None: ... + def at_end(self): ... + +class Info: + flags: _Any + global_flags: _Any + inline_locale: bool + kwargs: _Any + group_count: int + group_index: _Any + group_name: _Any + char_type: _Any + named_lists_used: _Any + open_groups: _Any + open_group_count: _Any + defined_groups: _Any + group_calls: _Any + private_groups: _Any + def __init__( + self, flags: int = ..., char_type: _Any | None = ..., kwargs=... + ) -> None: ... + def open_group(self, name: _Any | None = ...): ... + def close_group(self) -> None: ... + def is_open_group(self, name): ... + +class Scanner: + lexicon: _Any + scanner: _Any + def __init__(self, lexicon, flags: int = ...) -> None: ... + match: _Any + def scan(self, string): ... diff --git a/stubs/regex/regex.pyi b/stubs/regex/regex.pyi new file mode 100644 index 00000000..a6819c4f --- /dev/null +++ b/stubs/regex/regex.pyi @@ -0,0 +1,189 @@ +from typing import Any + +from regex._regex_core import VERSION0 + +def match( + pattern, + string, + flags: int = ..., + pos: Any | None = ..., + endpos: Any | None = ..., + partial: bool = ..., + concurrent: Any | None = ..., + timeout: Any | None = ..., + ignore_unused: bool = ..., + **kwargs +): ... +def fullmatch( + pattern, + string, + flags: int = ..., + pos: Any | None = ..., + endpos: Any | None = ..., + partial: bool = ..., + concurrent: Any | None = ..., + timeout: Any | None = ..., + ignore_unused: bool = ..., + **kwargs +): ... +def search( + pattern, + string, + flags: int = ..., + pos: Any | None = ..., + endpos: Any | None = ..., + partial: bool = ..., + concurrent: Any | None = ..., + timeout: Any | None = ..., + ignore_unused: bool = ..., + **kwargs +): ... +def sub( + pattern, + repl, + string, + count: int = ..., + flags: int = ..., + pos: Any | None = ..., + endpos: Any | None = ..., + concurrent: Any | None = ..., + timeout: Any | None = ..., + ignore_unused: bool = ..., + **kwargs +): ... +def subf( + pattern, + format, + string, + count: int = ..., + flags: int = ..., + pos: Any | None = ..., + endpos: Any | None = ..., + concurrent: Any | None = ..., + timeout: Any | None = ..., + ignore_unused: bool = ..., + **kwargs +): ... +def subn( + pattern, + repl, + string, + count: int = ..., + flags: int = ..., + pos: Any | None = ..., + endpos: Any | None = ..., + concurrent: Any | None = ..., + timeout: Any | None = ..., + ignore_unused: bool = ..., + **kwargs +): ... +def subfn( + pattern, + format, + string, + count: int = ..., + flags: int = ..., + pos: Any | None = ..., + endpos: Any | None = ..., + concurrent: Any | None = ..., + timeout: Any | None = ..., + ignore_unused: bool = ..., + **kwargs +): ... +def split( + pattern, + string, + maxsplit: int = ..., + flags: int = ..., + concurrent: Any | None = ..., + timeout: Any | None = ..., + ignore_unused: bool = ..., + **kwargs +): ... +def splititer( + pattern, + string, + maxsplit: int = ..., + flags: int = ..., + concurrent: Any | None = ..., + timeout: Any | None = ..., + ignore_unused: bool = ..., + **kwargs +): ... +def findall( + pattern, + string, + flags: int = ..., + pos: Any | None = ..., + endpos: Any | None = ..., + overlapped: bool = ..., + concurrent: Any | None = ..., + timeout: Any | None = ..., + ignore_unused: bool = ..., + **kwargs +): ... +def finditer( + pattern, + string, + flags: int = ..., + pos: Any | None = ..., + endpos: Any | None = ..., + overlapped: bool = ..., + partial: bool = ..., + concurrent: Any | None = ..., + timeout: Any | None = ..., + ignore_unused: bool = ..., + **kwargs +): ... +def compile( + pattern, flags: int = ..., ignore_unused: bool = ..., **kwargs +): ... +def purge() -> None: ... +def cache_all(value: bool = ...): ... +def template(pattern, flags: int = ...): ... +def escape(pattern, special_only: bool = ..., literal_spaces: bool = ...): ... + +DEFAULT_VERSION = VERSION0 +Pattern: Any +Match: Any +Regex = compile + +# Names in __all__ with no definition: +# A +# ASCII +# B +# BESTMATCH +# D +# DEBUG +# DOTALL +# E +# ENHANCEMATCH +# F +# FULLCASE +# I +# IGNORECASE +# L +# LOCALE +# M +# MULTILINE +# P +# POSIX +# R +# REVERSE +# S +# Scanner +# T +# TEMPLATE +# U +# UNICODE +# V0 +# V1 +# VERBOSE +# VERSION0 +# VERSION1 +# W +# WORD +# X +# __doc__ +# __version__ +# error diff --git a/stubs/tabview/__init__.pyi b/stubs/tabview/__init__.pyi new file mode 100644 index 00000000..db597540 --- /dev/null +++ b/stubs/tabview/__init__.pyi @@ -0,0 +1 @@ +from .tabview import view as view diff --git a/stubs/tabview/tabview.pyi b/stubs/tabview/tabview.pyi new file mode 100644 index 00000000..4685cd48 --- /dev/null +++ b/stubs/tabview/tabview.pyi @@ -0,0 +1,163 @@ +import io + +from typing import Any + +basestring = str +file = io.FileIO + +def KEY_CTRL(key): ... +def addstr(*args): ... +def insstr(*args): ... + +class ReloadException(Exception): + start_pos: Any + column_width_mode: Any + column_gap: Any + column_widths: Any + search_str: Any + def __init__( + self, start_pos, column_width, column_gap, column_widths, search_str + ) -> None: ... + +class QuitException(Exception): ... + +class Viewer: + scr: Any + data: Any + info: Any + header_offset_orig: int + header: Any + header_offset: Any + num_data_columns: Any + column_width_mode: Any + column_gap: Any + trunc_char: Any + num_columns: int + vis_columns: int + init_search: Any + modifier: Any + def __init__(self, *args, **kwargs) -> None: ... + def column_xw(self, x): ... + def quit(self) -> None: ... + def reload(self) -> None: ... + def consume_modifier(self, default: int = ...): ... + def down(self) -> None: ... + def up(self) -> None: ... + def left(self) -> None: ... + def right(self) -> None: ... + y: Any + win_y: Any + def page_down(self) -> None: ... + def page_up(self) -> None: ... + x: Any + win_x: Any + def page_right(self) -> None: ... + def page_left(self) -> None: ... + def mark(self) -> None: ... + def goto_mark(self) -> None: ... + def home(self) -> None: ... + def goto_y(self, y) -> None: ... + def goto_row(self) -> None: ... + def goto_x(self, x) -> None: ... + def goto_col(self) -> None: ... + def goto_yx(self, y, x) -> None: ... + def line_home(self) -> None: ... + def line_end(self) -> None: ... + def show_cell(self) -> None: ... + def show_info(self): ... + textpad: Any + search_str: Any + def search(self) -> None: ... + def search_results( + self, rev: bool = ..., look_in_cur: bool = ... + ) -> None: ... + def search_results_prev( + self, rev: bool = ..., look_in_cur: bool = ... + ) -> None: ... + def help(self) -> None: ... + def toggle_header(self) -> None: ... + def column_gap_down(self) -> None: ... + def column_gap_up(self) -> None: ... + column_width: Any + def column_width_all_down(self) -> None: ... + def column_width_all_up(self) -> None: ... + def column_width_down(self) -> None: ... + def column_width_up(self) -> None: ... + def sort_by_column_numeric(self): ... + def sort_by_column_numeric_reverse(self): ... + def sort_by_column(self) -> None: ... + def sort_by_column_reverse(self) -> None: ... + def sort_by_column_natural(self) -> None: ... + def sort_by_column_natural_reverse(self) -> None: ... + def sorted_nicely(self, ls, key, rev: bool = ...): ... + def float_string_key(self, value): ... + def toggle_column_width(self) -> None: ... + def set_current_column_width(self) -> None: ... + def yank_cell(self) -> None: ... + keys: Any + def define_keys(self) -> None: ... + def run(self) -> None: ... + def handle_keys(self) -> None: ... + def handle_modifier(self, mod) -> None: ... + def resize(self) -> None: ... + def num_columns_fwd(self, x): ... + def num_columns_rev(self, x): ... + def recalculate_layout(self) -> None: ... + def location_string(self, yp, xp): ... + def display(self) -> None: ... + def strpad(self, s, width): ... + def hdrstr(self, x, width): ... + def cellstr(self, y, x, width): ... + def skip_to_row_change(self) -> None: ... + def skip_to_row_change_reverse(self) -> None: ... + def skip_to_col_change(self) -> None: ... + def skip_to_col_change_reverse(self) -> None: ... + +class TextBox: + scr: Any + data: Any + title: Any + tdata: Any + hid_rows: int + def __init__(self, scr, data: str = ..., title: str = ...) -> None: ... + def __call__(self) -> None: ... + handlers: Any + def setup_handlers(self) -> None: ... + def run(self) -> None: ... + def handle_key(self, key) -> None: ... + def close(self) -> None: ... + def scroll_down(self) -> None: ... + def scroll_up(self) -> None: ... + def display(self) -> None: ... + +def csv_sniff(data, enc): ... +def fix_newlines(data): ... +def adjust_space_delim(data, enc): ... +def process_data( + data, + enc: Any | None = ..., + delim: Any | None = ..., + quoting: Any | None = ..., + quote_char=..., +): ... +def data_list_or_file(data): ... +def pad_data(d): ... +def readme(): ... +def detect_encoding(data: Any | None = ...): ... +def main(stdscr, *args, **kwargs) -> None: ... +def view( + data, + enc: Any | None = ..., + start_pos=..., + column_width: int = ..., + column_gap: int = ..., + trunc_char: str = ..., + column_widths: Any | None = ..., + search_str: Any | None = ..., + double_width: bool = ..., + delimiter: Any | None = ..., + quoting: Any | None = ..., + info: Any | None = ..., + quote_char=..., +): ... +def parse_path(path): ... diff --git a/stubs/termcolor/__init__.pyi b/stubs/termcolor/__init__.pyi new file mode 100644 index 00000000..9c937267 --- /dev/null +++ b/stubs/termcolor/__init__.pyi @@ -0,0 +1,22 @@ +from typing import Any + +__ALL__: Any +VERSION: Any +ATTRIBUTES: Any +HIGHLIGHTS: Any +COLORS: Any +RESET: str + +def colored( + text, + color: Any | None = ..., + on_color: Any | None = ..., + attrs: Any | None = ..., +): ... +def cprint( + text, + color: Any | None = ..., + on_color: Any | None = ..., + attrs: Any | None = ..., + **kwargs +) -> None: ... diff --git a/stubs/wilderness/__init__.pyi b/stubs/wilderness/__init__.pyi new file mode 100644 index 00000000..47e061a2 --- /dev/null +++ b/stubs/wilderness/__init__.pyi @@ -0,0 +1,168 @@ +import abc +import argparse + +from typing import Dict +from typing import List +from typing import Optional +from typing import TextIO + +class DocumentableMixin(metaclass=abc.ABCMeta): + def __init__( + self, + description: Optional[str] = None, + extra_sections: Optional[Dict[str, str]] = None, + options_prolog: Optional[str] = None, + options_epilog: Optional[str] = None, + ) -> None: ... + @property + def description(self) -> Optional[str]: ... + @property + def parser(self) -> argparse.ArgumentParser: ... + @parser.setter + def parser(self, parser: argparse.ArgumentParser): ... + @property + def args(self) -> argparse.Namespace: ... + @args.setter + def args(self, args: argparse.Namespace): ... + @property + def argument_help(self) -> Dict[str, Optional[str]]: ... + +class Application(DocumentableMixin): + def __init__( + self, + name: str, + version: str, + author: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + default_command: Optional[str] = None, + add_help: bool = True, + extra_sections: Optional[Dict[str, str]] = None, + prolog: Optional[str] = None, + epilog: Optional[str] = None, + options_prolog: Optional[str] = None, + options_epilog: Optional[str] = None, + add_commands_section: bool = False, + ) -> None: ... + @property + def name(self) -> str: ... + @property + def author(self) -> str: ... + @property + def version(self) -> str: ... + @property + def commands(self) -> List[Command]: ... + @property + def groups(self) -> List[Group]: ... + def add_argument(self, *args, **kwargs) -> argparse.Action: ... + def add(self, command: Command): ... + def add_group(self, title: str) -> Group: ... + def register(self): ... + def handle(self) -> int: ... + def run( + self, + args: Optional[List[str]] = None, + namespace: Optional[argparse.Namespace] = None, + exit_on_error: bool = True, + ) -> int: ... + def run_command(self, command: Command) -> int: ... + def get_command(self, command_name: str) -> Command: ... + def set_prolog(self, prolog: str) -> None: ... + def set_epilog(self, epilog: str) -> None: ... + def get_commands_text(self) -> str: ... + def create_manpage(self) -> ManPage: ... + def format_help(self) -> str: ... + def print_help(self, file: Optional[TextIO] = None) -> None: ... + +class Group: + def __init__( + self, title: Optional[str] = None, is_root: bool = False + ) -> None: ... + @property + def application(self) -> Optional[Application]: ... + @property + def title(self) -> Optional[str]: ... + @property + def commands(self) -> List[Command]: ... + @property + def is_root(self) -> bool: ... + def commands_as_actions(self) -> List[argparse.Action]: ... + def set_app(self, app: Application) -> None: ... + def add(self, command: Command) -> None: ... + def __len__(self) -> int: ... + +class Command(DocumentableMixin, metaclass=abc.ABCMeta): + def __init__( + self, + name: str, + title: Optional[str] = None, + description: Optional[str] = None, + add_help: bool = True, + extra_sections: Optional[Dict[str, str]] = None, + options_prolog: Optional[str] = None, + options_epilog: Optional[str] = None, + ) -> None: ... + @property + def application(self) -> Optional[Application]: ... + @property + def name(self) -> str: ... + @property + def title(self) -> Optional[str]: ... + def add_argument(self, *args, **kwargs) -> None: ... + def add_argument_group(self, *args, **kwargs) -> ArgumentGroup: ... + def add_mutually_exclusive_group( + self, *args, **kwargs + ) -> MutuallyExclusiveGroup: ... + def register(self) -> None: ... + @abc.abstractmethod + def handle(self) -> int: ... + def create_manpage(self) -> ManPage: ... + +class ManPage: + def __init__( + self, + application_name: str, + author: Optional[str] = "", + command_name: Optional[str] = None, + date: Optional[str] = None, + title: Optional[str] = None, + version: Optional[str] = "", + ) -> None: ... + @property + def name(self) -> str: ... + def metadata(self) -> List[str]: ... + def preamble(self) -> List[str]: ... + def header(self) -> str: ... + def section_name(self) -> str: ... + def add_section_synopsis(self, synopsis: str) -> None: ... + def add_section(self, label: str, text: str) -> None: ... + def groffify(self, text: str) -> str: ... + def groffify_line(self, line: str) -> str: ... + def export(self, output_dir: str) -> str: ... + +class ArgumentGroup: + def __init__(self, group: argparse._ArgumentGroup) -> None: ... + @property + def command(self) -> Optional[Command]: ... + @command.setter + def command(self, command: Command): ... + def add_argument(self, *args, **kwargs): ... + +class MutuallyExclusiveGroup: + def __init__(self, meg: argparse._MutuallyExclusiveGroup) -> None: ... + @property + def command(self) -> Optional[Command]: ... + @command.setter + def command(self, command: Command): ... + def add_argument(self, *args, **kwargs): ... + +class Tester: + def __init__(self, app: Application) -> None: ... + @property + def application(self) -> Application: ... + def clear(self) -> None: ... + def get_return_code(self) -> Optional[int]: ... + def get_stdout(self) -> Optional[str]: ... + def get_stderr(self) -> Optional[str]: ... + def test_command(self, cmd_name: str, args: List[str]) -> None: ... + def test_application(self, args: Optional[List[str]] = None) -> None: ... diff --git a/tests/test_integration/test_dialect_detection.py b/tests/test_integration/test_dialect_detection.py index 917cff47..7308ea73 100644 --- a/tests/test_integration/test_dialect_detection.py +++ b/tests/test_integration/test_dialect_detection.py @@ -47,7 +47,8 @@ def log_result(name, kind, verbose, partial): "success": (LOG_SUCCESS, LOG_SUCCESS_PARTIAL, "green"), "failure": (LOG_FAILED, LOG_FAILED_PARTIAL, "red"), } - outfull, outpartial, color = table.get(kind) + assert kind in table + outfull, outpartial, color = table[kind] fname = outpartial if partial else outfull with open(fname, "a") as fp: diff --git a/tests/test_unit/test_console.py b/tests/test_unit/test_console.py index 5ea1a65a..29595575 100644 --- a/tests/test_unit/test_console.py +++ b/tests/test_unit/test_console.py @@ -11,6 +11,9 @@ import tempfile import unittest +from typing import List +from typing import Union + from wilderness import Tester from clevercsv import __version__ @@ -20,7 +23,7 @@ class ConsoleTestCase(unittest.TestCase): - def _build_file(self, table, dialect, encoding=None, newline=None): + def _build_file(self, table, dialect, encoding=None, newline=None) -> str: tmpfd, tmpfname = tempfile.mkstemp( prefix="ccsv_", suffix=".csv", @@ -40,7 +43,10 @@ def _detect_test_wrap(self, table, dialect): tester.test_command("detect", [tmpfname]) try: - output = tester.get_stdout().strip() + stdout = tester.get_stdout() + self.assertIsNotNone(stdout) + assert stdout is not None + output = stdout.strip() self.assertEqual(exp, output) finally: os.unlink(tmpfname) @@ -79,7 +85,10 @@ def test_detect_opts_1(self): exp = "Detected: " + str(dialect) try: - output = tester.get_stdout().strip() + stdout = tester.get_stdout() + self.assertIsNotNone(stdout) + assert stdout is not None + output = stdout.strip() self.assertEqual(exp, output) finally: os.unlink(tmpfname) @@ -96,7 +105,10 @@ def test_detect_opts_2(self): exp = "Detected: " + str(dialect) try: - output = tester.get_stdout().strip() + stdout = tester.get_stdout() + self.assertIsNotNone(stdout) + assert stdout is not None + output = stdout.strip() self.assertEqual(exp, output) finally: os.unlink(tmpfname) @@ -115,7 +127,10 @@ def test_detect_opts_3(self): quotechar = escapechar =""" try: - output = tester.get_stdout().strip() + stdout = tester.get_stdout() + self.assertIsNotNone(stdout) + assert stdout is not None + output = stdout.strip() self.assertEqual(exp, output) finally: os.unlink(tmpfname) @@ -130,7 +145,10 @@ def test_detect_opts_4(self): tester.test_command("detect", ["--json", "--add-runtime", tmpfname]) try: - output = tester.get_stdout().strip() + stdout = tester.get_stdout() + self.assertIsNotNone(stdout) + assert stdout is not None + output = stdout.strip() data = json.loads(output) self.assertEqual(data["delimiter"], ";") self.assertEqual(data["quotechar"], "") @@ -442,7 +460,11 @@ def test_standardize_in_place_noop(self): os.unlink(tmpfname) def test_standardize_multi(self): - table = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] + table: List[List[Union[str, int]]] = [ + ["A", "B", "C"], + [1, 2, 3], + [4, 5, 6], + ] dialects = ["excel", "unix", "excel-tab"] tmpfnames = [self._build_file(table, D, newline="") for D in dialects] @@ -476,7 +498,11 @@ def test_standardize_multi(self): any(map(os.unlink, tmpoutnames)) def test_standardize_multi_errors(self): - table = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] + table: List[List[Union[str, int]]] = [ + ["A", "B", "C"], + [1, 2, 3], + [4, 5, 6], + ] dialects = ["excel", "unix", "excel-tab"] tmpfnames = [self._build_file(table, D, newline="") for D in dialects] @@ -507,7 +533,11 @@ def test_standardize_multi_errors(self): any(map(os.unlink, tmpoutnames)) def test_standardize_multi_encoding(self): - table = [["Å", "B", "C"], [1, 2, 3], [4, 5, 6]] + table: List[List[Union[str, int]]] = [ + ["Å", "B", "C"], + [1, 2, 3], + [4, 5, 6], + ] dialects = ["excel", "unix", "excel-tab"] encoding = "ISO-8859-1" tmpfnames = [ @@ -547,7 +577,11 @@ def test_standardize_multi_encoding(self): any(map(os.unlink, tmpoutnames)) def test_standardize_in_place_multi(self): - table = [["Å", "B", "C"], [1, 2, 3], [4, 5, 6]] + table: List[List[Union[str, int]]] = [ + ["Å", "B", "C"], + [1, 2, 3], + [4, 5, 6], + ] dialects = ["excel", "unix", "excel-tab"] encoding = "ISO-8859-1" tmpfnames = [ @@ -572,7 +606,11 @@ def test_standardize_in_place_multi(self): any(map(os.unlink, tmpfnames)) def test_standardize_in_place_multi_noop(self): - table = [["Å", "B", "C"], [1, 2, 3], [4, 5, 6]] + table: List[List[Union[str, int]]] = [ + ["Å", "B", "C"], + [1, 2, 3], + [4, 5, 6], + ] dialects = ["excel", "excel", "excel"] tmpfnames = [self._build_file(table, D, newline="") for D in dialects] diff --git a/tests/test_unit/test_detect_type.py b/tests/test_unit/test_detect_type.py index c4b8065b..83fcb42e 100644 --- a/tests/test_unit/test_detect_type.py +++ b/tests/test_unit/test_detect_type.py @@ -9,6 +9,8 @@ import unittest +from typing import List + from clevercsv.detect_type import TypeDetector from clevercsv.detect_type import type_score from clevercsv.dialect import SimpleDialect @@ -21,7 +23,7 @@ def setUp(self): # NUMBERS def test_number(self): - yes_number = [ + yes_number: List[str] = [ "1", "2", "34", @@ -87,7 +89,7 @@ def test_number(self): for num in yes_number: with self.subTest(num=num): self.assertTrue(self.td.is_number(num)) - no_number = [ + no_number: List[str] = [ "0000.213654", "123.465.798", "0.5e0.5", @@ -111,7 +113,7 @@ def test_number(self): # DATES def test_date(self): - yes_date = [ + yes_date: List[str] = [ "031219", "03122019", "03-12-19", @@ -162,7 +164,7 @@ def test_date(self): for date in yes_date: with self.subTest(date=date): self.assertTrue(self.td.is_date(date)) - no_date = [ + no_date: List[str] = [ "2018|01|02", "30/07-88", "12.01-99", @@ -177,11 +179,14 @@ def test_date(self): # DATETIME def test_datetime(self): - yes_dt = ["2019-01-12T04:01:23Z", "2021-09-26T12:13:31+01:00"] + yes_dt: List[str] = [ + "2019-01-12T04:01:23Z", + "2021-09-26T12:13:31+01:00", + ] for dt in yes_dt: with self.subTest(dt=dt): self.assertTrue(self.td.is_datetime(dt)) - no_date = [] + no_date: List[str] = [] for date in no_date: with self.subTest(date=date): self.assertFalse(self.td.is_datetime(dt)) @@ -190,7 +195,7 @@ def test_datetime(self): def test_url(self): # Some cases copied from https://mathiasbynens.be/demo/url-regex - yes_url = [ + yes_url: List[str] = [ "Cocoal.icio.us", "Websquash.com", "bbc.co.uk", @@ -262,7 +267,7 @@ def test_url(self): with self.subTest(url=url): self.assertTrue(self.td.is_url(url)) - no_url = [ + no_url: List[str] = [ "//", "///", "///a", @@ -305,7 +310,7 @@ def test_unicode_alphanum(self): # These tests are by no means inclusive and ought to be extended in the # future. - yes_alphanum = ["this is a cell", "1231 pounds"] + yes_alphanum: List[str] = ["this is a cell", "1231 pounds"] for unicode_alphanum in yes_alphanum: with self.subTest(unicode_alphanum=unicode_alphanum): self.assertTrue(self.td.is_unicode_alphanum(unicode_alphanum)) @@ -315,7 +320,7 @@ def test_unicode_alphanum(self): ) ) - no_alphanum = ["https://www.gertjan.dev"] + no_alphanum: List[str] = ["https://www.gertjan.dev"] for unicode_alpanum in no_alphanum: with self.subTest(unicode_alpanum=unicode_alpanum): self.assertFalse(self.td.is_unicode_alphanum(unicode_alpanum)) @@ -325,7 +330,7 @@ def test_unicode_alphanum(self): ) ) - only_quoted = ["this string, with a comma"] + only_quoted: List[str] = ["this string, with a comma"] for unicode_alpanum in only_quoted: with self.subTest(unicode_alpanum=unicode_alpanum): self.assertFalse( @@ -340,12 +345,12 @@ def test_unicode_alphanum(self): ) def test_bytearray(self): - yes_bytearray = [ + yes_bytearray: List[str] = [ "bytearray(b'')", "bytearray(b'abc,*&@\"')", "bytearray(b'bytearray(b'')')", ] - no_bytearray = [ + no_bytearray: List[str] = [ "bytearray(b'abc", "bytearray(b'abc'", "bytearray('abc')", @@ -363,7 +368,7 @@ def test_bytearray(self): # Unix path def test_unix_path(self): - yes_path = [ + yes_path: List[str] = [ "/Users/person/abc/def-ghi/blabla.csv.test", "/home/username/share/a/_b/c_d/e.py", "/home/username/share", diff --git a/tests/test_unit/test_dict.py b/tests/test_unit/test_dict.py index 8a941da8..43d45c5a 100644 --- a/tests/test_unit/test_dict.py +++ b/tests/test_unit/test_dict.py @@ -12,8 +12,13 @@ import tempfile import unittest +from typing import Any +from typing import Dict + import clevercsv +from clevercsv.dict_read_write import DictReader + class DictTestCase(unittest.TestCase): ############################ @@ -57,8 +62,9 @@ def test_write_fields_not_in_fieldnames(self): with tempfile.TemporaryFile("w+", newline="") as fp: writer = clevercsv.DictWriter(fp, fieldnames=["f1", "f2", "f3"]) # Of special note is the non-string key (CPython issue 19449) + content: Dict[Any, Any] = {"f4": 10, "f2": "spam", 1: "abc"} with self.assertRaises(ValueError) as cx: - writer.writerow({"f4": 10, "f2": "spam", 1: "abc"}) + writer.writerow(content) exception = str(cx.exception) self.assertIn("fieldnames", exception) self.assertIn("'f4'", exception) @@ -101,7 +107,7 @@ def test_read_dict_no_fieldnames(self): with tempfile.TemporaryFile("w+") as fp: fp.write("f1,f2,f3\r\n1,2,abc\r\n") fp.seek(0) - reader = clevercsv.DictReader(fp) + reader: DictReader = clevercsv.DictReader(fp) self.assertEqual(next(reader), {"f1": "1", "f2": "2", "f3": "abc"}) self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"]) @@ -123,7 +129,7 @@ def test_read_dict_fieldnames_chain(self): with tempfile.TemporaryFile("w+") as fp: fp.write("f1,f2,f3\r\n1,2,abc\r\n") fp.seek(0) - reader = clevercsv.DictReader(fp) + reader: DictReader = clevercsv.DictReader(fp) first = next(reader) for row in itertools.chain([first], reader): self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"]) @@ -155,7 +161,7 @@ def test_read_long_with_rest_no_fieldnames(self): with tempfile.TemporaryFile("w+") as fp: fp.write("f1,f2\r\n1,2,abc,4,5,6\r\n") fp.seek(0) - reader = clevercsv.DictReader(fp, restkey="_rest") + reader: DictReader = clevercsv.DictReader(fp, restkey="_rest") self.assertEqual(reader.fieldnames, ["f1", "f2"]) self.assertEqual( next(reader), @@ -238,7 +244,9 @@ def test_read_semi_sep(self): # Start tests added for CleverCSV # def test_read_duplicate_fieldnames(self): - reader = clevercsv.DictReader(["f1,f2,f1\r\n", "a", "b", "c"]) + reader: DictReader = clevercsv.DictReader( + ["f1,f2,f1\r\n", "a", "b", "c"] + ) with self.assertWarns(UserWarning): reader.fieldnames diff --git a/tests/test_unit/test_wrappers.py b/tests/test_unit/test_wrappers.py index a2becc6a..e054f002 100644 --- a/tests/test_unit/test_wrappers.py +++ b/tests/test_unit/test_wrappers.py @@ -12,6 +12,9 @@ import types import unittest +from typing import List +from typing import Union + import pandas as pd from clevercsv import wrappers @@ -204,7 +207,11 @@ def _write_test_table(self, table, expected, **kwargs): os.unlink(tmpfname) def test_write_table(self): - table = [["A", "B,C", "D"], [1, 2, 3], [4, 5, 6]] + table: List[List[Union[str, int]]] = [ + ["A", "B,C", "D"], + [1, 2, 3], + [4, 5, 6], + ] exp = 'A,"B,C",D\r\n1,2,3\r\n4,5,6\r\n' with self.subTest(name="default"): self._write_test_table(table, exp) diff --git a/tests/test_unit/test_write.py b/tests/test_unit/test_write.py index 01f495e4..2c8367c9 100644 --- a/tests/test_unit/test_write.py +++ b/tests/test_unit/test_write.py @@ -18,13 +18,6 @@ class WriterTestCase(unittest.TestCase): - def writerAssertEqual(self, input, expected_result): - with tempfile.TemporaryFile("w+", newline="", prefix="ccsv_") as fp: - writer = clevercsv.writer(fp, dialect=self.dialect) - writer.writerows(input) - fp.seek(0) - self.assertEqual(fp.read(), expected_result) - def _write_test(self, fields, expect, **kwargs): with tempfile.TemporaryFile("w+", newline="", prefix="ccsv_") as fp: writer = clevercsv.writer(fp, **kwargs) From 94609460689a3f98481ffdaf898d9a57f34545d2 Mon Sep 17 00:00:00 2001 From: Gertjan van den Burg Date: Tue, 5 Sep 2023 22:26:35 +0100 Subject: [PATCH 2/5] Bump minimal Python version to 3.8 --- .github/workflows/build.yml | 2 +- .github/workflows/deploy.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 04088745..312ba269 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -42,7 +42,7 @@ jobs: strategy: matrix: os: [ 'ubuntu-latest', 'macos-latest', 'windows-latest' ] - py: [ '3.7', '3.11' ] # minimal and latest + py: [ '3.8', '3.11' ] # minimal and latest steps: - name: Install Python ${{ matrix.py }} uses: actions/setup-python@v2 diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 2e85a055..ff85a6c2 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -31,7 +31,7 @@ jobs: env: CIBW_TEST_COMMAND: "python -VV && python -m unittest discover -f -s {project}/tests/test_unit/" CIBW_TEST_EXTRAS: "full" - CIBW_SKIP: "pp* cp27-* cp33-* cp34-* cp35-* cp36-* cp310-win32 *-musllinux_* *-manylinux_i686" + CIBW_SKIP: "pp* cp27-* cp33-* cp34-* cp35-* cp36-* cp37-* cp310-win32 *-musllinux_* *-manylinux_i686" CIBW_ARCHS_MACOS: x86_64 arm64 universal2 CIBW_ARCHS_LINUX: auto aarch64 From 2cf87e0e3abf5e465405e400d2ffc3b61c75dcab Mon Sep 17 00:00:00 2001 From: Gertjan van den Burg Date: Tue, 5 Sep 2023 22:34:20 +0100 Subject: [PATCH 3/5] Make type hint work at runtime --- clevercsv/_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clevercsv/_types.py b/clevercsv/_types.py index e14cc006..cc684bfd 100644 --- a/clevercsv/_types.py +++ b/clevercsv/_types.py @@ -14,7 +14,7 @@ from clevercsv.dialect import SimpleDialect -AnyPath = Union[str, bytes, os.PathLike[str], os.PathLike[bytes]] +AnyPath = Union[str, bytes, "os.PathLike[str]", "os.PathLike[bytes]"] _OpenFile = Union[AnyPath, int] _DictRow = Mapping[str, Any] _DialectLike = Union[str, csv.Dialect, Type[csv.Dialect], SimpleDialect] From dff58004aed80bbed7a62dc4f1e26f9d0944fa44 Mon Sep 17 00:00:00 2001 From: Gertjan van den Burg Date: Tue, 5 Sep 2023 22:48:37 +0100 Subject: [PATCH 4/5] Minor fixes for type annotations --- clevercsv/dict_read_write.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/clevercsv/dict_read_write.py b/clevercsv/dict_read_write.py index f9cf8adb..521d10a9 100644 --- a/clevercsv/dict_read_write.py +++ b/clevercsv/dict_read_write.py @@ -9,6 +9,9 @@ Author: Gertjan van den Burg """ + +from __future__ import annotations + import warnings from collections import OrderedDict @@ -113,7 +116,7 @@ def __next__(self) -> "_DictReadMapping[Union[_T, Any], Union[str, Any]]": class DictWriter(Generic[_T]): def __init__( self, - f: "SupportsWrite[str]", + f: SupportsWrite[str], fieldnames: Collection[_T], restval: Optional[Any] = "", extrasaction: Literal["raise", "ignore"] = "raise", From aba1362eeaa263ecb62636ef18adb4907ca4951a Mon Sep 17 00:00:00 2001 From: Gertjan van den Burg Date: Sun, 24 Sep 2023 15:01:45 +0100 Subject: [PATCH 5/5] More type hints --- clevercsv/__main__.py | 2 +- clevercsv/_regexes.py | 38 +++--- clevercsv/_types.py | 3 + clevercsv/consistency.py | 8 +- clevercsv/console/__init__.py | 2 +- clevercsv/console/application.py | 4 +- clevercsv/console/commands/_utils.py | 21 +++- clevercsv/console/commands/code.py | 4 +- clevercsv/console/commands/detect.py | 6 +- clevercsv/console/commands/explore.py | 4 +- clevercsv/console/commands/standardize.py | 61 ++++++++-- clevercsv/console/commands/view.py | 6 +- clevercsv/detect.py | 67 +++++++---- clevercsv/detect_pattern.py | 2 +- clevercsv/detect_type.py | 22 ++-- clevercsv/dialect.py | 9 +- clevercsv/dict_read_write.py | 2 +- clevercsv/normal_form.py | 78 ++++++++----- clevercsv/potential_dialects.py | 33 ++++-- clevercsv/utils.py | 13 ++- clevercsv/wrappers.py | 12 +- clevercsv/write.py | 8 +- pyproject.toml | 16 ++- stubs/pythonfuzz/main.pyi | 2 +- stubs/wilderness/__init__.pyi | 12 +- .../test_dialect_detection.py | 39 +++++-- tests/test_unit/test_abstraction.py | 4 +- tests/test_unit/test_consistency.py | 2 +- tests/test_unit/test_console.py | 102 ++++++++-------- tests/test_unit/test_cparser.py | 109 ++++++++++-------- tests/test_unit/test_detect.py | 21 +++- tests/test_unit/test_detect_pattern.py | 40 +++---- tests/test_unit/test_detect_type.py | 22 ++-- tests/test_unit/test_dict.py | 56 ++++----- tests/test_unit/test_encoding.py | 66 ++++++----- tests/test_unit/test_fuzzing.py | 2 +- tests/test_unit/test_normal_forms.py | 10 +- tests/test_unit/test_potential_dialects.py | 8 +- tests/test_unit/test_reader.py | 29 +++-- tests/test_unit/test_wrappers.py | 51 +++++--- tests/test_unit/test_write.py | 28 +++-- 41 files changed, 647 insertions(+), 377 deletions(-) diff --git a/clevercsv/__main__.py b/clevercsv/__main__.py index 4fa62598..9aafa17c 100644 --- a/clevercsv/__main__.py +++ b/clevercsv/__main__.py @@ -10,7 +10,7 @@ from ._optional import import_optional_dependency -def main(): +def main() -> None: # Check that necessary dependencies are available import_optional_dependency("wilderness") diff --git a/clevercsv/_regexes.py b/clevercsv/_regexes.py index 6e5a537e..b0d84c2f 100644 --- a/clevercsv/_regexes.py +++ b/clevercsv/_regexes.py @@ -10,7 +10,7 @@ # Regular expressions for number formats # ########################################## -PATTERN_NUMBER_1: Pattern = regex.compile( +PATTERN_NUMBER_1: Pattern[str] = regex.compile( r"^(?=[+-\.\d])" r"[+-]?" r"(?:0|[1-9]\d*)?" @@ -35,11 +35,11 @@ r"$" ) -PATTERN_NUMBER_2: Pattern = regex.compile( +PATTERN_NUMBER_2: Pattern[str] = regex.compile( r"[+-]?(?:[1-9]|[1-9]\d{0,2})(?:\,\d{3})+\.\d*" ) -PATTERN_NUMBER_3: Pattern = regex.compile( +PATTERN_NUMBER_3: Pattern[str] = regex.compile( r"[+-]?(?:[1-9]|[1-9]\d{0,2})(?:\.\d{3})+\,\d*" ) @@ -47,7 +47,7 @@ # Regular expressions for url, email, and ip # ############################################## -PATTERN_URL: Pattern = regex.compile( +PATTERN_URL: Pattern[str] = regex.compile( r"(" r"(https?|ftp):\/\/(?!\-)" r")?" @@ -62,17 +62,17 @@ r"(\.[a-z]+)?" ) -PATTERN_EMAIL: Pattern = regex.compile( +PATTERN_EMAIL: Pattern[str] = regex.compile( r"(^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$)" ) -PATTERN_IPV4: Pattern = regex.compile(r"(?:\d{1,3}\.){3}\d{1,3}") +PATTERN_IPV4: Pattern[str] = regex.compile(r"(?:\d{1,3}\.){3}\d{1,3}") ################################################# # Regular expressions related to time notations # ################################################# -PATTERN_TIME_HHMMSSZZ: Pattern = regex.compile( +PATTERN_TIME_HHMMSSZZ: Pattern[str] = regex.compile( r"(0[0-9]|1[0-9]|2[0-3])" r":" r"([0-5][0-9])" @@ -84,21 +84,23 @@ r"([0-5][0-9])" ) -PATTERN_TIME_HHMMSS: Pattern = regex.compile( +PATTERN_TIME_HHMMSS: Pattern[str] = regex.compile( r"(0[0-9]|1[0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9])" ) -PATTERN_TIME_HHMM_1: Pattern = regex.compile( +PATTERN_TIME_HHMM_1: Pattern[str] = regex.compile( r"(0[0-9]|1[0-9]|2[0-3]):([0-5][0-9])" ) -PATTERN_TIME_HHMM_2: Pattern = regex.compile( +PATTERN_TIME_HHMM_2: Pattern[str] = regex.compile( r"(0[0-9]|1[0-9]|2[0-3])([0-5][0-9])" ) -PATTERN_TIME_HH: Pattern = regex.compile(r"(0[0-9]|1[0-9]|2[0-3])([0-5][0-9])") +PATTERN_TIME_HH: Pattern[str] = regex.compile( + r"(0[0-9]|1[0-9]|2[0-3])([0-5][0-9])" +) -PATTERN_TIME_HMM: Pattern = regex.compile( +PATTERN_TIME_HMM: Pattern[str] = regex.compile( r"([0-9]|1[0-9]|2[0-3]):([0-5][0-9])" ) @@ -109,7 +111,7 @@ # Regex for various date formats. See # https://github.com/alan-turing-institute/CleverCSV/blob/master/notes/date_regex/dateregex_annotated.txt # for an explanation. -PATTERN_DATE: Pattern = regex.compile( +PATTERN_DATE: Pattern[str] = regex.compile( r"(" r"(0[1-9]|1[0-2])" r"(" @@ -238,7 +240,7 @@ ALPHANUM_SPECIALS: str = regex.escape(r"".join(SPECIALS_ALLOWED)) # Regex for alphanumeric text -PATTERN_ALPHANUM: Pattern = regex.compile( +PATTERN_ALPHANUM: Pattern[str] = regex.compile( r"(" r"\p{N}?\p{L}+" r"[" @@ -254,7 +256,7 @@ r"".join(SPECIALS_ALLOWED) + r"".join(QUOTED_SPECIALS_ALLOWED) ) # Regex for alphanumeric text in quoted strings -PATTERN_ALPHANUM_QUOTED: Pattern = regex.compile( +PATTERN_ALPHANUM_QUOTED: Pattern[str] = regex.compile( r"(" r"\p{N}?\p{L}+" r"[" @@ -270,13 +272,13 @@ # Regular expression for currency # ################################### -PATTERN_CURRENCY: Pattern = regex.compile(r"\p{Sc}\s?(.*)") +PATTERN_CURRENCY: Pattern[str] = regex.compile(r"\p{Sc}\s?(.*)") ##################################### # Regular expression for unix paths # ##################################### -PATTERN_UNIX_PATH: Pattern = regex.compile( +PATTERN_UNIX_PATH: Pattern[str] = regex.compile( r"[~.]?(?:\/[a-zA-Z0-9\.\-\_]+)+\/?" ) @@ -284,7 +286,7 @@ # Map of regular expresions for type detection # ################################################ -DEFAULT_TYPE_REGEXES: Dict[str, Pattern] = { +DEFAULT_TYPE_REGEXES: Dict[str, Pattern[str]] = { "number_1": PATTERN_NUMBER_1, "number_2": PATTERN_NUMBER_2, "number_3": PATTERN_NUMBER_3, diff --git a/clevercsv/_types.py b/clevercsv/_types.py index cc684bfd..48398374 100644 --- a/clevercsv/_types.py +++ b/clevercsv/_types.py @@ -10,14 +10,17 @@ from typing import Any from typing import Mapping from typing import Type +from typing import TypeVar from typing import Union from clevercsv.dialect import SimpleDialect AnyPath = Union[str, bytes, "os.PathLike[str]", "os.PathLike[bytes]"] +StrPath = Union[str, "os.PathLike[str]"] _OpenFile = Union[AnyPath, int] _DictRow = Mapping[str, Any] _DialectLike = Union[str, csv.Dialect, Type[csv.Dialect], SimpleDialect] +_T = TypeVar("_T") if sys.version_info >= (3, 8): from typing import Dict as _DictReadMapping diff --git a/clevercsv/consistency.py b/clevercsv/consistency.py index 71d2afdb..a7454f30 100644 --- a/clevercsv/consistency.py +++ b/clevercsv/consistency.py @@ -88,7 +88,7 @@ def cached_is_known_type(cell: str, is_quoted: bool) -> bool: self._cached_is_known_type = cached_is_known_type def detect( - self, data: str, delimiters: Optional[Iterable[str]] = None + self, data: str, delimiters: Optional[List[str]] = None ) -> Optional[SimpleDialect]: """Detect the dialect using the consistency measure @@ -192,7 +192,7 @@ def get_best_dialects( return [d for d, score in scores.items() if score.Q == Qmax] def compute_type_score( - self, data: str, dialect: SimpleDialect, eps=DEFAULT_EPS_TYPE + self, data: str, dialect: SimpleDialect, eps: float = DEFAULT_EPS_TYPE ) -> float: """Compute the type score""" total = known = 0 @@ -211,8 +211,10 @@ def detect_dialect_consistency( delimiters: Optional[Iterable[str]] = None, skip: bool = True, verbose: bool = False, -): +) -> Optional[SimpleDialect]: """Helper function that wraps ConsistencyDetector""" # Mostly kept for backwards compatibility consistency_detector = ConsistencyDetector(skip=skip, verbose=verbose) + if delimiters is not None: + delimiters = list(delimiters) return consistency_detector.detect(data, delimiters=delimiters) diff --git a/clevercsv/console/__init__.py b/clevercsv/console/__init__.py index 4fa1df8a..95554677 100644 --- a/clevercsv/console/__init__.py +++ b/clevercsv/console/__init__.py @@ -3,6 +3,6 @@ from .application import build_application -def main(): +def main() -> int: app = build_application() return app.run() diff --git a/clevercsv/console/application.py b/clevercsv/console/application.py index 71643fa8..9464783c 100644 --- a/clevercsv/console/application.py +++ b/clevercsv/console/application.py @@ -64,7 +64,7 @@ class CleverCSVApplication(Application): ), } - def __init__(self): + def __init__(self) -> None: super().__init__( "clevercsv", version=__version__, @@ -74,7 +74,7 @@ def __init__(self): extra_sections=self._extra, ) - def register(self): + def register(self) -> None: self.add_argument( "-V", "--version", diff --git a/clevercsv/console/commands/_utils.py b/clevercsv/console/commands/_utils.py index 36ada229..20c1861a 100644 --- a/clevercsv/console/commands/_utils.py +++ b/clevercsv/console/commands/_utils.py @@ -1,9 +1,14 @@ # -*- coding: utf-8 -*- +from typing import Any +from typing import List +from typing import Optional + from clevercsv import __version__ +from clevercsv.dialect import SimpleDialect -def parse_int(val, name): +def parse_int(val: Any, name: str) -> Optional[int]: """Parse a number to an integer if possible""" if val is None: return val @@ -15,7 +20,13 @@ def parse_int(val, name): ) -def generate_code(filename, dialect, encoding, use_pandas=False): +def generate_code( + filename: str, + dialect: SimpleDialect, + encoding: Optional[str], + use_pandas: bool = False, +) -> List[str]: + assert dialect.quotechar is not None d = '"\\t"' if dialect.delimiter == "\t" else f'"{dialect.delimiter}"' q = '"%s"' % (dialect.quotechar.replace('"', '\\"')) e = repr(f"{dialect.escapechar}").replace("'", '"') @@ -26,7 +37,8 @@ def generate_code(filename, dialect, encoding, use_pandas=False): "import clevercsv", ] if use_pandas: - return base + [ + return [ + *base, "", f'df = clevercsv.read_dataframe("{filename}", delimiter={d}, ' f"quotechar={q}, escapechar={e})", @@ -34,7 +46,8 @@ def generate_code(filename, dialect, encoding, use_pandas=False): ] enc = "None" if encoding is None else f'"{encoding}"' - lines = base + [ + lines = [ + *base, "", f'with open("{filename}", "r", newline="", encoding={enc}) as fp:', " reader = clevercsv.reader(fp, " diff --git a/clevercsv/console/commands/code.py b/clevercsv/console/commands/code.py index edc5dddd..de88d127 100644 --- a/clevercsv/console/commands/code.py +++ b/clevercsv/console/commands/code.py @@ -21,7 +21,7 @@ class CodeCommand(Command): "and copy the generated code to a Python script." ) - def __init__(self): + def __init__(self) -> None: super().__init__( name="code", title="Generate Python code to import a CSV file", @@ -29,7 +29,7 @@ def __init__(self): extra_sections={"CleverCSV": "Part of the CleverCSV suite"}, ) - def register(self): + def register(self) -> None: self.add_argument("path", help="Path to the CSV file") self.add_argument( "-e", diff --git a/clevercsv/console/commands/detect.py b/clevercsv/console/commands/detect.py index 74775d90..58ccefb1 100644 --- a/clevercsv/console/commands/detect.py +++ b/clevercsv/console/commands/detect.py @@ -18,7 +18,7 @@ class DetectCommand(Command): _description = "Detect the dialect of a CSV file." - def __init__(self): + def __init__(self) -> None: super().__init__( name="detect", title="Detect the dialect of a CSV file", @@ -26,7 +26,7 @@ def __init__(self): extra_sections={"CleverCSV": "Part of the CleverCSV suite"}, ) - def register(self): + def register(self) -> None: self.add_argument("path", help="Path to the CSV file") self.add_argument( "-c", @@ -100,7 +100,7 @@ def register(self): help="Add the runtime of the detection to the detection output.", ) - def handle(self): + def handle(self) -> int: verbose = self.args.verbose num_chars = parse_int(self.args.num_chars, "num-chars") method = "consistency" if self.args.consistency else "auto" diff --git a/clevercsv/console/commands/explore.py b/clevercsv/console/commands/explore.py index 6dc7540c..3634f021 100644 --- a/clevercsv/console/commands/explore.py +++ b/clevercsv/console/commands/explore.py @@ -26,7 +26,7 @@ class ExploreCommand(Command): "to read the file as a Pandas dataframe." ) - def __init__(self): + def __init__(self) -> None: super().__init__( name="explore", title="Explore the CSV file in an interactive Python shell", @@ -34,7 +34,7 @@ def __init__(self): extra_sections={"CleverCSV": "Part of the CleverCSV suite"}, ) - def register(self): + def register(self) -> None: self.add_argument("path", help="Path to the CSV file") self.add_argument( "-e", diff --git a/clevercsv/console/commands/standardize.py b/clevercsv/console/commands/standardize.py index 9b120989..002eb7bf 100644 --- a/clevercsv/console/commands/standardize.py +++ b/clevercsv/console/commands/standardize.py @@ -1,19 +1,29 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import io import os import shutil import sys import tempfile +from typing import TYPE_CHECKING +from typing import Optional + from wilderness import Command +from clevercsv._types import StrPath +from clevercsv.dialect import SimpleDialect from clevercsv.encoding import get_encoding from clevercsv.read import reader from clevercsv.utils import sha1sum from clevercsv.wrappers import detect_dialect from clevercsv.write import writer +if TYPE_CHECKING: + from clevercsv._types import SupportsWrite + from ._docs import FLAG_DESCRIPTIONS from ._utils import parse_int @@ -28,7 +38,7 @@ class StandardizeCommand(Command): "[1]: https://tools.ietf.org/html/rfc4180" ) - def __init__(self): + def __init__(self) -> None: super().__init__( name="standardize", title="Convert a CSV file to one that conforms to RFC-4180", @@ -36,7 +46,7 @@ def __init__(self): extra_sections={"CleverCSV": "Part of the CleverCSV suite"}, ) - def register(self): + def register(self) -> None: self.add_argument( "path", help="Path to one or more CSV file(s)", nargs="+" ) @@ -152,7 +162,12 @@ def handle(self) -> int: return global_retval def handle_path( - self, path, output, encoding=None, num_chars=None, verbose=False + self, + path: StrPath, + output: Optional[StrPath], + encoding: Optional[str] = None, + num_chars: Optional[int] = None, + verbose: bool = False, ) -> int: encoding = encoding or get_encoding(path) dialect = detect_dialect( @@ -168,7 +183,13 @@ def handle_path( return self._to_stdout(path, dialect, encoding) return self._to_file(path, output, dialect, encoding) - def _write_transposed(self, path, stream, dialect, encoding): + def _write_transposed( + self, + path: StrPath, + stream: SupportsWrite[str], + dialect: SimpleDialect, + encoding: Optional[str], + ) -> None: with open(path, "r", newline="", encoding=encoding) as fp: read = reader(fp, dialect=dialect) rows = list(read) @@ -177,20 +198,34 @@ def _write_transposed(self, path, stream, dialect, encoding): for row in rows: write.writerow(row) - def _write_direct(self, path, stream, dialect, encoding): + def _write_direct( + self, + path: StrPath, + stream: SupportsWrite[str], + dialect: SimpleDialect, + encoding: Optional[str], + ) -> None: with open(path, "r", newline="", encoding=encoding) as fp: read = reader(fp, dialect=dialect) write = writer(stream, dialect="excel") for row in read: write.writerow(row) - def _write_to_stream(self, path, stream, dialect, encoding): + def _write_to_stream( + self, + path: StrPath, + stream: SupportsWrite[str], + dialect: SimpleDialect, + encoding: Optional[str], + ) -> None: if self.args.transpose: self._write_transposed(path, stream, dialect, encoding) else: self._write_direct(path, stream, dialect, encoding) - def _in_place(self, path, dialect, encoding): + def _in_place( + self, path: StrPath, dialect: SimpleDialect, encoding: Optional[str] + ) -> int: """In-place mode overwrites the input file, if necessary The return value of this method is to be used as the status code of @@ -213,14 +248,22 @@ def _in_place(self, path, dialect, encoding): shutil.move(tmpfname, path) return 2 - def _to_stdout(self, path, dialect, encoding): + def _to_stdout( + self, path: StrPath, dialect: SimpleDialect, encoding: Optional[str] + ) -> int: stream = io.StringIO(newline="") self._write_to_stream(path, stream, dialect, encoding) print(stream.getvalue(), end="") stream.close() return 0 - def _to_file(self, path, output, dialect, encoding): + def _to_file( + self, + path: StrPath, + output: StrPath, + dialect: SimpleDialect, + encoding: Optional[str], + ) -> int: with open(output, "w", newline="", encoding=encoding) as fp: self._write_to_stream(path, fp, dialect, encoding) return 0 diff --git a/clevercsv/console/commands/view.py b/clevercsv/console/commands/view.py index afaaa44d..60ae57c5 100644 --- a/clevercsv/console/commands/view.py +++ b/clevercsv/console/commands/view.py @@ -22,7 +22,7 @@ class ViewCommand(Command): "the command line." ) - def __init__(self): + def __init__(self) -> None: super().__init__( name="view", title="View the CSV file on the command line using TabView", @@ -30,7 +30,7 @@ def __init__(self): extra_sections={"CleverCSV": "Part of the CleverCSV suite"}, ) - def register(self): + def register(self) -> None: self.add_argument("path", help="Path to the CSV file") self.add_argument( "-e", @@ -52,7 +52,7 @@ def register(self): help="Transpose the columns of the input file before viewing", ) - def _tabview(self, rows) -> None: + def _tabview(self, rows: List[List[str]]) -> None: if sys.platform == "win32": print( "Error: unfortunately Tabview is not available on Windows, so " diff --git a/clevercsv/detect.py b/clevercsv/detect.py index 57c1e619..d6b8742e 100644 --- a/clevercsv/detect.py +++ b/clevercsv/detect.py @@ -7,17 +7,37 @@ """ +from enum import Enum from io import StringIO from typing import Dict +from typing import Iterable from typing import Optional from typing import Union from .consistency import ConsistencyDetector +from .dialect import SimpleDialect +from .exceptions import NoDetectionResult from .normal_form import detect_dialect_normal from .read import reader +class DetectionMethod(str, Enum): + """Possible detection methods + + Valid options are `"auto"` (the default for :class:`Detector.detect`), + `"normal"`, or `"consistency"`. The `"auto"` option first attempts to + detect the dialect using normal-form detection, and uses the consistency + measure if normal-form detection is inconclusive. The `"normal"` method + uses normal-form detection excllusively, and the `"consistency"` method + uses the consistency measure exclusively. + """ + + AUTO = "auto" + NORMAL = "normal" + CONSISTENCY = "consistency" + + class Detector: """ Detect the Dialect of CSV files with normal forms or the data consistency @@ -32,18 +52,23 @@ class Detector: """ - def sniff(self, sample, delimiters=None, verbose=False): + def sniff( + self, + sample: str, + delimiters: Optional[Iterable[str]] = None, + verbose: bool = False, + ) -> Optional[SimpleDialect]: # Compatibility method for Python return self.detect(sample, delimiters=delimiters, verbose=verbose) def detect( self, - sample, - delimiters=None, - verbose=False, - method="auto", - skip=True, - ): + sample: str, + delimiters: Optional[Iterable[str]] = None, + verbose: bool = False, + method: Union[DetectionMethod, str] = DetectionMethod.AUTO, + skip: bool = True, + ) -> Optional[SimpleDialect]: """Detect the dialect of a CSV file This method detects the dialect of the CSV file using the specified @@ -64,14 +89,10 @@ def detect( verbose : bool Enable verbose mode. - method : str - The method to use for dialect detection. Valid options are `"auto"` - (the default), `"normal"`, or `"consistency"`. The `"auto"` option - first attempts to detect the dialect using normal-form detection, - and uses the consistency measure if normal-form detection is - inconclusive. The `"normal"` method uses normal-form detection - excllusively, and the `"consistency"` method uses the consistency - measure exclusively. + method : Union[DetectionMethod, str] + The method to use for dialect detection. Possible values are + :class:`DetectionMethod` instances or strings that can be cast to + as such an enum. skip : bool Whether to skip potential dialects that have too low a pattern @@ -86,10 +107,10 @@ def detect( inconclusive. """ - if method not in ("auto", "normal", "consistency"): - raise ValueError(f"Unknown detection method: {method}") - - if method == "normal" or method == "auto": + method = DetectionMethod(method) if isinstance(method, str) else method + if delimiters is not None: + delimiters = list(delimiters) + if method == DetectionMethod.NORMAL or method == DetectionMethod.AUTO: if verbose: print("Running normal form detection ...", flush=True) dialect = detect_dialect_normal( @@ -99,7 +120,7 @@ def detect( self.method_ = "normal" return dialect - self.method_ = "consistency" + self.method_ = DetectionMethod.CONSISTENCY consistency_detector = ConsistencyDetector(skip=skip, verbose=verbose) if verbose: print("Running data consistency measure ...", flush=True) @@ -122,7 +143,11 @@ def has_header(self, sample): # Finally, a 'vote' is taken at the end for each column, adding or # subtracting from the likelihood of the first row being a header. - rdr = reader(StringIO(sample), self.sniff(sample)) + dialect = self.sniff(sample) + if dialect is None: + raise NoDetectionResult + + rdr = reader(StringIO(sample), dialect) header = next(rdr) # assume first row is header diff --git a/clevercsv/detect_pattern.py b/clevercsv/detect_pattern.py index dc8b49bd..8133ea08 100644 --- a/clevercsv/detect_pattern.py +++ b/clevercsv/detect_pattern.py @@ -19,7 +19,7 @@ DEFAULT_EPS_PAT: float = 1e-3 -RE_MULTI_C: Pattern = re.compile(r"C{2,}") +RE_MULTI_C: Pattern[str] = re.compile(r"C{2,}") def pattern_score( diff --git a/clevercsv/detect_type.py b/clevercsv/detect_type.py index 97fe116e..e0635f3f 100644 --- a/clevercsv/detect_type.py +++ b/clevercsv/detect_type.py @@ -16,21 +16,22 @@ from ._regexes import DEFAULT_TYPE_REGEXES from .cparser_util import parse_string +from .dialect import SimpleDialect -DEFAULT_EPS_TYPE = 1e-10 +DEFAULT_EPS_TYPE: float = 1e-10 class TypeDetector: def __init__( self, - patterns: Optional[Dict[str, Pattern]] = None, - strip_whitespace=True, - ): + patterns: Optional[Dict[str, Pattern[str]]] = None, + strip_whitespace: bool = True, + ) -> None: self.patterns = patterns or DEFAULT_TYPE_REGEXES.copy() self.strip_whitespace = strip_whitespace self._register_type_tests() - def _register_type_tests(self): + def _register_type_tests(self) -> None: self._type_tests = [ ("empty", self.is_empty), ("url", self.is_url), @@ -55,7 +56,7 @@ def list_known_types(self) -> List[str]: def is_known_type(self, cell: str, is_quoted: bool = False) -> bool: return self.detect_type(cell, is_quoted=is_quoted) is not None - def detect_type(self, cell: str, is_quoted: bool = False): + def detect_type(self, cell: str, is_quoted: bool = False) -> Optional[str]: cell = cell.strip() if self.strip_whitespace else cell for name, func in self._type_tests: if func(cell, is_quoted=is_quoted): @@ -216,7 +217,9 @@ def gen_known_type(cells): yield td.is_known_type(cell) -def type_score(data, dialect, eps=DEFAULT_EPS_TYPE): +def type_score( + data: str, dialect: SimpleDialect, eps: float = DEFAULT_EPS_TYPE +) -> float: """ Compute the type score as the ratio of cells with a known type. @@ -231,6 +234,11 @@ def type_score(data, dialect, eps=DEFAULT_EPS_TYPE): eps: float the minimum value of the type score + Returns + ------- + type_score: float + The computed type score + """ total = 0 known = 0 diff --git a/clevercsv/dialect.py b/clevercsv/dialect.py index 023622a3..815fc1d8 100644 --- a/clevercsv/dialect.py +++ b/clevercsv/dialect.py @@ -15,6 +15,7 @@ from typing import Any from typing import Dict from typing import Optional +from typing import Tuple from typing import Type from typing import Union @@ -142,18 +143,20 @@ def __repr__(self) -> str: self.escapechar, ) - def __key(self): + def __key( + self, + ) -> Tuple[Optional[str], Optional[str], Optional[str], bool]: return (self.delimiter, self.quotechar, self.escapechar, self.strict) def __hash__(self) -> int: return hash(self.__key()) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, SimpleDialect): return False return self.__key() == other.__key() - def __lt__(self, other): + def __lt__(self, other: Any) -> bool: # This provides a partial order on dialect objects with the goal of # speeding up the consistency measure. if not isinstance(other, SimpleDialect): diff --git a/clevercsv/dict_read_write.py b/clevercsv/dict_read_write.py index 521d10a9..a32264ca 100644 --- a/clevercsv/dict_read_write.py +++ b/clevercsv/dict_read_write.py @@ -61,7 +61,7 @@ def __init__( self.dialect = dialect self.line_num = 0 - def __iter__(self) -> "DictReader": + def __iter__(self) -> "DictReader[_T]": return self @property diff --git a/clevercsv/normal_form.py b/clevercsv/normal_form.py index b496d1e5..c56b0634 100644 --- a/clevercsv/normal_form.py +++ b/clevercsv/normal_form.py @@ -14,19 +14,28 @@ import itertools +from typing import Callable +from typing import Iterable +from typing import List +from typing import Optional +from typing import Tuple + import regex from .dialect import SimpleDialect from .escape import is_potential_escapechar from .utils import pairwise -DELIMS = [",", ";", "|", "\t"] -QUOTECHARS = ["'", '"'] +DELIMS: List[str] = [",", ";", "|", "\t"] +QUOTECHARS: List[str] = ["'", '"'] def detect_dialect_normal( - data, encoding="UTF-8", delimiters=None, verbose=False -): + data: str, + encoding: str = "UTF-8", + delimiters: Optional[Iterable[str]] = None, + verbose: bool = False, +) -> Optional[SimpleDialect]: """Detect the normal form of a file from a given sample Parameters @@ -52,7 +61,9 @@ def detect_dialect_normal( print("Not normal, has potential escapechar.") return None - form_and_dialect = [] + form_and_dialect: List[ + Tuple[int, Callable[[str, SimpleDialect], bool], SimpleDialect] + ] = [] for delim in delimiters: dialect = SimpleDialect(delimiter=delim, quotechar="", escapechar="") @@ -70,6 +81,7 @@ def detect_dialect_normal( delimiter="", quotechar=quotechar, escapechar="" ) form_and_dialect.append((4, is_form_4, dialect)) + form_and_dialect.append( ( 4, @@ -85,19 +97,20 @@ def detect_dialect_normal( return dialect if verbose: print("Didn't match any normal forms.") + return None -def is_quoted_cell(cell, quotechar): +def is_quoted_cell(cell: str, quotechar: str) -> bool: if len(cell) < 2: return False return cell[0] == quotechar and cell[-1] == quotechar -def is_any_quoted_cell(cell): +def is_any_quoted_cell(cell: str) -> bool: return is_quoted_cell(cell, "'") or is_quoted_cell(cell, '"') -def is_any_partial_quoted_cell(cell): +def is_any_partial_quoted_cell(cell: str) -> bool: if len(cell) < 1: return False return ( @@ -105,15 +118,15 @@ def is_any_partial_quoted_cell(cell): ) -def is_empty_quoted(cell, quotechar): +def is_empty_quoted(cell: str, quotechar: str) -> bool: return len(cell) == 2 and is_quoted_cell(cell, quotechar) -def is_empty_unquoted(cell): +def is_empty_unquoted(cell: str) -> bool: return cell == "" -def is_any_empty(cell): +def is_any_empty(cell: str) -> bool: return ( is_empty_unquoted(cell) or is_empty_quoted(cell, "'") @@ -121,15 +134,17 @@ def is_any_empty(cell): ) -def has_delimiter(string, delim): +def has_delimiter(string: str, delim: str) -> bool: return delim in string -def has_nested_quotes(string, quotechar): +def has_nested_quotes(string: str, quotechar: str) -> bool: return quotechar in string[1:-1] -def maybe_has_escapechar(data, encoding, delim, quotechar): +def maybe_has_escapechar( + data: str, encoding: str, delim: str, quotechar: str +) -> bool: if delim not in data and quotechar not in data: return False for u, v in pairwise(data): @@ -138,7 +153,7 @@ def maybe_has_escapechar(data, encoding, delim, quotechar): return False -def strip_trailing_crnl(data): +def strip_trailing_crnl(data: str) -> str: while data.endswith("\n"): data = data.rstrip("\n") while data.endswith("\r"): @@ -146,28 +161,29 @@ def strip_trailing_crnl(data): return data -def every_row_has_delim(rows, dialect): +def every_row_has_delim(rows: List[str], dialect: SimpleDialect) -> bool: + assert dialect.delimiter is not None for row in rows: if not has_delimiter(row, dialect.delimiter): return False return True -def is_elementary(cell): +def is_elementary(cell: str) -> bool: return ( regex.fullmatch(r"[a-zA-Z0-9\.\_\&\-\@\+\%\(\)\ \/]+", cell) is not None ) -def even_rows(rows, dialect): +def even_rows(rows: List[str], dialect: SimpleDialect) -> bool: cells_per_row = set() for row in rows: cells_per_row.add(len(split_row(row, dialect))) return len(cells_per_row) == 1 -def split_file(data): +def split_file(data: str) -> List[str]: data = strip_trailing_crnl(data) if "\r\n" in data: return data.split("\r\n") @@ -175,14 +191,14 @@ def split_file(data): return data.split("\n") elif "\r" in data: return data.split("\r") - else: - return [data] + return [data] -def split_row(row, dialect): +def split_row(row: str, dialect: SimpleDialect) -> List[str]: # no nested quotes - if dialect.quotechar == "" or dialect.quotechar not in row: - if dialect.delimiter == "": + assert dialect.quotechar is not None + if (not dialect.quotechar) or (dialect.quotechar not in row): + if not dialect.delimiter: return [row] return row.split(dialect.delimiter) @@ -203,9 +219,10 @@ def split_row(row, dialect): return cells -def is_form_1(data, dialect=None): +def is_form_1(data: str, dialect: SimpleDialect) -> bool: # All cells quoted, quoted empty allowed, no nested quotes, more than one # column + assert dialect.quotechar is not None rows = split_file(data) @@ -234,7 +251,7 @@ def is_form_1(data, dialect=None): return True -def is_form_2(data, dialect): +def is_form_2(data: str, dialect: SimpleDialect) -> bool: # All unquoted, empty allowed, all elementary rows = split_file(data) @@ -261,8 +278,9 @@ def is_form_2(data, dialect): return True -def is_form_3(data, dialect): +def is_form_3(data: str, dialect: SimpleDialect) -> bool: # some quoted, some not quoted, no empty, no nested quotes + assert dialect.quotechar is not None rows = split_file(data) @@ -297,8 +315,10 @@ def is_form_3(data, dialect): return True -def is_form_4(data, dialect): +def is_form_4(data: str, dialect: SimpleDialect) -> bool: # no delim, single column (either entirely quoted or entirely unquoted) + assert dialect.quotechar is not None + rows = split_file(data) if len(rows) <= 1: @@ -322,7 +342,7 @@ def is_form_4(data, dialect): return True -def is_form_5(data, dialect): +def is_form_5(data: str, dialect: SimpleDialect) -> bool: # all rows quoted, no nested quotes # basically form 2 but with quotes around each row diff --git a/clevercsv/potential_dialects.py b/clevercsv/potential_dialects.py index 5cd15c73..f8b74e68 100644 --- a/clevercsv/potential_dialects.py +++ b/clevercsv/potential_dialects.py @@ -12,6 +12,10 @@ import unicodedata from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Set from ._regexes import PATTERN_URL from .dialect import SimpleDialect @@ -20,8 +24,11 @@ def get_dialects( - data, encoding="UTF-8", delimiters=None, test_masked_by_quotes=False -): + data: str, + encoding: str = "UTF-8", + delimiters: Optional[List[str]] = None, + test_masked_by_quotes: bool = False, +) -> List[SimpleDialect]: """Return the possible dialects for the given data. We consider as escape characters those characters for which @@ -56,7 +63,7 @@ def get_dialects( Returns ------- - dialects: list + dialects: List[SimpleDialect] List of SimpleDialect objects that are considered potential dialects. """ @@ -97,7 +104,7 @@ def get_dialects( return dialects -def unicode_category(x, encoding=None): +def unicode_category(x: str, encoding: str) -> str: """Return the Unicode category of a character Parameters @@ -118,14 +125,18 @@ def unicode_category(x, encoding=None): return unicodedata.category(as_unicode) -def filter_urls(data): +def filter_urls(data: str) -> str: """Filter URLs from the data""" return PATTERN_URL.sub("U", data) def get_delimiters( - data, encoding, delimiters=None, block_cat=None, block_char=None -): + data: str, + encoding: str, + delimiters: Optional[List[str]] = None, + block_cat: Optional[List[str]] = None, + block_char: Optional[List[str]] = None, +) -> Set[str]: """Get potential delimiters The set of potential delimiters is constructed as follows. For each unique @@ -200,7 +211,9 @@ def get_delimiters( return D -def get_quotechars(data, quote_chars=None): +def get_quotechars( + data: str, quote_chars: Optional[Iterable[str]] = None +) -> Set[str]: """Get potential quote characters Quote characters are those that occur in the ``quote_chars`` set and are @@ -233,7 +246,9 @@ def get_quotechars(data, quote_chars=None): return Q -def masked_by_quotechar(data, quotechar, escapechar, test_char): +def masked_by_quotechar( + data: str, quotechar: str, escapechar: str, test_char: str +) -> bool: """Test if a character is always masked by quote characters This function tests if a given character is always within quoted segments diff --git a/clevercsv/utils.py b/clevercsv/utils.py index 891a4887..f969f362 100644 --- a/clevercsv/utils.py +++ b/clevercsv/utils.py @@ -9,8 +9,17 @@ import hashlib +from typing import Iterable +from typing import Iterator +from typing import Tuple +from typing import TypeVar -def pairwise(iterable): +from clevercsv._types import AnyPath + +T = TypeVar("T") + + +def pairwise(iterable: Iterable[T]) -> Iterator[Tuple[T, T]]: "s - > (s0, s1), (s1, s2), (s2, s3), ..." a = iter(iterable) b = iter(iterable) @@ -18,7 +27,7 @@ def pairwise(iterable): return zip(a, b) -def sha1sum(filename): +def sha1sum(filename: AnyPath) -> str: """Compute the SHA1 checksum of a given file Parameters diff --git a/clevercsv/wrappers.py b/clevercsv/wrappers.py index 807c33f9..18dc6d7f 100644 --- a/clevercsv/wrappers.py +++ b/clevercsv/wrappers.py @@ -95,6 +95,10 @@ def stream_dicts( data = fid.read(num_chars) if num_chars else fid.read() dialect = Detector().detect(data, verbose=verbose) fid.seek(0) + + if dialect is None: + raise NoDetectionResult + reader: DictReader = DictReader(fid, dialect=dialect) for row in reader: yield row @@ -326,6 +330,10 @@ def read_dataframe( with open(filename, "r", newline="", encoding=enc) as fid: data = fid.read(num_chars) if num_chars else fid.read() dialect = Detector().detect(data) + + if dialect is None: + raise NoDetectionResult + csv_dialect = dialect.to_csv_dialect() # This is used to catch pandas' warnings when a dialect is supplied. @@ -346,7 +354,7 @@ def detect_dialect( verbose: bool = False, method: str = "auto", skip: bool = True, -) -> SimpleDialect: +) -> Optional[SimpleDialect]: """Detect the dialect of a CSV file This is a utility function that simply returns the detected dialect of a @@ -379,7 +387,7 @@ def detect_dialect( Returns ------- - dialect : SimpleDialect + dialect : Optional[SimpleDialect] The detected dialect as a :class:`SimpleDialect`, or None if detection failed. diff --git a/clevercsv/write.py b/clevercsv/write.py index a0e3403b..71a3577a 100644 --- a/clevercsv/write.py +++ b/clevercsv/write.py @@ -40,10 +40,10 @@ class writer: def __init__( self, - csvfile: SupportsWrite, + csvfile: SupportsWrite[str], dialect: _DialectLike = "excel", - **fmtparams, - ): + **fmtparams: Any, + ) -> None: self.original_dialect = dialect self.dialect: Type[csv.Dialect] = self._make_python_dialect( dialect, **fmtparams @@ -51,7 +51,7 @@ def __init__( self._writer = csv.writer(csvfile, dialect=self.dialect) def _make_python_dialect( - self, dialect: _DialectLike, **fmtparams + self, dialect: _DialectLike, **fmtparams: Any ) -> Type[csv.Dialect]: d: _DialectLike = "" if isinstance(dialect, str): diff --git a/pyproject.toml b/pyproject.toml index a834bd6a..b9ec849a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,18 @@ exclude = ["stubs"] [tool.mypy] python_version = 3.8 warn_unused_configs = true +warn_redundant_casts = true +warn_unused_ignores = true +strict_equality = true +strict_concatenate = true +check_untyped_defs = true +disallow_subclassing_any = true +disallow_untyped_decorators = true +disallow_any_generics = true +disallow_untyped_calls = true +disallow_incomplete_defs = true +disallow_untyped_defs = false -# [[tool.mypy.overrides]] -# packages = ["stubs", "clevercsv", "tests"] +[[tool.mypy.overrides]] +packages = ["stubs", "clevercsv"] +disallow_incomplete_defs = true diff --git a/stubs/pythonfuzz/main.pyi b/stubs/pythonfuzz/main.pyi index f13dbe8b..046b9bc0 100644 --- a/stubs/pythonfuzz/main.pyi +++ b/stubs/pythonfuzz/main.pyi @@ -2,5 +2,5 @@ from typing import Any from typing import Callable class PythonFuzz: - def __init__(self, func: Callable) -> None: ... + def __init__(self, func: Callable[[bytes], Any]) -> None: ... def __call__(self, *args: Any, **kwargs: Any) -> None: ... diff --git a/stubs/wilderness/__init__.pyi b/stubs/wilderness/__init__.pyi index 47e061a2..a766f99c 100644 --- a/stubs/wilderness/__init__.pyi +++ b/stubs/wilderness/__init__.pyi @@ -19,11 +19,11 @@ class DocumentableMixin(metaclass=abc.ABCMeta): @property def parser(self) -> argparse.ArgumentParser: ... @parser.setter - def parser(self, parser: argparse.ArgumentParser): ... + def parser(self, parser: argparse.ArgumentParser) -> None: ... @property def args(self) -> argparse.Namespace: ... @args.setter - def args(self, args: argparse.Namespace): ... + def args(self, args: argparse.Namespace) -> None: ... @property def argument_help(self) -> Dict[str, Optional[str]]: ... @@ -145,16 +145,16 @@ class ArgumentGroup: @property def command(self) -> Optional[Command]: ... @command.setter - def command(self, command: Command): ... - def add_argument(self, *args, **kwargs): ... + def command(self, command: Command) -> None: ... + def add_argument(self, *args, **kwargs) -> None: ... class MutuallyExclusiveGroup: def __init__(self, meg: argparse._MutuallyExclusiveGroup) -> None: ... @property def command(self) -> Optional[Command]: ... @command.setter - def command(self, command: Command): ... - def add_argument(self, *args, **kwargs): ... + def command(self, command: Command) -> None: ... + def add_argument(self, *args, **kwargs) -> None: ... class Tester: def __init__(self, app: Application) -> None: ... diff --git a/tests/test_integration/test_dialect_detection.py b/tests/test_integration/test_dialect_detection.py index 7308ea73..c97d191e 100644 --- a/tests/test_integration/test_dialect_detection.py +++ b/tests/test_integration/test_dialect_detection.py @@ -15,11 +15,19 @@ import time import warnings +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + import chardet import termcolor import clevercsv +from clevercsv.dialect import SimpleDialect + THIS_DIR = os.path.abspath(os.path.dirname(__file__)) SOURCE_DIR = os.path.join(THIS_DIR, "data") TEST_FILES = os.path.join(SOURCE_DIR, "files") @@ -41,7 +49,7 @@ N_BYTES_PARTIAL = 10000 -def log_result(name, kind, verbose, partial): +def log_result(name: str, kind: str, verbose: int, partial: bool) -> None: table = { "error": (LOG_ERROR, LOG_ERROR_PARTIAL, "yellow"), "success": (LOG_SUCCESS, LOG_SUCCESS_PARTIAL, "green"), @@ -57,19 +65,21 @@ def log_result(name, kind, verbose, partial): termcolor.cprint(name, color=color) -def log_method(name, method, partial): +def log_method(name: str, method: str, partial: bool) -> None: fname = LOG_METHOD_PARTIAL if partial else LOG_METHOD with open(fname, "a") as fp: fp.write(f"{name},{method}\n") -def log_runtime(name, runtime, partial): +def log_runtime(name: str, runtime: float, partial: bool) -> None: fname = LOG_RUNTIME_PARTIAL if partial else LOG_RUNTIME with open(fname, "a") as fp: fp.write(f"{name},{runtime}\n") -def worker(args, return_dict, **kwargs): +def worker( + args: List[Any], return_dict: Dict[str, Any], **kwargs: Any +) -> None: det = clevercsv.Detector() filename, encoding, partial = args return_dict["error"] = False @@ -87,7 +97,9 @@ def worker(args, return_dict, **kwargs): return_dict["error"] = True -def run_with_timeout(args, kwargs, limit): +def run_with_timeout( + args: Tuple[Any, ...], kwargs: Dict[str, Any], limit: Optional[int] +) -> Tuple[Optional[SimpleDialect], bool, Optional[str], float]: manager = multiprocessing.Manager() return_dict = manager.dict() p = multiprocessing.Process( @@ -106,7 +118,13 @@ def run_with_timeout(args, kwargs, limit): ) -def run_test(name, gz_filename, annotation, verbose=1, partial=False): +def run_test( + name: str, + gz_filename: str, + annotation: Dict[str, Any], + verbose: int = 1, + partial: bool = False, +) -> None: if "encoding" in annotation: enc = annotation["encoding"] else: @@ -131,11 +149,12 @@ def run_test(name, gz_filename, annotation, verbose=1, partial=False): else: log_result(name, "success", verbose, partial) + assert method is not None log_method(name, method, partial) log_runtime(name, runtime, partial) -def load_test_cases(): +def load_test_cases() -> List[Tuple[str, str, Dict[str, Any]]]: cases = [] for f in sorted(os.listdir(TEST_FILES)): base = f[: -len(".csv.gz")] @@ -157,7 +176,7 @@ def load_test_cases(): return cases -def clear_output_files(partial): +def clear_output_files(partial: bool) -> None: files = { True: [ LOG_SUCCESS_PARTIAL, @@ -173,7 +192,7 @@ def clear_output_files(partial): os.unlink(filename) -def parse_args(): +def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument( "--partial", @@ -184,7 +203,7 @@ def parse_args(): return parser.parse_args() -def main(): +def main() -> None: args = parse_args() clear_output_files(args.partial) cases = load_test_cases() diff --git a/tests/test_unit/test_abstraction.py b/tests/test_unit/test_abstraction.py index 264b7943..4fb5cd25 100644 --- a/tests/test_unit/test_abstraction.py +++ b/tests/test_unit/test_abstraction.py @@ -22,7 +22,7 @@ class AbstractionTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: here = Path(__file__) this_dir = here.parent data_dir = this_dir / "data" @@ -40,7 +40,7 @@ def _load_cases(filename: Path) -> List[Dict[str, Any]]: cases.append(json.loads(line)) return cases - def test_abstraction_multi(self): + def test_abstraction_multi(self) -> None: if not self._cases: self.skipTest("no abstraction test cases found") diff --git a/tests/test_unit/test_consistency.py b/tests/test_unit/test_consistency.py index 74e86ff4..82c2df1d 100644 --- a/tests/test_unit/test_consistency.py +++ b/tests/test_unit/test_consistency.py @@ -15,7 +15,7 @@ class ConsistencyTestCase(unittest.TestCase): - def test_get_best_set_1(self): + def test_get_best_set_1(self) -> None: scores = { SimpleDialect(",", None, None): ConsistencyScore(P=1, T=1, Q=1), SimpleDialect(";", None, None): ConsistencyScore( diff --git a/tests/test_unit/test_console.py b/tests/test_unit/test_console.py index 29595575..ef67f2df 100644 --- a/tests/test_unit/test_console.py +++ b/tests/test_unit/test_console.py @@ -11,19 +11,29 @@ import tempfile import unittest +from typing import Any from typing import List -from typing import Union +from typing import Optional from wilderness import Tester from clevercsv import __version__ +from clevercsv._types import _DialectLike from clevercsv.console import build_application from clevercsv.dialect import SimpleDialect from clevercsv.write import writer +TableType = List[List[Any]] + class ConsoleTestCase(unittest.TestCase): - def _build_file(self, table, dialect, encoding=None, newline=None) -> str: + def _build_file( + self, + table: TableType, + dialect: _DialectLike, + encoding: Optional[str] = None, + newline: Optional[str] = None, + ) -> str: tmpfd, tmpfname = tempfile.mkstemp( prefix="ccsv_", suffix=".csv", @@ -34,7 +44,9 @@ def _build_file(self, table, dialect, encoding=None, newline=None) -> str: tmpid.close() return tmpfname - def _detect_test_wrap(self, table, dialect): + def _detect_test_wrap( + self, table: TableType, dialect: _DialectLike + ) -> None: tmpfname = self._build_file(table, dialect) exp = "Detected: " + str(dialect) @@ -51,8 +63,8 @@ def _detect_test_wrap(self, table, dialect): finally: os.unlink(tmpfname) - def test_detect_base(self): - table = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] + def test_detect_base(self) -> None: + table: TableType = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] dialect = SimpleDialect(delimiter=";", quotechar="", escapechar="") with self.subTest(name="simple"): self._detect_test_wrap(table, dialect) @@ -72,8 +84,8 @@ def test_detect_base(self): with self.subTest(name="double"): self._detect_test_wrap(table, dialect) - def test_detect_opts_1(self): - table = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] + def test_detect_opts_1(self) -> None: + table: TableType = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] dialect = SimpleDialect(delimiter=";", quotechar="", escapechar="") encoding = "windows-1252" tmpfname = self._build_file(table, dialect, encoding=encoding) @@ -93,8 +105,8 @@ def test_detect_opts_1(self): finally: os.unlink(tmpfname) - def test_detect_opts_2(self): - table = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] + def test_detect_opts_2(self) -> None: + table: TableType = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] dialect = SimpleDialect(delimiter=";", quotechar="", escapechar="") tmpfname = self._build_file(table, dialect) @@ -113,8 +125,8 @@ def test_detect_opts_2(self): finally: os.unlink(tmpfname) - def test_detect_opts_3(self): - table = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] + def test_detect_opts_3(self) -> None: + table: TableType = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] dialect = SimpleDialect(delimiter=";", quotechar="", escapechar="") tmpfname = self._build_file(table, dialect) @@ -135,8 +147,8 @@ def test_detect_opts_3(self): finally: os.unlink(tmpfname) - def test_detect_opts_4(self): - table = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] + def test_detect_opts_4(self) -> None: + table: TableType = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] dialect = SimpleDialect(delimiter=";", quotechar="", escapechar="") tmpfname = self._build_file(table, dialect) @@ -157,8 +169,8 @@ def test_detect_opts_4(self): finally: os.unlink(tmpfname) - def test_code_1(self): - table = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] + def test_code_1(self) -> None: + table: TableType = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] dialect = SimpleDialect(delimiter=";", quotechar="", escapechar="") tmpfname = self._build_file(table, dialect) @@ -198,8 +210,8 @@ def test_code_1(self): finally: os.unlink(tmpfname) - def test_code_2(self): - table = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] + def test_code_2(self) -> None: + table: TableType = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] dialect = SimpleDialect(delimiter=";", quotechar="", escapechar="") tmpfname = self._build_file(table, dialect) @@ -222,8 +234,8 @@ def test_code_2(self): finally: os.unlink(tmpfname) - def test_code_3(self): - table = [["Å", "B", "C"], [1, 2, 3], [4, 5, 6]] + def test_code_3(self) -> None: + table: TableType = [["Å", "B", "C"], [1, 2, 3], [4, 5, 6]] dialect = SimpleDialect(delimiter=";", quotechar="", escapechar="") encoding = "ISO-8859-1" tmpfname = self._build_file(table, dialect, encoding=encoding) @@ -263,8 +275,8 @@ def test_code_3(self): finally: os.unlink(tmpfname) - def test_code_4(self): - table = [["Å", "B,D", "C"], [1, 2, 3], [4, 5, 6]] + def test_code_4(self) -> None: + table: TableType = [["Å", "B,D", "C"], [1, 2, 3], [4, 5, 6]] dialect = SimpleDialect(delimiter=",", quotechar="", escapechar="\\") encoding = "ISO-8859-1" tmpfname = self._build_file(table, dialect, encoding=encoding) @@ -304,8 +316,8 @@ def test_code_4(self): finally: os.unlink(tmpfname) - def test_code_5(self): - table = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] + def test_code_5(self) -> None: + table: TableType = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] dialect = SimpleDialect(delimiter="\t", quotechar="", escapechar="") tmpfname = self._build_file(table, dialect) @@ -346,8 +358,8 @@ def test_code_5(self): finally: os.unlink(tmpfname) - def test_standardize_1(self): - table = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] + def test_standardize_1(self) -> None: + table: TableType = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] dialect = SimpleDialect(delimiter=";", quotechar="", escapechar="") tmpfname = self._build_file(table, dialect) @@ -366,8 +378,8 @@ def test_standardize_1(self): finally: os.unlink(tmpfname) - def test_standardize_2(self): - table = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] + def test_standardize_2(self) -> None: + table: TableType = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] dialect = SimpleDialect(delimiter=";", quotechar="", escapechar="") tmpfname = self._build_file(table, dialect) @@ -390,8 +402,8 @@ def test_standardize_2(self): os.unlink(tmpfname) os.unlink(tmpoutname) - def test_standardize_3(self): - table = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] + def test_standardize_3(self) -> None: + table: TableType = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] dialect = SimpleDialect(delimiter=";", quotechar="", escapechar="") tmpfname = self._build_file(table, dialect) @@ -411,8 +423,8 @@ def test_standardize_3(self): finally: os.unlink(tmpfname) - def test_standardize_in_place(self): - table = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] + def test_standardize_in_place(self) -> None: + table: TableType = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] dialect = SimpleDialect(delimiter=";", quotechar="", escapechar="") tmpfname = self._build_file(table, dialect) @@ -435,8 +447,8 @@ def test_standardize_in_place(self): finally: os.unlink(tmpfname) - def test_standardize_in_place_noop(self): - table = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] + def test_standardize_in_place_noop(self) -> None: + table: TableType = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] dialect = "excel" tmpfname = self._build_file(table, dialect, newline="") @@ -459,8 +471,8 @@ def test_standardize_in_place_noop(self): finally: os.unlink(tmpfname) - def test_standardize_multi(self): - table: List[List[Union[str, int]]] = [ + def test_standardize_multi(self) -> None: + table: TableType = [ ["A", "B", "C"], [1, 2, 3], [4, 5, 6], @@ -497,8 +509,8 @@ def test_standardize_multi(self): any(map(os.unlink, tmpfnames)) any(map(os.unlink, tmpoutnames)) - def test_standardize_multi_errors(self): - table: List[List[Union[str, int]]] = [ + def test_standardize_multi_errors(self) -> None: + table: TableType = [ ["A", "B", "C"], [1, 2, 3], [4, 5, 6], @@ -532,8 +544,8 @@ def test_standardize_multi_errors(self): any(map(os.unlink, tmpfnames)) any(map(os.unlink, tmpoutnames)) - def test_standardize_multi_encoding(self): - table: List[List[Union[str, int]]] = [ + def test_standardize_multi_encoding(self) -> None: + table: TableType = [ ["Å", "B", "C"], [1, 2, 3], [4, 5, 6], @@ -576,8 +588,8 @@ def test_standardize_multi_encoding(self): any(map(os.unlink, tmpfnames)) any(map(os.unlink, tmpoutnames)) - def test_standardize_in_place_multi(self): - table: List[List[Union[str, int]]] = [ + def test_standardize_in_place_multi(self) -> None: + table: TableType = [ ["Å", "B", "C"], [1, 2, 3], [4, 5, 6], @@ -591,7 +603,7 @@ def test_standardize_in_place_multi(self): application = build_application() tester = Tester(application) - tester.test_command("standardize", ["-i", "-e", encoding] + tmpfnames) + tester.test_command("standardize", ["-i", "-e", encoding, *tmpfnames]) self.assertEqual(tester.get_return_code(), 2) @@ -605,8 +617,8 @@ def test_standardize_in_place_multi(self): finally: any(map(os.unlink, tmpfnames)) - def test_standardize_in_place_multi_noop(self): - table: List[List[Union[str, int]]] = [ + def test_standardize_in_place_multi_noop(self) -> None: + table: TableType = [ ["Å", "B", "C"], [1, 2, 3], [4, 5, 6], @@ -616,7 +628,7 @@ def test_standardize_in_place_multi_noop(self): application = build_application() tester = Tester(application) - tester.test_command("standardize", ["-i"] + tmpfnames) + tester.test_command("standardize", ["-i", *tmpfnames]) self.assertEqual(tester.get_return_code(), 0) diff --git a/tests/test_unit/test_cparser.py b/tests/test_unit/test_cparser.py index 1f6160f3..3fc9f95f 100644 --- a/tests/test_unit/test_cparser.py +++ b/tests/test_unit/test_cparser.py @@ -10,8 +10,15 @@ import io import unittest +from typing import Any +from typing import List +from typing import Tuple +from typing import TypeVar + from clevercsv.cparser_util import parse_data +T = TypeVar("T", str, Tuple[str, bool]) + class ParserTestCase(unittest.TestCase): @@ -19,12 +26,14 @@ class ParserTestCase(unittest.TestCase): Testing splitting on delimiter with or without quotes """ - def _parse_test(self, string, expect, **kwargs): + def _parse_test( + self, string: str, expect: List[List[T]], **kwargs: Any + ) -> None: buf = io.StringIO(string, newline="") result = list(parse_data(buf, **kwargs)) self.assertEqual(result, expect) - def test_parse_simple_1(self): + def test_parse_simple_1(self) -> None: self._parse_test( "A,B,C,D,E", [["A", "B", "C", "D", "E"]], @@ -32,7 +41,7 @@ def test_parse_simple_1(self): quotechar='"', ) - def test_parse_simple_2(self): + def test_parse_simple_2(self) -> None: self._parse_test( "A,B,C,D,E", [["A", "B", "C", "D", "E"]], @@ -40,10 +49,10 @@ def test_parse_simple_2(self): quotechar="", ) - def test_parse_simple_3(self): + def test_parse_simple_3(self) -> None: self._parse_test("A,B,C,D,E", [["A,B,C,D,E"]]) - def test_parse_simple_4(self): + def test_parse_simple_4(self) -> None: self._parse_test( 'A,"B",C,D,E', [["A", "B", "C", "D", "E"]], @@ -51,7 +60,7 @@ def test_parse_simple_4(self): quotechar='"', ) - def test_parse_simple_5(self): + def test_parse_simple_5(self) -> None: self._parse_test( 'A,"B,C",D,E', [["A", "B,C", "D", "E"]], @@ -59,7 +68,7 @@ def test_parse_simple_5(self): quotechar='"', ) - def test_parse_simple_6(self): + def test_parse_simple_6(self) -> None: self._parse_test( 'A,"B,C",D,E', [["A", '"B', 'C"', "D", "E"]], @@ -67,7 +76,7 @@ def test_parse_simple_6(self): quotechar="", ) - def test_parse_simple_7(self): + def test_parse_simple_7(self) -> None: self._parse_test( '"A","B","C",,,,', [['"A"', '"B"', '"C"', "", "", "", ""]], @@ -79,36 +88,36 @@ def test_parse_simple_7(self): Testing splitting on rows only: """ - def test_parse_no_delim_1(self): + def test_parse_no_delim_1(self) -> None: self._parse_test( 'A"B"C\rA"B""C""D"', [['A"B"C'], ['A"B""C""D"']], quotechar="" ) - def test_parse_no_delim_2(self): + def test_parse_no_delim_2(self) -> None: self._parse_test( 'A"B"C\nA"B""C""D"', [['A"B"C'], ['A"B""C""D"']], quotechar="" ) - def test_parse_no_delim_3(self): + def test_parse_no_delim_3(self) -> None: self._parse_test( 'A"B"C\r\nA"B""C""D"', [['A"B"C'], ['A"B""C""D"']], quotechar="" ) - def test_parse_no_delim_4(self): + def test_parse_no_delim_4(self) -> None: self._parse_test( 'A"B\r\nB"C\r\nD"E"F\r\nG', [['A"B\r\nB"C'], ['D"E"F'], ["G"]], quotechar='"', ) - def test_parse_no_delim_5(self): + def test_parse_no_delim_5(self) -> None: self._parse_test( 'A"B\nB"C\nD"E"F\nG', [['A"B\nB"C'], ['D"E"F'], ["G"]], quotechar='"', ) - def test_parse_no_delim_6(self): + def test_parse_no_delim_6(self) -> None: self._parse_test( 'A"B\nB\rB"C\nD"E"F\nG', [['A"B\nB\rB"C'], ['D"E"F'], ["G"]], @@ -119,25 +128,25 @@ def test_parse_no_delim_6(self): Tests from Pythons builtin CSV module: """ - def test_parse_builtin_1(self): + def test_parse_builtin_1(self) -> None: self._parse_test("", []) - def test_parse_builtin_2(self): + def test_parse_builtin_2(self) -> None: self._parse_test("a,b\r", [["a", "b"]], delimiter=",") - def test_parse_builtin_3(self): + def test_parse_builtin_3(self) -> None: self._parse_test("a,b\n", [["a", "b"]], delimiter=",") - def test_parse_builtin_4(self): + def test_parse_builtin_4(self) -> None: self._parse_test("a,b\r\n", [["a", "b"]], delimiter=",") - def test_parse_builtin_5(self): + def test_parse_builtin_5(self) -> None: self._parse_test('a,"', [["a", ""]], delimiter=",", quotechar='"') - def test_parse_builtin_6(self): + def test_parse_builtin_6(self) -> None: self._parse_test('"a', [["a"]], delimiter=",", quotechar='"') - def test_parse_builtin_7(self): + def test_parse_builtin_7(self) -> None: # differs from Python (1) self._parse_test( "a,|b,c", @@ -147,7 +156,7 @@ def test_parse_builtin_7(self): escapechar="|", ) - def test_parse_builtin_8(self): + def test_parse_builtin_8(self) -> None: self._parse_test( "a,b|,c", [["a", "b,c"]], @@ -156,7 +165,7 @@ def test_parse_builtin_8(self): escapechar="|", ) - def test_parse_builtin_9(self): + def test_parse_builtin_9(self) -> None: # differs from Python (1) self._parse_test( 'a,"b,|c"', @@ -166,7 +175,7 @@ def test_parse_builtin_9(self): escapechar="|", ) - def test_parse_builtin_10(self): + def test_parse_builtin_10(self) -> None: self._parse_test( 'a,"b,c|""', [["a", 'b,c"']], @@ -175,7 +184,7 @@ def test_parse_builtin_10(self): escapechar="|", ) - def test_parse_builtin_11(self): + def test_parse_builtin_11(self) -> None: # differs from Python (2) self._parse_test( 'a,"b,c"|', @@ -185,12 +194,12 @@ def test_parse_builtin_11(self): escapechar="|", ) - def test_parse_builtin_12(self): + def test_parse_builtin_12(self) -> None: self._parse_test( '1,",3,",5', [["1", ",3,", "5"]], delimiter=",", quotechar='"' ) - def test_parse_builtin_13(self): + def test_parse_builtin_13(self) -> None: self._parse_test( '1,",3,",5', [["1", '"', "3", '"', "5"]], @@ -198,7 +207,7 @@ def test_parse_builtin_13(self): quotechar="", ) - def test_parse_builtin_14(self): + def test_parse_builtin_14(self) -> None: self._parse_test( ',3,"5",7.3, 9', [["", "3", "5", "7.3", " 9"]], @@ -206,7 +215,7 @@ def test_parse_builtin_14(self): quotechar='"', ) - def test_parse_builtin_15(self): + def test_parse_builtin_15(self) -> None: self._parse_test( '"a\nb", 7', [["a\nb", " 7"]], delimiter=",", quotechar='"' ) @@ -215,12 +224,12 @@ def test_parse_builtin_15(self): Double quotes: """ - def test_parse_dq_1(self): + def test_parse_dq_1(self) -> None: self._parse_test( 'a,"a""b""c"', [["a", 'a"b"c']], delimiter=",", quotechar='"' ) - def test_parse_dq_2(self): + def test_parse_dq_2(self) -> None: self._parse_test( 'a,"a""b,c""d",e', [["a", 'a"b,c"d', "e"]], @@ -232,7 +241,7 @@ def test_parse_dq_2(self): Mix double and escapechar: """ - def test_parse_mix_double_escape_1(self): + def test_parse_mix_double_escape_1(self) -> None: self._parse_test( 'a,"bc""d"",|"f|""', [["a", 'bc"d","f"']], @@ -245,12 +254,12 @@ def test_parse_mix_double_escape_1(self): Other tests: """ - def test_parse_other_1(self): + def test_parse_other_1(self) -> None: self._parse_test( 'a,b "c" d,e', [["a", 'b "c" d', "e"]], delimiter=",", quotechar="" ) - def test_parse_other_2(self): + def test_parse_other_2(self) -> None: self._parse_test( 'a,b "c" d,e', [["a", 'b "c" d', "e"]], @@ -258,22 +267,22 @@ def test_parse_other_2(self): quotechar='"', ) - def test_parse_other_3(self): + def test_parse_other_3(self) -> None: self._parse_test("a,\rb,c", [["a", ""], ["b", "c"]], delimiter=",") - def test_parse_other_4(self): + def test_parse_other_4(self) -> None: self._parse_test( "a,b\r\n\r\nc,d\r\n", [["a", "b"], [], ["c", "d"]], delimiter="," ) - def test_parse_other_5(self): + def test_parse_other_5(self) -> None: self._parse_test( "\r\na,b\rc,d\n\re,f\r\n", [[], ["a", "b"], ["c", "d"], [], ["e", "f"]], delimiter=",", ) - def test_parse_other_6(self): + def test_parse_other_6(self) -> None: self._parse_test( "a,b\n\nc,d", [["a", "b"], [], ["c", "d"]], delimiter="," ) @@ -282,7 +291,7 @@ def test_parse_other_6(self): Further escape char tests: """ - def test_parse_escape_1(self): + def test_parse_escape_1(self) -> None: self._parse_test( "a,b,c||d", [["a", "b", "c|d"]], @@ -291,7 +300,7 @@ def test_parse_escape_1(self): escapechar="|", ) - def test_parse_escape_2(self): + def test_parse_escape_2(self) -> None: self._parse_test( "a,b,c||d,e|,d", [["a", "b", "c|d", "e,d"]], @@ -304,7 +313,7 @@ def test_parse_escape_2(self): Quote mismatch until EOF: """ - def test_parse_quote_mismatch_1(self): + def test_parse_quote_mismatch_1(self) -> None: self._parse_test( 'a,b,c"d,e\n', [["a", "b", 'c"d,e\n']], @@ -312,7 +321,7 @@ def test_parse_quote_mismatch_1(self): quotechar='"', ) - def test_parse_quote_mismatch_2(self): + def test_parse_quote_mismatch_2(self) -> None: self._parse_test( 'a,b,c"d,e\n', [["a", "b", 'c"d', "e"]], @@ -320,12 +329,12 @@ def test_parse_quote_mismatch_2(self): quotechar="", ) - def test_parse_quote_mismatch_3(self): + def test_parse_quote_mismatch_3(self) -> None: self._parse_test( 'a,b,"c,d', [["a", "b", "c,d"]], delimiter=",", quotechar='"' ) - def test_parse_quote_mismatch_4(self): + def test_parse_quote_mismatch_4(self) -> None: self._parse_test( 'a,b,"c,d\n', [["a", "b", "c,d\n"]], delimiter=",", quotechar='"' ) @@ -334,7 +343,7 @@ def test_parse_quote_mismatch_4(self): Single column: """ - def test_parse_single_1(self): + def test_parse_single_1(self) -> None: self._parse_test("a\rb\rc\n", [["a"], ["b"], ["c"]]) """ @@ -342,12 +351,12 @@ def test_parse_single_1(self): case would return ``[['a', 'abc', 'd']]``. """ - def test_parse_differ_1(self): + def test_parse_differ_1(self) -> None: self._parse_test( 'a,"ab"c,d', [["a", '"ab"c', "d"]], delimiter=",", quotechar="" ) - def test_parse_differ_2(self): + def test_parse_differ_2(self) -> None: self._parse_test( 'a,"ab"c,d', [["a", '"ab"c', "d"]], delimiter=",", quotechar='"' ) @@ -356,7 +365,7 @@ def test_parse_differ_2(self): Return quoted """ - def test_parse_return_quoted_1(self): + def test_parse_return_quoted_1(self) -> None: self._parse_test( "a,b,c", [[("a", False), ("b", False), ("c", False)]], @@ -365,7 +374,7 @@ def test_parse_return_quoted_1(self): return_quoted=True, ) - def test_parse_return_quoted_2(self): + def test_parse_return_quoted_2(self) -> None: self._parse_test( 'a,"b,c",d', [[("a", False), ("b,c", True), ("d", False)]], @@ -374,7 +383,7 @@ def test_parse_return_quoted_2(self): return_quoted=True, ) - def test_parse_return_quoted_3(self): + def test_parse_return_quoted_3(self) -> None: self._parse_test( 'a,b,"c,d', [[("a", False), ("b", False), ("c,d", True)]], diff --git a/tests/test_unit/test_detect.py b/tests/test_unit/test_detect.py index 7f165c75..7b6303fc 100644 --- a/tests/test_unit/test_detect.py +++ b/tests/test_unit/test_detect.py @@ -84,57 +84,70 @@ class DetectorTestCase(unittest.TestCase): "{""fake"": ""json"", ""fake2"":""json2""}",00:02:51,20:04:45-06:00 """ - def test_detect(self): + def test_detect(self) -> None: # Adapted from CPython detector = Detector() dialect = detector.detect(self.sample1) + assert dialect is not None self.assertEqual(dialect.delimiter, ",") self.assertEqual(dialect.quotechar, "") self.assertEqual(dialect.escapechar, "") dialect = detector.detect(self.sample2) + assert dialect is not None self.assertEqual(dialect.delimiter, ":") self.assertEqual(dialect.quotechar, "'") self.assertEqual(dialect.escapechar, "") - def test_delimiters(self): + def test_delimiters(self) -> None: # Adapted from CPython detector = Detector() dialect = detector.detect(self.sample3) + assert dialect is not None self.assertIn(dialect.delimiter, self.sample3) dialect = detector.detect(self.sample3, delimiters="?,") + assert dialect is not None self.assertEqual(dialect.delimiter, "?") dialect = detector.detect(self.sample3, delimiters="/,") + assert dialect is not None self.assertEqual(dialect.delimiter, "/") dialect = detector.detect(self.sample4) + assert dialect is not None self.assertEqual(dialect.delimiter, ";") dialect = detector.detect(self.sample5) + assert dialect is not None self.assertEqual(dialect.delimiter, "\t") dialect = detector.detect(self.sample6) + assert dialect is not None self.assertEqual(dialect.delimiter, "|") dialect = detector.detect(self.sample7) + assert dialect is not None self.assertEqual(dialect.delimiter, "|") self.assertEqual(dialect.quotechar, "'") dialect = detector.detect(self.sample8) + assert dialect is not None self.assertEqual(dialect.delimiter, "+") dialect = detector.detect(self.sample9) + assert dialect is not None self.assertEqual(dialect.delimiter, "+") self.assertEqual(dialect.quotechar, "'") dialect = detector.detect(self.sample10) + assert dialect is not None self.assertEqual(dialect.delimiter, ",") self.assertEqual(dialect.quotechar, "") dialect = detector.detect(self.sample11) + assert dialect is not None self.assertEqual(dialect.delimiter, ",") self.assertEqual(dialect.quotechar, '"') - def test_has_header(self): + def test_has_header(self) -> None: detector = Detector() self.assertEqual(detector.has_header(self.sample1), False) self.assertEqual( detector.has_header(self.header1 + self.sample1), True ) - def test_has_header_regex_special_delimiter(self): + def test_has_header_regex_special_delimiter(self) -> None: detector = Detector() self.assertEqual(detector.has_header(self.sample8), False) self.assertEqual( diff --git a/tests/test_unit/test_detect_pattern.py b/tests/test_unit/test_detect_pattern.py index fd77793c..0f0e08fc 100644 --- a/tests/test_unit/test_detect_pattern.py +++ b/tests/test_unit/test_detect_pattern.py @@ -19,14 +19,14 @@ class PatternTestCase(unittest.TestCase): Abstraction tests """ - def test_abstraction_1(self): + def test_abstraction_1(self) -> None: out = detect_pattern.make_abstraction( "A,B,C", SimpleDialect(delimiter=",", quotechar="", escapechar="") ) exp = "CDCDC" self.assertEqual(exp, out) - def test_abstraction_2(self): + def test_abstraction_2(self) -> None: out = detect_pattern.make_abstraction( "A,\rA,A,A\r", SimpleDialect(delimiter=",", quotechar="", escapechar=""), @@ -34,7 +34,7 @@ def test_abstraction_2(self): exp = "CDCRCDCDC" self.assertEqual(exp, out) - def test_abstraction_3(self): + def test_abstraction_3(self) -> None: out = detect_pattern.make_abstraction( "a,a,\n,a,a\ra,a,a\r\n", SimpleDialect(delimiter=",", quotechar="", escapechar=""), @@ -42,7 +42,7 @@ def test_abstraction_3(self): exp = "CDCDCRCDCDCRCDCDC" self.assertEqual(exp, out) - def test_abstraction_4(self): + def test_abstraction_4(self) -> None: out = detect_pattern.make_abstraction( 'a,"bc""d""e""f""a",\r\n', SimpleDialect(delimiter=",", quotechar='"', escapechar=""), @@ -50,7 +50,7 @@ def test_abstraction_4(self): exp = "CDCDC" self.assertEqual(exp, out) - def test_abstraction_5(self): + def test_abstraction_5(self) -> None: out = detect_pattern.make_abstraction( 'a,"bc""d"",|"f|""', SimpleDialect(delimiter=",", quotechar='"', escapechar="|"), @@ -58,21 +58,21 @@ def test_abstraction_5(self): exp = "CDC" self.assertEqual(exp, out) - def test_abstraction_6(self): + def test_abstraction_6(self) -> None: out = detect_pattern.make_abstraction( ",,,", SimpleDialect(delimiter=",", quotechar="", escapechar="") ) exp = "CDCDCDC" self.assertEqual(exp, out) - def test_abstraction_7(self): + def test_abstraction_7(self) -> None: out = detect_pattern.make_abstraction( ',"",,', SimpleDialect(delimiter=",", quotechar='"', escapechar="") ) exp = "CDCDCDC" self.assertEqual(exp, out) - def test_abstraction_8(self): + def test_abstraction_8(self) -> None: out = detect_pattern.make_abstraction( ',"",,\r\n', SimpleDialect(delimiter=",", quotechar='"', escapechar=""), @@ -84,7 +84,7 @@ def test_abstraction_8(self): Escape char tests """ - def test_abstraction_9(self): + def test_abstraction_9(self) -> None: out = detect_pattern.make_abstraction( "A,B|,C", SimpleDialect(delimiter=",", quotechar="", escapechar="|"), @@ -92,7 +92,7 @@ def test_abstraction_9(self): exp = "CDC" self.assertEqual(exp, out) - def test_abstraction_10(self): + def test_abstraction_10(self) -> None: out = detect_pattern.make_abstraction( 'A,"B,C|"D"', SimpleDialect(delimiter=",", quotechar='"', escapechar="|"), @@ -100,7 +100,7 @@ def test_abstraction_10(self): exp = "CDC" self.assertEqual(exp, out) - def test_abstraction_11(self): + def test_abstraction_11(self) -> None: out = detect_pattern.make_abstraction( "a,|b,c", SimpleDialect(delimiter=",", quotechar="", escapechar="|"), @@ -108,7 +108,7 @@ def test_abstraction_11(self): exp = "CDCDC" self.assertEqual(exp, out) - def test_abstraction_12(self): + def test_abstraction_12(self) -> None: out = detect_pattern.make_abstraction( "a,b|,c", SimpleDialect(delimiter=",", quotechar="", escapechar="|"), @@ -116,7 +116,7 @@ def test_abstraction_12(self): exp = "CDC" self.assertEqual(exp, out) - def test_abstraction_13(self): + def test_abstraction_13(self) -> None: out = detect_pattern.make_abstraction( 'a,"b,c|""', SimpleDialect(delimiter=",", quotechar='"', escapechar="|"), @@ -124,7 +124,7 @@ def test_abstraction_13(self): exp = "CDC" self.assertEqual(exp, out) - def test_abstraction_14(self): + def test_abstraction_14(self) -> None: out = detect_pattern.make_abstraction( "a,b||c", SimpleDialect(delimiter=",", quotechar="", escapechar="|"), @@ -132,7 +132,7 @@ def test_abstraction_14(self): exp = "CDC" self.assertEqual(exp, out) - def test_abstraction_15(self): + def test_abstraction_15(self) -> None: out = detect_pattern.make_abstraction( 'a,"b|"c||d|"e"', SimpleDialect(delimiter=",", quotechar='"', escapechar="|"), @@ -140,7 +140,7 @@ def test_abstraction_15(self): exp = "CDC" self.assertEqual(exp, out) - def test_abstraction_16(self): + def test_abstraction_16(self) -> None: out = detect_pattern.make_abstraction( 'a,"b|"c||d","e"', SimpleDialect(delimiter=",", quotechar='"', escapechar="|"), @@ -152,7 +152,7 @@ def test_abstraction_16(self): Fill empties """ - def test_fill_empties_1(self): + def test_fill_empties_1(self) -> None: out = detect_pattern.fill_empties("DDD") exp = "CDCDCDC" self.assertEqual(exp, out) @@ -161,7 +161,7 @@ def test_fill_empties_1(self): Pattern Score tests """ - def test_pattern_score_1(self): + def test_pattern_score_1(self) -> None: # theta_1 from paper data = ( "7,5; Mon, Jan 12;6,40\n100; Fri, Mar 21;8,23\n8,2; Thu, Sep 17;" @@ -172,7 +172,7 @@ def test_pattern_score_1(self): exp = 7 / 4 self.assertAlmostEqual(exp, out) - def test_pattern_score_2(self): + def test_pattern_score_2(self) -> None: # theta_2 from paper data = ( "7,5; Mon, Jan 12;6,40\n100; Fri, Mar 21;8,23\n8,2; Thu, Sep 17;" @@ -183,7 +183,7 @@ def test_pattern_score_2(self): exp = 10 / 3 self.assertAlmostEqual(exp, out) - def test_pattern_score_3(self): + def test_pattern_score_3(self) -> None: # theta_3 from paper data = ( "7,5; Mon, Jan 12;6,40\n100; Fri, Mar 21;8,23\n8,2; Thu, Sep 17;" diff --git a/tests/test_unit/test_detect_type.py b/tests/test_unit/test_detect_type.py index 83fcb42e..bdea6ee9 100644 --- a/tests/test_unit/test_detect_type.py +++ b/tests/test_unit/test_detect_type.py @@ -17,12 +17,12 @@ class TypeDetectorTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.td = TypeDetector() # NUMBERS - def test_number(self): + def test_number(self) -> None: yes_number: List[str] = [ "1", "2", @@ -112,7 +112,7 @@ def test_number(self): # DATES - def test_date(self): + def test_date(self) -> None: yes_date: List[str] = [ "031219", "03122019", @@ -178,7 +178,7 @@ def test_date(self): # DATETIME - def test_datetime(self): + def test_datetime(self) -> None: yes_dt: List[str] = [ "2019-01-12T04:01:23Z", "2021-09-26T12:13:31+01:00", @@ -193,7 +193,7 @@ def test_datetime(self): # URLs - def test_url(self): + def test_url(self) -> None: # Some cases copied from https://mathiasbynens.be/demo/url-regex yes_url: List[str] = [ "Cocoal.icio.us", @@ -306,7 +306,7 @@ def test_url(self): # Unicode_alphanum - def test_unicode_alphanum(self): + def test_unicode_alphanum(self) -> None: # These tests are by no means inclusive and ought to be extended in the # future. @@ -344,7 +344,7 @@ def test_unicode_alphanum(self): ) ) - def test_bytearray(self): + def test_bytearray(self) -> None: yes_bytearray: List[str] = [ "bytearray(b'')", "bytearray(b'abc,*&@\"')", @@ -367,7 +367,7 @@ def test_bytearray(self): # Unix path - def test_unix_path(self): + def test_unix_path(self) -> None: yes_path: List[str] = [ "/Users/person/abc/def-ghi/blabla.csv.test", "/home/username/share/a/_b/c_d/e.py", @@ -389,7 +389,7 @@ def test_unix_path(self): Type Score tests """ - def test_type_score_1(self): + def test_type_score_1(self) -> None: # theta_1 from paper cells = [ ["7", "5; Mon", " Jan 12;6", "40"], @@ -404,7 +404,7 @@ def test_type_score_1(self): exp = 8 / 17 self.assertAlmostEqual(exp, out) - def test_type_score_2(self): + def test_type_score_2(self) -> None: # theta_2 from paper cells = [ ["7,5", " Mon, Jan 12", "6,40"], @@ -419,7 +419,7 @@ def test_type_score_2(self): exp = 10 / 15 self.assertAlmostEqual(exp, out) - def test_type_score_3(self): + def test_type_score_3(self) -> None: # theta_3 from paper cells = [ ["7,5", " Mon, Jan 12", "6,40"], diff --git a/tests/test_unit/test_dict.py b/tests/test_unit/test_dict.py index 43d45c5a..183b4448 100644 --- a/tests/test_unit/test_dict.py +++ b/tests/test_unit/test_dict.py @@ -3,7 +3,7 @@ """ Unit tests for the DictReader and DictWriter classes -Most of these are the same as in CPython, but we also test the cases where +Most of these are the same as in CPython, but we also test the cases where CleverCSV's behavior differs. """ @@ -24,13 +24,13 @@ class DictTestCase(unittest.TestCase): ############################ # Start tests from CPython # - def test_writeheader_return_value(self): + def test_writeheader_return_value(self) -> None: with tempfile.TemporaryFile("w+", newline="") as fp: writer = clevercsv.DictWriter(fp, fieldnames=["f1", "f2", "f3"]) writeheader_return_value = writer.writeheader() self.assertEqual(writeheader_return_value, 10) - def test_write_simple_dict(self): + def test_write_simple_dict(self) -> None: with tempfile.TemporaryFile("w+", newline="") as fp: writer = clevercsv.DictWriter(fp, fieldnames=["f1", "f2", "f3"]) writer.writeheader() @@ -41,7 +41,7 @@ def test_write_simple_dict(self): fp.readline() # header self.assertEqual(fp.read(), "10,,abc\r\n") - def test_write_multiple_dict_rows(self): + def test_write_multiple_dict_rows(self) -> None: fp = io.StringIO() writer = clevercsv.DictWriter(fp, fieldnames=["f1", "f2", "f3"]) writer.writeheader() @@ -54,11 +54,11 @@ def test_write_multiple_dict_rows(self): ) self.assertEqual(fp.getvalue(), "f1,f2,f3\r\n1,abc,f\r\n2,5,xyz\r\n") - def test_write_no_fields(self): + def test_write_no_fields(self) -> None: fp = io.StringIO() self.assertRaises(TypeError, clevercsv.DictWriter, fp) - def test_write_fields_not_in_fieldnames(self): + def test_write_fields_not_in_fieldnames(self) -> None: with tempfile.TemporaryFile("w+", newline="") as fp: writer = clevercsv.DictWriter(fp, fieldnames=["f1", "f2", "f3"]) # Of special note is the non-string key (CPython issue 19449) @@ -71,7 +71,7 @@ def test_write_fields_not_in_fieldnames(self): self.assertNotIn("'f2'", exception) self.assertIn("1", exception) - def test_typo_in_extrasaction_raises_error(self): + def test_typo_in_extrasaction_raises_error(self) -> None: fp = io.StringIO() self.assertRaises( ValueError, @@ -81,7 +81,7 @@ def test_typo_in_extrasaction_raises_error(self): extrasaction="raised", ) - def test_write_field_not_in_field_names_raise(self): + def test_write_field_not_in_field_names_raise(self) -> None: fp = io.StringIO() writer = clevercsv.DictWriter(fp, ["f1", "f2"], extrasaction="raise") dictrow = {"f0": 0, "f1": 1, "f2": 2, "f3": 3} @@ -89,53 +89,55 @@ def test_write_field_not_in_field_names_raise(self): ValueError, clevercsv.DictWriter.writerow, writer, dictrow ) - def test_write_field_not_in_field_names_ignore(self): + def test_write_field_not_in_field_names_ignore(self) -> None: fp = io.StringIO() writer = clevercsv.DictWriter(fp, ["f1", "f2"], extrasaction="ignore") dictrow = {"f0": 0, "f1": 1, "f2": 2, "f3": 3} clevercsv.DictWriter.writerow(writer, dictrow) self.assertEqual(fp.getvalue(), "1,2\r\n") - def test_read_dict_fields(self): + def test_read_dict_fields(self) -> None: with tempfile.TemporaryFile("w+") as fp: fp.write("1,2,abc\r\n") fp.seek(0) - reader = clevercsv.DictReader(fp, fieldnames=["f1", "f2", "f3"]) + reader: DictReader[str] = clevercsv.DictReader( + fp, fieldnames=["f1", "f2", "f3"] + ) self.assertEqual(next(reader), {"f1": "1", "f2": "2", "f3": "abc"}) - def test_read_dict_no_fieldnames(self): + def test_read_dict_no_fieldnames(self) -> None: with tempfile.TemporaryFile("w+") as fp: fp.write("f1,f2,f3\r\n1,2,abc\r\n") fp.seek(0) - reader: DictReader = clevercsv.DictReader(fp) + reader: DictReader[str] = clevercsv.DictReader(fp) self.assertEqual(next(reader), {"f1": "1", "f2": "2", "f3": "abc"}) self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"]) # Two test cases to make sure existing ways of implicitly setting # fieldnames continue to work. Both arise from discussion in issue3436. - def test_read_dict_fieldnames_from_file(self): + def test_read_dict_fieldnames_from_file(self) -> None: with tempfile.TemporaryFile("w+") as fp: fp.write("f1,f2,f3\r\n1,2,abc\r\n") fp.seek(0) - reader = clevercsv.DictReader( + reader: DictReader[str] = clevercsv.DictReader( fp, fieldnames=next(clevercsv.reader(fp)) ) self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"]) self.assertEqual(next(reader), {"f1": "1", "f2": "2", "f3": "abc"}) - def test_read_dict_fieldnames_chain(self): + def test_read_dict_fieldnames_chain(self) -> None: import itertools with tempfile.TemporaryFile("w+") as fp: fp.write("f1,f2,f3\r\n1,2,abc\r\n") fp.seek(0) - reader: DictReader = clevercsv.DictReader(fp) + reader: DictReader[str] = clevercsv.DictReader(fp) first = next(reader) for row in itertools.chain([first], reader): self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"]) self.assertEqual(row, {"f1": "1", "f2": "2", "f3": "abc"}) - def test_read_long(self): + def test_read_long(self) -> None: with tempfile.TemporaryFile("w+") as fp: fp.write("1,2,abc,4,5,6\r\n") fp.seek(0) @@ -145,7 +147,7 @@ def test_read_long(self): {"f1": "1", "f2": "2", None: ["abc", "4", "5", "6"]}, ) - def test_read_long_with_rest(self): + def test_read_long_with_rest(self) -> None: with tempfile.TemporaryFile("w+") as fp: fp.write("1,2,abc,4,5,6\r\n") fp.seek(0) @@ -157,18 +159,18 @@ def test_read_long_with_rest(self): {"f1": "1", "f2": "2", "_rest": ["abc", "4", "5", "6"]}, ) - def test_read_long_with_rest_no_fieldnames(self): + def test_read_long_with_rest_no_fieldnames(self) -> None: with tempfile.TemporaryFile("w+") as fp: fp.write("f1,f2\r\n1,2,abc,4,5,6\r\n") fp.seek(0) - reader: DictReader = clevercsv.DictReader(fp, restkey="_rest") + reader: DictReader[str] = clevercsv.DictReader(fp, restkey="_rest") self.assertEqual(reader.fieldnames, ["f1", "f2"]) self.assertEqual( next(reader), {"f1": "1", "f2": "2", "_rest": ["abc", "4", "5", "6"]}, ) - def test_read_short(self): + def test_read_short(self) -> None: with tempfile.TemporaryFile("w+") as fp: fp.write("1,2,abc,4,5,6\r\n1,2,abc\r\n") fp.seek(0) @@ -191,7 +193,7 @@ def test_read_short(self): }, ) - def test_read_multi(self): + def test_read_multi(self) -> None: sample = [ "2147483648,43.0e12,17,abc,def\r\n", "147483648,43.0e2,17,abc,def\r\n", @@ -212,7 +214,7 @@ def test_read_multi(self): }, ) - def test_read_with_blanks(self): + def test_read_with_blanks(self) -> None: reader = clevercsv.DictReader( ["1,2,abc,4,5,6\r\n", "\r\n", "1,2,abc,4,5,6\r\n"], fieldnames="1 2 3 4 5 6".split(), @@ -226,7 +228,7 @@ def test_read_with_blanks(self): {"1": "1", "2": "2", "3": "abc", "4": "4", "5": "5", "6": "6"}, ) - def test_read_semi_sep(self): + def test_read_semi_sep(self) -> None: reader = clevercsv.DictReader( ["1;2;abc;4;5;6\r\n"], fieldnames="1 2 3 4 5 6".split(), @@ -243,8 +245,8 @@ def test_read_semi_sep(self): ################################### # Start tests added for CleverCSV # - def test_read_duplicate_fieldnames(self): - reader: DictReader = clevercsv.DictReader( + def test_read_duplicate_fieldnames(self) -> None: + reader: DictReader[str] = clevercsv.DictReader( ["f1,f2,f1\r\n", "a", "b", "c"] ) with self.assertWarns(UserWarning): diff --git a/tests/test_unit/test_encoding.py b/tests/test_unit/test_encoding.py index 3750f26a..b900d036 100644 --- a/tests/test_unit/test_encoding.py +++ b/tests/test_unit/test_encoding.py @@ -13,38 +13,50 @@ import tempfile import unittest +from dataclasses import dataclass + +from typing import Any +from typing import List + from clevercsv._optional import import_optional_dependency +from clevercsv._types import AnyPath from clevercsv.encoding import get_encoding from clevercsv.write import writer class EncodingTestCase(unittest.TestCase): - cases = [ - { - "table": [["Å", "B", "C"], [1, 2, 3], [4, 5, 6]], - "encoding": "ISO-8859-1", - "cchardet_encoding": "WINDOWS-1252", - }, - { - "table": [["A", "B", "C"], [1, 2, 3], [4, 5, 6]], - "encoding": "ascii", - "cchardet_encoding": "ASCII", - }, - { - "table": [["亜唖", "娃阿", "哀愛"], [1, 2, 3], ["挨", "姶", "葵"]], - "encoding": "ISO-2022-JP", - "cchardet_encoding": "ISO-2022-JP", - }, + @dataclass + class Instance: + table: List[List[Any]] + encoding: str + cchardet_encoding: str + + cases: List[Instance] = [ + Instance( + table=[["Å", "B", "C"], [1, 2, 3], [4, 5, 6]], + encoding="ISO-8859-1", + cchardet_encoding="WINDOWS-1252", + ), + Instance( + table=[["A", "B", "C"], [1, 2, 3], [4, 5, 6]], + encoding="ascii", + cchardet_encoding="ASCII", + ), + Instance( + table=[["亜唖", "娃阿", "哀愛"], [1, 2, 3], ["挨", "姶", "葵"]], + encoding="ISO-2022-JP", + cchardet_encoding="ISO-2022-JP", + ), ] - def setUp(self): - self._tmpfiles = [] + def setUp(self) -> None: + self._tmpfiles: List[AnyPath] = [] - def tearDown(self): + def tearDown(self) -> None: for f in self._tmpfiles: os.unlink(f) - def _build_file(self, table, encoding): + def _build_file(self, table: List[List[str]], encoding: str) -> str: tmpfd, tmpfname = tempfile.mkstemp( prefix="ccsv_", suffix=".csv", @@ -56,26 +68,26 @@ def _build_file(self, table, encoding): self._tmpfiles.append(tmpfname) return tmpfname - def test_encoding_chardet(self): + def test_encoding_chardet(self) -> None: for case in self.cases: - table = case["table"] - encoding = case["encoding"] + table = case.table + encoding = case.encoding with self.subTest(encoding=encoding): tmpfname = self._build_file(table, encoding) detected = get_encoding(tmpfname, try_cchardet=False) self.assertEqual(encoding, detected) - def test_encoding_cchardet(self): + def test_encoding_cchardet(self) -> None: try: _ = import_optional_dependency("cchardet") except ImportError: self.skipTest("Failed to import cchardet, skipping this test") for case in self.cases: - table = case["table"] - encoding = case["encoding"] + table = case.table + encoding = case.encoding with self.subTest(encoding=encoding): - out_encoding = case["cchardet_encoding"] + out_encoding = case.cchardet_encoding tmpfname = self._build_file(table, encoding) detected = get_encoding(tmpfname, try_cchardet=True) self.assertEqual(out_encoding, detected) diff --git a/tests/test_unit/test_fuzzing.py b/tests/test_unit/test_fuzzing.py index 7a00b060..1462b9f5 100644 --- a/tests/test_unit/test_fuzzing.py +++ b/tests/test_unit/test_fuzzing.py @@ -11,7 +11,7 @@ class FuzzingTestCase(unittest.TestCase): - def test_sniffer_fuzzing(self): + def test_sniffer_fuzzing(self) -> None: strings = ['"""', "```", "\"'", "'@'", "'\"", "'''", "O##P~` "] for string in strings: with self.subTest(string=string): diff --git a/tests/test_unit/test_normal_forms.py b/tests/test_unit/test_normal_forms.py index 8ef6ad2a..3e83480b 100644 --- a/tests/test_unit/test_normal_forms.py +++ b/tests/test_unit/test_normal_forms.py @@ -18,7 +18,7 @@ class NormalFormTestCase(unittest.TestCase): - def test_form_1(self): + def test_form_1(self) -> None: dialect = SimpleDialect(delimiter=",", quotechar='"', escapechar="") self.assertTrue(is_form_1('"A","B","C"', dialect)) @@ -32,7 +32,7 @@ def test_form_1(self): self.assertFalse(is_form_1('"A",C', dialect)) self.assertFalse(is_form_1('"A"\n"b""A""c","B"', dialect)) - def test_form_2(self): + def test_form_2(self) -> None: dialect = SimpleDialect(delimiter=",", quotechar="", escapechar="") self.assertTrue(is_form_2("1,2,3", dialect)) @@ -47,7 +47,7 @@ def test_form_2(self): self.assertFalse(is_form_2('"a,3,3\n1,2,3', dialect)) self.assertFalse(is_form_2('a,"",3\n1,2,3', dialect)) - def test_form_3(self): + def test_form_3(self) -> None: A = SimpleDialect(delimiter=",", quotechar="'", escapechar="") Q = SimpleDialect(delimiter=",", quotechar='"', escapechar="") @@ -60,7 +60,7 @@ def test_form_3(self): self.assertFalse(is_form_3('A,B\n"C",D\n', A)) self.assertTrue(is_form_3('A,B\n"C",D\n', Q)) - def test_form_4(self): + def test_form_4(self) -> None: quoted = SimpleDialect(delimiter="", quotechar='"', escapechar="") unquoted = SimpleDialect(delimiter="", quotechar="", escapechar="") @@ -77,7 +77,7 @@ def test_form_4(self): self.assertFalse(is_form_4('A\n"-1"\n2', unquoted)) self.assertFalse(is_form_4("A B\n-1 3\n2 4", unquoted)) - def test_form_5(self): + def test_form_5(self) -> None: dialect = SimpleDialect(delimiter=",", quotechar='"', escapechar="") self.assertTrue(is_form_5('"A,B"\n"1,2"\n"3,4"', dialect)) diff --git a/tests/test_unit/test_potential_dialects.py b/tests/test_unit/test_potential_dialects.py index 0f06a882..37670dc8 100644 --- a/tests/test_unit/test_potential_dialects.py +++ b/tests/test_unit/test_potential_dialects.py @@ -16,24 +16,24 @@ class PotentialDialectTestCase(unittest.TestCase): - def test_masked_by_quotechar(self): + def test_masked_by_quotechar(self) -> None: self.assertTrue(masked_by_quotechar('A"B&C"A', '"', "", "&")) self.assertFalse(masked_by_quotechar('A"B&C"A&A', '"', "", "&")) self.assertFalse(masked_by_quotechar('A|"B&C"A', '"', "|", "&")) self.assertFalse(masked_by_quotechar('A"B"C', '"', "", "")) - def test_filter_urls(self): + def test_filter_urls(self) -> None: data = "A,B\nwww.google.com,10\nhttps://gertjanvandenburg.com,25\n" exp = "A,B\nU,10\nU,25\n" self.assertEqual(exp, filter_urls(data)) - def test_get_quotechars(self): + def test_get_quotechars(self) -> None: data = "A,B,'A',B\"D\"E" exp = set(['"', "'", ""]) out = get_quotechars(data) self.assertEqual(out, exp) - def test_get_delimiters(self): + def test_get_delimiters(self) -> None: data = "A,B|CD,E;F\tD123£123€10.,0" exp = set([",", "|", ";", "\t", "€", "£", ""]) out = get_delimiters(data, "UTF-8") diff --git a/tests/test_unit/test_reader.py b/tests/test_unit/test_reader.py index b9f7e57f..e7953668 100644 --- a/tests/test_unit/test_reader.py +++ b/tests/test_unit/test_reader.py @@ -9,16 +9,23 @@ import unittest +from typing import Any +from typing import Iterable +from typing import Iterator +from typing import List + import clevercsv class ReaderTestCase(unittest.TestCase): - def _read_test(self, input, expect, **kwargs): + def _read_test( + self, input: Iterable[str], expect: List[List[str]], **kwargs: Any + ) -> None: reader = clevercsv.reader(input, **kwargs) result = list(reader) self.assertEqual(result, expect) - def test_read_oddinputs(self): + def test_read_oddinputs(self) -> None: self._read_test([], []) self._read_test([""], [[]]) self.assertRaises( @@ -34,7 +41,7 @@ def test_read_oddinputs(self): # self._read_test(['"ab"c'], [["abc"]], doublequote=0) self.assertRaises(clevercsv.Error, self._read_test, [b"ab\0c"], None) - def test_read_eol(self): + def test_read_eol(self) -> None: self._read_test(["a,b"], [["a", "b"]]) self._read_test(["a,b\n"], [["a", "b"]]) self._read_test(["a,b\r\n"], [["a", "b"]]) @@ -44,13 +51,13 @@ def test_read_eol(self): self.assertRaises(clevercsv.Error, self._read_test, ["a,b\rc,d"], []) self.assertRaises(clevercsv.Error, self._read_test, ["a,b\rc,d"], []) - def test_read_eof(self): + def test_read_eof(self) -> None: self._read_test(['a,"'], [["a", ""]]) self._read_test(['"a'], [["a"]]) # we're not using escape characters in the same way. # self._read_test("^", [["\n"]], escapechar="^") - def test_read_escape(self): + def test_read_escape(self) -> None: # we don't drop the escapechar if it serves no purpose # so instead of this: # self._read_test("a,\\b,c", [["a", "b", "c"]], escapechar="\\") @@ -61,7 +68,7 @@ def test_read_escape(self): # the next test also differs from Python self._read_test(['a,"b,c"\\'], [["a", "b,c"]], escapechar="\\") - def test_read_bigfield(self): + def test_read_bigfield(self) -> None: limit = clevercsv.field_size_limit() try: size = 500 @@ -78,7 +85,7 @@ def test_read_bigfield(self): finally: clevercsv.field_size_limit(limit) - def test_read_linenum(self): + def test_read_linenum(self) -> None: r = clevercsv.reader(["line,1", "line,2", "line,3"]) self.assertEqual(r.line_num, 0) self.assertEqual(next(r), ["line", "1"]) @@ -90,8 +97,8 @@ def test_read_linenum(self): self.assertRaises(StopIteration, next, r) self.assertEqual(r.line_num, 3) - def test_with_gen(self): - def gen(x): + def test_with_gen(self) -> None: + def gen(x: Iterable[str]) -> Iterator[str]: for line in x: yield line @@ -100,7 +107,7 @@ def gen(x): self.assertEqual(next(r), ["line", "2"]) self.assertEqual(next(r), ["line", "3"]) - def test_simple(self): + def test_simple(self) -> None: self._read_test( ["A,B,C,D,E"], [["A", "B", "C", "D", "E"]], @@ -139,7 +146,7 @@ def test_simple(self): quotechar="", ) - def test_no_delim(self): + def test_no_delim(self) -> None: self._read_test( ['A"B"C', 'A"B""C""D"'], [['A"B"C'], ['A"B""C""D"']], diff --git a/tests/test_unit/test_wrappers.py b/tests/test_unit/test_wrappers.py index e054f002..3020c584 100644 --- a/tests/test_unit/test_wrappers.py +++ b/tests/test_unit/test_wrappers.py @@ -12,6 +12,9 @@ import types import unittest +from typing import Any +from typing import Dict +from typing import Iterable from typing import List from typing import Union @@ -24,7 +27,9 @@ class WrappersTestCase(unittest.TestCase): - def _df_test(self, table, dialect, **kwargs): + def _df_test( + self, table: List[List[Any]], dialect: SimpleDialect, **kwargs: Any + ) -> None: tmpfd, tmpfname = tempfile.mkstemp(prefix="ccsv_", suffix=".csv") tmpid = os.fdopen(tmpfd, "w", encoding=kwargs.get("encoding")) w = writer(tmpid, dialect=dialect) @@ -39,7 +44,9 @@ def _df_test(self, table, dialect, **kwargs): finally: os.unlink(tmpfname) - def _write_tmpfile(self, table, dialect): + def _write_tmpfile( + self, table: Iterable[Iterable[Any]], dialect: SimpleDialect + ) -> str: """Write a table to a temporary file using specified dialect""" tmpfd, tmpfname = tempfile.mkstemp(prefix="ccsv_", suffix=".csv") tmpid = os.fdopen(tmpfd, "w") @@ -48,7 +55,9 @@ def _write_tmpfile(self, table, dialect): tmpid.close() return tmpfname - def _read_test(self, table, dialect): + def _read_test( + self, table: Iterable[Iterable[Any]], dialect: SimpleDialect + ) -> None: tmpfname = self._write_tmpfile(table, dialect) exp = [list(map(str, r)) for r in table] try: @@ -56,7 +65,9 @@ def _read_test(self, table, dialect): finally: os.unlink(tmpfname) - def _stream_test(self, table, dialect): + def _stream_test( + self, table: Iterable[Iterable[str]], dialect: SimpleDialect + ) -> None: tmpfname = self._write_tmpfile(table, dialect) exp = [list(map(str, r)) for r in table] try: @@ -66,7 +77,9 @@ def _stream_test(self, table, dialect): finally: os.unlink(tmpfname) - def _read_test_rows(self, rows, expected): + def _read_test_rows( + self, rows: List[str], expected: List[List[str]] + ) -> None: contents = "\n".join(rows) tmpfd, tmpfname = tempfile.mkstemp(prefix="ccsv_", suffix=".csv") tmpid = os.fdopen(tmpfd, "w") @@ -78,7 +91,9 @@ def _read_test_rows(self, rows, expected): finally: os.unlink(tmpfname) - def _stream_test_rows(self, rows, expected): + def _stream_test_rows( + self, rows: Iterable[str], expected: List[List[str]] + ) -> None: contents = "\n".join(rows) tmpfd, tmpfname = tempfile.mkstemp(prefix="ccsv_", suffix=".csv") tmpid = os.fdopen(tmpfd, "w") @@ -92,7 +107,9 @@ def _stream_test_rows(self, rows, expected): finally: os.unlink(tmpfname) - def test_read_dataframe(self): + def test_read_dataframe(self) -> None: + table: List[List[Any]] + table = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] dialect = SimpleDialect(delimiter=";", quotechar="", escapechar="") with self.subTest(name="simple"): @@ -123,7 +140,9 @@ def test_read_dataframe(self): with self.subTest(name="simple_encoding"): self._df_test(table, dialect, num_char=10, encoding="latin1") - def test_read_table(self): + def test_read_table(self) -> None: + table: List[List[Any]] + table = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] dialect = SimpleDialect(delimiter=";", quotechar="", escapechar="") with self.subTest(name="simple"): @@ -158,7 +177,9 @@ def test_read_table(self): with self.assertRaises(NoDetectionResult): self._read_test_rows(rows, exp) - def test_stream_table(self): + def test_stream_table(self) -> None: + table: List[List[Any]] + table = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] dialect = SimpleDialect(delimiter=";", quotechar="", escapechar="") with self.subTest(name="simple"): @@ -193,7 +214,9 @@ def test_stream_table(self): with self.assertRaises(NoDetectionResult): self._stream_test_rows(rows, exp) - def _write_test_table(self, table, expected, **kwargs): + def _write_test_table( + self, table: Iterable[Iterable[Any]], expected: str, **kwargs: Any + ) -> None: tmpfd, tmpfname = tempfile.mkstemp(prefix="ccsv_", suffix=".csv") wrappers.write_table(table, tmpfname, **kwargs) read_encoding = kwargs.get("encoding", None) @@ -206,7 +229,7 @@ def _write_test_table(self, table, expected, **kwargs): os.close(tmpfd) os.unlink(tmpfname) - def test_write_table(self): + def test_write_table(self) -> None: table: List[List[Union[str, int]]] = [ ["A", "B,C", "D"], [1, 2, 3], @@ -240,7 +263,9 @@ def test_write_table(self): with self.subTest(name="encoding_2"): self._write_test_table(table, exp, encoding="cp1252") - def _write_test_dicts(self, items, expected, **kwargs): + def _write_test_dicts( + self, items: Iterable[Dict[str, Any]], expected: str, **kwargs: Any + ) -> None: tmpfd, tmpfname = tempfile.mkstemp(prefix="ccsv_", suffix=".csv") wrappers.write_dicts(items, tmpfname, **kwargs) read_encoding = kwargs.get("encoding", None) @@ -253,7 +278,7 @@ def _write_test_dicts(self, items, expected, **kwargs): os.close(tmpfd) os.unlink(tmpfname) - def test_write_dicts(self): + def test_write_dicts(self) -> None: items = [{"A": 1, "B": 2, "C": 3}, {"A": 4, "B": 5, "C": 6}] exp = "A,B,C\r\n1,2,3\r\n4,5,6\r\n" with self.subTest(name="default"): diff --git a/tests/test_unit/test_write.py b/tests/test_unit/test_write.py index 2c8367c9..413e97cd 100644 --- a/tests/test_unit/test_write.py +++ b/tests/test_unit/test_write.py @@ -12,20 +12,28 @@ import tempfile import unittest +from typing import Any +from typing import Iterable +from typing import Type + import clevercsv from clevercsv.dialect import SimpleDialect class WriterTestCase(unittest.TestCase): - def _write_test(self, fields, expect, **kwargs): + def _write_test( + self, fields: Iterable[Any], expect: str, **kwargs: Any + ) -> None: with tempfile.TemporaryFile("w+", newline="", prefix="ccsv_") as fp: writer = clevercsv.writer(fp, **kwargs) writer.writerow(fields) fp.seek(0) self.assertEqual(fp.read(), expect + writer.dialect.lineterminator) - def _write_error_test(self, exc, fields, **kwargs): + def _write_error_test( + self, exc: Type[Exception], fields: Any, **kwargs: Any + ) -> None: with tempfile.TemporaryFile("w+", newline="", prefix="ccsv_") as fp: writer = clevercsv.writer(fp, **kwargs) with self.assertRaises(exc): @@ -33,7 +41,7 @@ def _write_error_test(self, exc, fields, **kwargs): fp.seek(0) self.assertEqual(fp.read(), "") - def test_write_arg_valid(self): + def test_write_arg_valid(self) -> None: self._write_error_test(clevercsv.Error, None) self._write_test((), "") self._write_test([None], '""') @@ -43,28 +51,28 @@ def test_write_arg_valid(self): # Check that exceptions are passed up the chain class BadList: - def __len__(self): + def __len__(self) -> int: return 10 - def __getitem__(self, i): + def __getitem__(self, i: int) -> None: if i > 2: raise OSError self._write_error_test(OSError, BadList()) class BadItem: - def __str__(self): + def __str__(self) -> str: raise OSError self._write_error_test(OSError, [BadItem()]) - def test_write_bigfield(self): + def test_write_bigfield(self) -> None: bigstring = "X" * 50000 self._write_test( [bigstring, bigstring], "%s,%s" % (bigstring, bigstring) ) - def test_write_quoting(self): + def test_write_quoting(self) -> None: self._write_test(["a", 1, "p,q"], 'a,1,"p,q"') self._write_error_test( clevercsv.Error, ["a", 1, "p,q"], quoting=clevercsv.QUOTE_NONE @@ -82,14 +90,14 @@ def test_write_quoting(self): ["a\nb", 1], '"a\nb","1"', quoting=clevercsv.QUOTE_ALL ) - def test_write_simpledialect(self): + def test_write_simpledialect(self) -> None: self._write_test( ["a", 1, "p,q"], "a,1,|p,q|", dialect=SimpleDialect(delimiter=",", quotechar="|", escapechar=""), ) - def test_write_csv_dialect(self): + def test_write_csv_dialect(self) -> None: self._write_test( ["a", 1, "p,q"], 'a,1,"p,q"',