Skip to content

Commit

Permalink
[red-knot] Add control flow support for match statement
Browse files Browse the repository at this point in the history
  • Loading branch information
dhruvmanila committed Sep 9, 2024
1 parent b04948f commit d483753
Show file tree
Hide file tree
Showing 8 changed files with 319 additions and 34 deletions.
15 changes: 15 additions & 0 deletions crates/red_knot_python_semantic/src/semantic_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::Db;

pub mod ast_ids;
mod builder;
pub(crate) mod constraint;
pub mod definition;
pub mod expression;
pub mod symbol;
Expand Down Expand Up @@ -1222,4 +1223,18 @@ match 1:

assert!(matches!(definition.node(&db), DefinitionKind::For(_)));
}

#[test]
fn debug_test() {
let TestCase { db, file } = test_case(
"
x = 0
if x == 0:
y = 3
elif x == 1:
y = 4
",
);
semantic_index(&db, file);
}
}
57 changes: 53 additions & 4 deletions crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use crate::semantic_index::use_def::{FlowSnapshot, UseDefMapBuilder};
use crate::semantic_index::SemanticIndex;
use crate::Db;

use super::constraint::{Constraint, PatternConstraint};
use super::definition::{MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef};

pub(super) struct SemanticIndexBuilder<'db> {
Expand Down Expand Up @@ -204,13 +205,39 @@ impl<'db> SemanticIndexBuilder<'db> {
definition
}

fn add_constraint(&mut self, constraint_node: &ast::Expr) -> Expression<'db> {
fn add_expression_constraint(&mut self, constraint_node: &ast::Expr) -> Expression<'db> {
let expression = self.add_standalone_expression(constraint_node);
self.current_use_def_map_mut().record_constraint(expression);
self.current_use_def_map_mut()
.record_constraint(Constraint::Expression(expression));

expression
}

fn add_pattern_constraint(
&mut self,
subject: &ast::Expr,
pattern: &ast::Pattern,
) -> PatternConstraint<'db> {
#[allow(unsafe_code)]
let (subject, pattern) = unsafe {
(
AstNodeRef::new(self.module.clone(), subject),
AstNodeRef::new(self.module.clone(), pattern),
)
};
let pattern_constraint = PatternConstraint::new(
self.db,
self.file,
self.current_scope(),
subject,
pattern,
countme::Count::default(),
);
self.current_use_def_map_mut()
.record_constraint(Constraint::Pattern(pattern_constraint));
pattern_constraint
}

