Skip to content

Commit

Permalink
fixup! Let is_lambda stay False
Browse files Browse the repository at this point in the history
Add _get_yield_nodes_skip_functions()
  • Loading branch information
jacobtylerwalls committed Jun 22, 2023
1 parent 1a12169 commit 75cc881
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 1 deletion.
7 changes: 7 additions & 0 deletions astroid/nodes/_base_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,13 @@ def _get_return_nodes_skip_functions(self):
continue
yield from child_node._get_return_nodes_skip_functions()

def _get_yield_nodes_skip_functions(self):
for block in self._multi_line_blocks:
for child_node in block:
if child_node.is_function:
continue
yield from child_node._get_yield_nodes_skip_functions()

def _get_yield_nodes_skip_lambdas(self):
for block in self._multi_line_blocks:
for child_node in block:
Expand Down
25 changes: 25 additions & 0 deletions astroid/nodes/node_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,9 @@ def get_children(self):
def _assign_nodes_in_scope(self) -> list[nodes.Assign]:
return [self, *self.value._assign_nodes_in_scope]

def _get_yield_nodes_skip_functions(self):
yield from self.value._get_yield_nodes_skip_functions()

def _get_yield_nodes_skip_lambdas(self):
yield from self.value._get_yield_nodes_skip_lambdas()

Expand Down Expand Up @@ -1203,6 +1206,11 @@ def get_children(self):
yield self.target
yield self.value

def _get_yield_nodes_skip_functions(self):
"""An AugAssign node can contain a Yield node in the value"""
yield from self.value._get_yield_nodes_skip_functions()
yield from super()._get_yield_nodes_skip_functions()

def _get_yield_nodes_skip_lambdas(self):
"""An AugAssign node can contain a Yield node in the value"""
yield from self.value._get_yield_nodes_skip_lambdas()
Expand Down Expand Up @@ -1984,6 +1992,10 @@ def postinit(self, value: NodeNG) -> None:
def get_children(self):
yield self.value

def _get_yield_nodes_skip_functions(self):
if not self.value.is_function:
yield from self.value._get_yield_nodes_skip_functions()

def _get_yield_nodes_skip_lambdas(self):
if not self.value.is_lambda:
yield from self.value._get_yield_nodes_skip_lambdas()
Expand Down Expand Up @@ -2432,6 +2444,11 @@ def get_children(self):
def has_elif_block(self):
return len(self.orelse) == 1 and isinstance(self.orelse[0], If)

def _get_yield_nodes_skip_functions(self):
"""An If node can contain a Yield node in the test"""
yield from self.test._get_yield_nodes_skip_functions()
yield from super()._get_yield_nodes_skip_functions()

def _get_yield_nodes_skip_lambdas(self):
"""An If node can contain a Yield node in the test"""
yield from self.test._get_yield_nodes_skip_lambdas()
Expand Down Expand Up @@ -3442,6 +3459,11 @@ def get_children(self):
yield from self.body
yield from self.orelse

def _get_yield_nodes_skip_functions(self):
"""A While node can contain a Yield node in the test"""
yield from self.test._get_yield_nodes_skip_functions()
yield from super()._get_yield_nodes_skip_functions()

def _get_yield_nodes_skip_lambdas(self):
"""A While node can contain a Yield node in the test"""
yield from self.test._get_yield_nodes_skip_lambdas()
Expand Down Expand Up @@ -3577,6 +3599,9 @@ def get_children(self):
if self.value is not None:
yield self.value

def _get_yield_nodes_skip_functions(self):
yield self

def _get_yield_nodes_skip_lambdas(self):
yield self

Expand Down
3 changes: 3 additions & 0 deletions astroid/nodes/node_ng.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,9 @@ def _get_name_nodes(self):
def _get_return_nodes_skip_functions(self):
yield from ()

def _get_yield_nodes_skip_functions(self):
yield from ()

def _get_yield_nodes_skip_lambdas(self):
yield from ()

Expand Down
2 changes: 1 addition & 1 deletion astroid/nodes/scoped_nodes/scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,7 +1465,7 @@ def is_generator(self) -> bool:
:returns: Whether this is a generator function.
"""
return bool(next(self._get_yield_nodes_skip_lambdas(), False))
return bool(next(self._get_yield_nodes_skip_functions(), False))

def infer_yield_result(self, context: InferenceContext | None = None):
"""Infer what the function yields when called
Expand Down

0 comments on commit 75cc881

Please sign in to comment.