Skip to content

Commit

Permalink
Add prefer_stubs configuration (#2437) (#2438)
Browse files Browse the repository at this point in the history
(cherry picked from commit ee06feb)

Co-authored-by: Jacob Walls <jacobtylerwalls@gmail.com>
  • Loading branch information
github-actions[bot] and jacobtylerwalls committed May 16, 2024
1 parent a7ff092 commit 3650c34
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 14 deletions.
5 changes: 5 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ Release date: TBA

Closes pylint-dev/pylint#9139

* Add ``AstroidManager.prefer_stubs`` attribute to control the astroid 3.2.0 feature that prefers stubs.

Refs pylint-dev/#9626
Refs pylint-dev/#9623


What's New in astroid 3.2.0?
============================
Expand Down
8 changes: 2 additions & 6 deletions astroid/interpreter/_import/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,10 @@ def find_module(
pass
submodule_path = sys.path

# We're looping on pyi first because if a pyi exists there's probably a reason
# (i.e. the code is hard or impossible to parse), so we take pyi into account
# But we're not quite ready to do this for numpy, see https://github.com/pylint-dev/astroid/pull/2375
suffixes = (".pyi", ".py", importlib.machinery.BYTECODE_SUFFIXES[0])
numpy_suffixes = (".py", ".pyi", importlib.machinery.BYTECODE_SUFFIXES[0])
suffixes = (".py", ".pyi", importlib.machinery.BYTECODE_SUFFIXES[0])
for entry in submodule_path:
package_directory = os.path.join(entry, modname)
for suffix in numpy_suffixes if "numpy" in entry else suffixes:
for suffix in suffixes:
package_file_name = "__init__" + suffix
file_path = os.path.join(package_directory, package_file_name)
if os.path.isfile(file_path):
Expand Down
14 changes: 13 additions & 1 deletion astroid/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class AstroidManager:
"extension_package_whitelist": set(),
"module_denylist": set(),
"_transform": TransformVisitor(),
"prefer_stubs": False,
}

def __init__(self) -> None:
Expand All @@ -73,6 +74,7 @@ def __init__(self) -> None:
]
self.module_denylist = AstroidManager.brain["module_denylist"]
self._transform = AstroidManager.brain["_transform"]
self.prefer_stubs = AstroidManager.brain["prefer_stubs"]

@property
def always_load_extensions(self) -> bool:
Expand Down Expand Up @@ -111,6 +113,14 @@ def unregister_transform(self):
def builtins_module(self) -> nodes.Module:
return self.astroid_cache["builtins"]

@property
def prefer_stubs(self) -> bool:
return AstroidManager.brain["prefer_stubs"]

@prefer_stubs.setter
def prefer_stubs(self, value: bool) -> None:
AstroidManager.brain["prefer_stubs"] = value

def visit_transforms(self, node: nodes.NodeNG) -> InferenceResult:
"""Visit the transforms and apply them to the given *node*."""
return self._transform.visit(node)
Expand All @@ -136,7 +146,9 @@ def ast_from_file(
# Call get_source_file() only after a cache miss,
# since it calls os.path.exists().
try:
filepath = get_source_file(filepath, include_no_ext=True)
filepath = get_source_file(
filepath, include_no_ext=True, prefer_stubs=self.prefer_stubs
)
source = True
except NoSourceFile:
pass
Expand Down
15 changes: 9 additions & 6 deletions astroid/modutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@


if sys.platform.startswith("win"):
PY_SOURCE_EXTS = ("pyi", "pyw", "py")
PY_SOURCE_EXTS = ("py", "pyw", "pyi")
PY_SOURCE_EXTS_STUBS_FIRST = ("pyi", "pyw", "py")
PY_COMPILED_EXTS = ("dll", "pyd")
else:
PY_SOURCE_EXTS = ("pyi", "py")
PY_SOURCE_EXTS = ("py", "pyi")
PY_SOURCE_EXTS_STUBS_FIRST = ("pyi", "py")
PY_COMPILED_EXTS = ("so",)


Expand Down Expand Up @@ -484,7 +486,9 @@ def get_module_files(
return files


def get_source_file(filename: str, include_no_ext: bool = False) -> str:
def get_source_file(
filename: str, include_no_ext: bool = False, prefer_stubs: bool = False
) -> str:
"""Given a python module's file name return the matching source file
name (the filename will be returned identically if it's already an
absolute path to a python source file).
Expand All @@ -499,7 +503,7 @@ def get_source_file(filename: str, include_no_ext: bool = False) -> str:
base, orig_ext = os.path.splitext(filename)
if orig_ext == ".pyi" and os.path.exists(f"{base}{orig_ext}"):
return f"{base}{orig_ext}"
for ext in PY_SOURCE_EXTS if "numpy" not in filename else reversed(PY_SOURCE_EXTS):
for ext in PY_SOURCE_EXTS_STUBS_FIRST if prefer_stubs else PY_SOURCE_EXTS:
source_path = f"{base}.{ext}"
if os.path.exists(source_path):
return source_path
Expand Down Expand Up @@ -671,8 +675,7 @@ def _has_init(directory: str) -> str | None:
else return None.
"""
mod_or_pack = os.path.join(directory, "__init__")
exts = reversed(PY_SOURCE_EXTS) if "numpy" in directory else PY_SOURCE_EXTS
for ext in (*exts, "pyc", "pyo"):
for ext in (*PY_SOURCE_EXTS, "pyc", "pyo"):
if os.path.exists(mod_or_pack + "." + ext):
return mod_or_pack + "." + ext
return None
Expand Down
3 changes: 2 additions & 1 deletion tests/test_modutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ def test_pyi_preferred(self) -> None:
package = resources.find("pyi_data/find_test")
module = os.path.join(package, "__init__.py")
self.assertEqual(
modutils.get_source_file(module), os.path.normpath(module) + "i"
modutils.get_source_file(module, prefer_stubs=True),
os.path.normpath(module) + "i",
)


Expand Down

0 comments on commit 3650c34

Please sign in to comment.