/// Record an expression that needs to be a Salsa ingredient, because we need to infer its type
/// standalone (type narrowing tests, RHS of an assignment.)
fn add_standalone_expression(&mut self, expression_node: &ast::Expr) -> Expression<'db> {
Expand Down Expand Up @@ -523,7 +550,7 @@ where
ast::Stmt::If(node) => {
self.visit_expr(&node.test);
let pre_if = self.flow_snapshot();
self.add_constraint(&node.test);
self.add_expression_constraint(&node.test);
self.visit_body(&node.body);
let mut post_clauses: Vec<FlowSnapshot> = vec![];
for clause in &node.elif_else_clauses {
Expand Down Expand Up @@ -615,9 +642,31 @@ where
}) => {
self.add_standalone_expression(subject);
self.visit_expr(subject);
for case in cases {

let after_subject = self.flow_snapshot();
let Some((first, remaining)) = cases.split_first() else {
// TODO(dhruvmanila): In case of error recovery, we should not panic here
unreachable!("Match statement must have at least one case block");
};
self.add_pattern_constraint(subject, &first.pattern);
self.visit_match_case(first);

let mut post_case_snapshots = vec![];
for case in remaining {
post_case_snapshots.push(self.flow_snapshot());
self.flow_restore(after_subject.clone());
self.add_pattern_constraint(subject, &case.pattern);
self.visit_match_case(case);
}
for post_clause_state in post_case_snapshots {
self.flow_merge(post_clause_state);
}
if !cases
.last()
.is_some_and(|case| case.guard.is_none() && case.pattern.is_wildcard())
{
self.flow_merge(after_subject);
}
}
_ => {
walk_stmt(self, stmt);
Expand Down
39 changes: 39 additions & 0 deletions crates/red_knot_python_semantic/src/semantic_index/constraint.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use ruff_db::files::File;
use ruff_python_ast as ast;

use crate::ast_node_ref::AstNodeRef;
use crate::db::Db;
use crate::semantic_index::expression::Expression;
use crate::semantic_index::symbol::{FileScopeId, ScopeId};

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum Constraint<'db> {
Expression(Expression<'db>),
Pattern(PatternConstraint<'db>),
}

#[salsa::tracked]
pub(crate) struct PatternConstraint<'db> {
#[id]
pub(crate) file: File,

#[id]
pub(crate) file_scope: FileScopeId,

#[no_eq]
#[return_ref]
pub(crate) subject: AstNodeRef<ast::Expr>,

#[no_eq]
#[return_ref]
pub(crate) pattern: AstNodeRef<ast::Pattern>,

#[no_eq]
count: countme::Count<PatternConstraint<'static>>,
}

impl<'db> PatternConstraint<'db> {
pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> {
self.file_scope(db).to_scope_id(db, self.file(db))
}
}
19 changes: 10 additions & 9 deletions crates/red_knot_python_semantic/src/semantic_index/use_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,11 @@ use self::symbol_state::{
};
use crate::semantic_index::ast_ids::ScopedUseId;
use crate::semantic_index::definition::Definition;
use crate::semantic_index::expression::Expression;
use crate::semantic_index::symbol::ScopedSymbolId;
use ruff_index::IndexVec;

use super::constraint::Constraint;

mod bitset;
mod symbol_state;

Expand All @@ -159,8 +160,8 @@ pub(crate) struct UseDefMap<'db> {
/// Array of [`Definition`] in this scope.
all_definitions: IndexVec<ScopedDefinitionId, Definition<'db>>,

/// Array of constraints (as [`Expression`]) in this scope.
all_constraints: IndexVec<ScopedConstraintId, Expression<'db>>,
/// Array of [`Constraint`] in this scope.
all_constraints: IndexVec<ScopedConstraintId, Constraint<'db>>,

/// [`SymbolState`] visible at a [`ScopedUseId`].
definitions_by_use: IndexVec<ScopedUseId, SymbolState>,
Expand Down Expand Up @@ -204,7 +205,7 @@ impl<'db> UseDefMap<'db> {
#[derive(Debug)]
pub(crate) struct DefinitionWithConstraintsIterator<'map, 'db> {
all_definitions: &'map IndexVec<ScopedDefinitionId, Definition<'db>>,
all_constraints: &'map IndexVec<ScopedConstraintId, Expression<'db>>,
all_constraints: &'map IndexVec<ScopedConstraintId, Constraint<'db>>,
inner: DefinitionIdWithConstraintsIterator<'map>,
}

Expand Down Expand Up @@ -232,12 +233,12 @@ pub(crate) struct DefinitionWithConstraints<'map, 'db> {
}

pub(crate) struct ConstraintsIterator<'map, 'db> {
all_constraints: &'map IndexVec<ScopedConstraintId, Expression<'db>>,
all_constraints: &'map IndexVec<ScopedConstraintId, Constraint<'db>>,
constraint_ids: ConstraintIdIterator<'map>,
}

impl<'map, 'db> Iterator for ConstraintsIterator<'map, 'db> {
type Item = Expression<'db>;
type Item = Constraint<'db>;

fn next(&mut self) -> Option<Self::Item> {
self.constraint_ids
Expand All @@ -259,8 +260,8 @@ pub(super) struct UseDefMapBuilder<'db> {
/// Append-only array of [`Definition`]; None is unbound.
all_definitions: IndexVec<ScopedDefinitionId, Definition<'db>>,

/// Append-only array of constraints (as [`Expression`]).
all_constraints: IndexVec<ScopedConstraintId, Expression<'db>>,
/// Append-only array of [`Constraint`].
all_constraints: IndexVec<ScopedConstraintId, Constraint<'db>>,

/// Visible definitions at each so-far-recorded use.
definitions_by_use: IndexVec<ScopedUseId, SymbolState>,
Expand Down Expand Up @@ -290,7 +291,7 @@ impl<'db> UseDefMapBuilder<'db> {
self.definitions_by_symbol[symbol] = SymbolState::with(def_id);
}

pub(super) fn record_constraint(&mut self, constraint: Expression<'db>) {
pub(super) fn record_constraint(&mut self, constraint: Constraint<'db>) {
let constraint_id = self.all_constraints.push(constraint);
for definitions in &mut self.definitions_by_symbol {
definitions.add_constraint(constraint_id);
Expand Down
4 changes: 2 additions & 2 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ pub(crate) fn definitions_ty<'db>(
definition,
constraints,
}| {
let mut constraint_tys =
constraints.filter_map(|test| narrowing_constraint(db, test, definition));
let mut constraint_tys = constraints
.filter_map(|constraint| narrowing_constraint(db, constraint, definition));
let definition_ty = definition_ty(db, definition);
if let Some(first_constraint_ty) = constraint_tys.next() {
let mut builder = IntersectionBuilder::new(db);
Expand Down
86 changes: 86 additions & 0 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3483,6 +3483,65 @@ mod tests {
Ok(())
}

#[test]
fn match_with_wildcard() {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
match 0:
case 1:
y = 2
case _:
y = 3
",
)
.unwrap();

assert_public_ty(&db, "src/a.py", "y", "Literal[2, 3]");
}

#[test]
fn match_without_wildcard() {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
match 0:
case 1:
y = 2
case 2:
y = 3
",
)
.unwrap();

assert_public_ty(&db, "src/a.py", "y", "Unbound | Literal[2, 3]");
}

#[test]
fn match_stmt() {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
y = 1
y = 2
match 0:
case 1:
y = 3
case 2:
y = 4
",
)
.unwrap();

assert_public_ty(&db, "src/a.py", "y", "Literal[2, 3, 4]");
}

#[test]
fn import_cycle() -> anyhow::Result<()> {
let mut db = setup_db();
Expand Down Expand Up @@ -3797,6 +3856,33 @@ mod tests {
Ok(())
}

#[test]
fn narrow_singleton_pattern() {
let mut db = setup_db();

db.write_dedented(
"/src/a.py",
"
x = None if flag else 1
y = 0
match x:
case None:
y = x
",
)
.unwrap();

// TODO: The correct inferred type should be `Literal[0] | None` but currently the
// simplification logic doesn't account for this. The final type with parenthesis:
// `Literal[0] | (None | Literal[1] & None)`
assert_public_ty(
&db,
"/src/a.py",
"y",
"Literal[0] | None | Literal[1] & None",
);
}

#[test]
fn while_loop() -> anyhow::Result<()> {
let mut db = setup_db();
Expand Down
Loading

0 comments on commit d483753

Please sign in to comment.