diff --git a/Cargo.lock b/Cargo.lock index 018df705ac5ab..946cf29a1fc3a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1907,6 +1907,7 @@ dependencies = [ "ruff_text_size", "rustc-hash 2.0.0", "salsa", + "textwrap", "tracing", ] @@ -2844,6 +2845,12 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "smawk" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c388c1b5e93756d0c740965c41e8822f866621d41acbdf6336a6a168f8840c" + [[package]] name = "spin" version = "0.9.8" @@ -2995,6 +3002,17 @@ dependencies = [ "test-case-core", ] +[[package]] +name = "textwrap" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23d434d3f8967a09480fb04132ebe0a3e088c173e6d0ee7897abbdf4eab0f8b9" +dependencies = [ + "smawk", + "unicode-linebreak", + "unicode-width", +] + [[package]] name = "thiserror" version = "1.0.62" @@ -3250,6 +3268,12 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "unicode-linebreak" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b09c83c3c29d37506a3e260c08c03743a6bb66a9cd432c6934ab501a190571f" + [[package]] name = "unicode-normalization" version = "0.1.23" diff --git a/Cargo.toml b/Cargo.toml index 3604eb82493c4..7b96cd4101ca5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -127,6 +127,7 @@ strum_macros = { version = "0.26.0" } syn = { version = "2.0.55" } tempfile = { version = "3.9.0" } test-case = { version = "3.3.1" } +textwrap = { version = "0.16.1" } thiserror = { version = "1.0.58" } tikv-jemallocator = { version = "0.6.0" } toml = { version = "0.8.11" } diff --git a/crates/red_knot/src/lint.rs b/crates/red_knot/src/lint.rs index 7abe9b7b1bd53..5f70c032091a4 100644 --- a/crates/red_knot/src/lint.rs +++ b/crates/red_knot/src/lint.rs @@ -130,11 +130,7 @@ fn lint_bad_override(context: &SemanticLintContext, class: &ast::StmtClassDef) { return; }; - let Some(typing_override) = semantic.public_symbol(&typing, "override") else { - return; - }; - - let override_ty = semantic.public_symbol_ty(typing_override); + let override_ty = semantic.module_global_symbol_ty(&typing, "override"); let Type::Class(class_ty) = class.ty(semantic) else { return; @@ -154,7 +150,10 @@ fn lint_bad_override(context: &SemanticLintContext, class: &ast::StmtClassDef) { if ty.has_decorator(db, override_ty) { let method_name = ty.name(db); - if class_ty.inherited_class_member(db, &method_name).is_none() { + if class_ty + .inherited_class_member(db, &method_name) + .is_unbound() + { // TODO should have a qualname() method to support nested classes context.push_diagnostic( format!( diff --git a/crates/red_knot_python_semantic/Cargo.toml b/crates/red_knot_python_semantic/Cargo.toml index b314905d7aa64..07a2ed32208c7 100644 --- a/crates/red_knot_python_semantic/Cargo.toml +++ b/crates/red_knot_python_semantic/Cargo.toml @@ -27,6 +27,7 @@ hashbrown = { workspace = true } [dev-dependencies] anyhow = { workspace = true } ruff_python_parser = { workspace = true } +textwrap = { workspace = true } [lints] workspace = true diff --git a/crates/red_knot_python_semantic/src/ast_node_ref.rs b/crates/red_knot_python_semantic/src/ast_node_ref.rs index 118a1918a3634..94f7d5d268563 100644 --- a/crates/red_knot_python_semantic/src/ast_node_ref.rs +++ b/crates/red_knot_python_semantic/src/ast_node_ref.rs @@ -27,12 +27,13 @@ pub struct AstNodeRef { #[allow(unsafe_code)] impl AstNodeRef { - /// Creates a new `AstNodeRef` that reference `node`. The `parsed` is the [`ParsedModule`] to which - /// the `AstNodeRef` belongs. + /// Creates a new `AstNodeRef` that reference `node`. The `parsed` is the [`ParsedModule`] to + /// which the `AstNodeRef` belongs. /// /// ## Safety - /// Dereferencing the `node` can result in undefined behavior if `parsed` isn't the [`ParsedModule`] to - /// which `node` belongs. It's the caller's responsibility to ensure that the invariant `node belongs to parsed` is upheld. + /// Dereferencing the `node` can result in undefined behavior if `parsed` isn't the + /// [`ParsedModule`] to which `node` belongs. It's the caller's responsibility to ensure that + /// the invariant `node belongs to parsed` is upheld. pub(super) unsafe fn new(parsed: ParsedModule, node: &T) -> Self { Self { @@ -43,8 +44,8 @@ impl AstNodeRef { /// Returns a reference to the wrapped node. pub fn node(&self) -> &T { - // SAFETY: Holding on to `parsed` ensures that the AST to which `node` belongs is still alive - // and not moved. + // SAFETY: Holding on to `parsed` ensures that the AST to which `node` belongs is still + // alive and not moved. unsafe { self.node.as_ref() } } } diff --git a/crates/red_knot_python_semantic/src/db.rs b/crates/red_knot_python_semantic/src/db.rs index 5d375ad86f56c..6f901986017dd 100644 --- a/crates/red_knot_python_semantic/src/db.rs +++ b/crates/red_knot_python_semantic/src/db.rs @@ -4,27 +4,30 @@ use red_knot_module_resolver::Db as ResolverDb; use ruff_db::{Db as SourceDb, Upcast}; use crate::semantic_index::definition::Definition; -use crate::semantic_index::symbol::{public_symbols_map, PublicSymbolId, ScopeId}; -use crate::semantic_index::{root_scope, semantic_index, symbol_table}; +use crate::semantic_index::expression::Expression; +use crate::semantic_index::symbol::ScopeId; +use crate::semantic_index::{module_global_scope, semantic_index, symbol_table, use_def_map}; use crate::types::{ - infer_types, public_symbol_ty, ClassType, FunctionType, IntersectionType, UnionType, + infer_definition_types, infer_expression_types, infer_scope_types, ClassType, FunctionType, + IntersectionType, UnionType, }; #[salsa::jar(db=Db)] pub struct Jar( ScopeId<'_>, - PublicSymbolId<'_>, Definition<'_>, + Expression<'_>, FunctionType<'_>, ClassType<'_>, UnionType<'_>, IntersectionType<'_>, symbol_table, - root_scope, + use_def_map, + module_global_scope, semantic_index, - infer_types, - public_symbol_ty, - public_symbols_map, + infer_definition_types, + infer_expression_types, + infer_scope_types, ); /// Database giving access to semantic information about a Python program. diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index 354b5d382527d..25476d6055c21 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -10,17 +10,20 @@ use ruff_index::{IndexSlice, IndexVec}; use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey; use crate::semantic_index::ast_ids::AstIds; use crate::semantic_index::builder::SemanticIndexBuilder; -use crate::semantic_index::definition::{Definition, DefinitionNodeKey, DefinitionNodeRef}; +use crate::semantic_index::definition::{Definition, DefinitionNodeKey}; +use crate::semantic_index::expression::Expression; use crate::semantic_index::symbol::{ - FileScopeId, NodeWithScopeKey, NodeWithScopeRef, PublicSymbolId, Scope, ScopeId, - ScopedSymbolId, SymbolTable, + FileScopeId, NodeWithScopeKey, NodeWithScopeRef, Scope, ScopeId, ScopedSymbolId, SymbolTable, }; +use crate::semantic_index::usedef::UseDefMap; use crate::Db; pub mod ast_ids; mod builder; pub mod definition; +pub mod expression; pub mod symbol; +pub mod usedef; type SymbolMap = hashbrown::HashMap; @@ -42,57 +45,63 @@ pub(crate) fn semantic_index(db: &dyn Db, file: File) -> SemanticIndex<'_> { /// Salsa can avoid invalidating dependent queries if this scope's symbol table /// is unchanged. #[salsa::tracked] -pub(crate) fn symbol_table<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Arc> { +pub(crate) fn symbol_table<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Arc { let _span = tracing::trace_span!("symbol_table", ?scope).entered(); let index = semantic_index(db, scope.file(db)); index.symbol_table(scope.file_scope_id(db)) } -/// Returns the root scope of `file`. +/// Returns the use-def map for a specific `scope`. +/// +/// Using [`use_def_map`] over [`semantic_index`] has the advantage that +/// Salsa can avoid invalidating dependent queries if this scope's use-def map +/// is unchanged. #[salsa::tracked] -pub(crate) fn root_scope(db: &dyn Db, file: File) -> ScopeId<'_> { - let _span = tracing::trace_span!("root_scope", ?file).entered(); +pub(crate) fn use_def_map<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Arc> { + let _span = tracing::trace_span!("use_def_map", ?scope).entered(); + let index = semantic_index(db, scope.file(db)); - FileScopeId::root().to_scope_id(db, file) + index.use_def_map(scope.file_scope_id(db)) } -/// Returns the symbol with the given name in `file`'s public scope or `None` if -/// no symbol with the given name exists. -pub(crate) fn public_symbol<'db>( - db: &'db dyn Db, - file: File, - name: &str, -) -> Option> { - let root_scope = root_scope(db, file); - let symbol_table = symbol_table(db, root_scope); - let local = symbol_table.symbol_id_by_name(name)?; - Some(local.to_public_symbol(db, file)) +/// Returns the module global scope of `file`. +#[salsa::tracked] +pub(crate) fn module_global_scope(db: &dyn Db, file: File) -> ScopeId<'_> { + let _span = tracing::trace_span!("module_global_scope", ?file).entered(); + + FileScopeId::module_global().to_scope_id(db, file) } -/// The symbol tables for an entire file. +/// The symbol tables and use-def maps for all scopes in a file. #[derive(Debug)] pub(crate) struct SemanticIndex<'db> { /// List of all symbol tables in this file, indexed by scope. - symbol_tables: IndexVec>>, + symbol_tables: IndexVec>, /// List of all scopes in this file. scopes: IndexVec, - /// Maps expressions to their corresponding scope. + /// Map expressions to their corresponding scope. /// We can't use [`ExpressionId`] here, because the challenge is how to get from /// an [`ast::Expr`] to an [`ExpressionId`] (which requires knowing the scope). scopes_by_expression: FxHashMap, - /// Maps from a node creating a definition node to its definition. + /// Map from a node creating a definition to its definition. definitions_by_node: FxHashMap>, + /// Map from a standalone expression to its [`Expression`] ingredient. + expressions_by_node: FxHashMap>, + /// Map from nodes that create a scope to the scope they create. scopes_by_node: FxHashMap, /// Map from the file-local [`FileScopeId`] to the salsa-ingredient [`ScopeId`]. scope_ids_by_scope: IndexVec>, + /// Use-def map for each scope in this file. + use_def_maps: IndexVec>>, + /// Lookup table to map between node ids and ast nodes. /// /// Note: We should not depend on this map when analysing other files or @@ -105,10 +114,18 @@ impl<'db> SemanticIndex<'db> { /// /// Use the Salsa cached [`symbol_table`] query if you only need the /// symbol table for a single scope. - pub(super) fn symbol_table(&self, scope_id: FileScopeId) -> Arc> { + pub(super) fn symbol_table(&self, scope_id: FileScopeId) -> Arc { self.symbol_tables[scope_id].clone() } + /// Returns the use-def map for a specific scope. + /// + /// Use the Salsa cached [`use_def_map`] query if you only need the + /// use-def map for a single scope. + pub(super) fn use_def_map(&self, scope_id: FileScopeId) -> Arc { + self.use_def_maps[scope_id].clone() + } + pub(crate) fn ast_ids(&self, scope_id: FileScopeId) -> &AstIds { &self.ast_ids[scope_id] } @@ -157,16 +174,25 @@ impl<'db> SemanticIndex<'db> { } /// Returns an iterator over all ancestors of `scope`, starting with `scope` itself. + #[allow(unused)] pub(crate) fn ancestor_scopes(&self, scope: FileScopeId) -> AncestorsIter { AncestorsIter::new(self, scope) } - /// Returns the [`Definition`] salsa ingredient for `definition_node`. - pub(crate) fn definition<'def>( + /// Returns the [`Definition`] salsa ingredient for `definition_key`. + pub(crate) fn definition( &self, - definition_node: impl Into>, + definition_key: impl Into, ) -> Definition<'db> { - self.definitions_by_node[&definition_node.into().key()] + self.definitions_by_node[&definition_key.into()] + } + + /// Returns the [`Expression`] ingredient for an expression node. + pub(crate) fn expression( + &self, + expression_key: impl Into, + ) -> Expression<'db> { + self.expressions_by_node[&expression_key.into()] } /// Returns the id of the scope that `node` creates. This is different from [`Definition::scope`] which @@ -176,8 +202,6 @@ impl<'db> SemanticIndex<'db> { } } -/// ID that uniquely identifies an expression inside a [`Scope`]. - pub struct AncestorsIter<'a> { scopes: &'a IndexSlice, next_id: Option, @@ -278,7 +302,7 @@ mod tests { use crate::db::tests::TestDb; use crate::semantic_index::symbol::{FileScopeId, Scope, ScopeKind, SymbolTable}; - use crate::semantic_index::{root_scope, semantic_index, symbol_table}; + use crate::semantic_index::{module_global_scope, semantic_index, symbol_table, use_def_map}; use crate::Db; struct TestCase { @@ -305,95 +329,110 @@ mod tests { #[test] fn empty() { let TestCase { db, file } = test_case(""); - let root_table = symbol_table(&db, root_scope(&db, file)); + let module_global_table = symbol_table(&db, module_global_scope(&db, file)); - let root_names = names(&root_table); + let module_global_names = names(&module_global_table); - assert_eq!(root_names, Vec::<&str>::new()); + assert_eq!(module_global_names, Vec::<&str>::new()); } #[test] fn simple() { let TestCase { db, file } = test_case("x"); - let root_table = symbol_table(&db, root_scope(&db, file)); + let module_global_table = symbol_table(&db, module_global_scope(&db, file)); - assert_eq!(names(&root_table), vec!["x"]); + assert_eq!(names(&module_global_table), vec!["x"]); } #[test] fn annotation_only() { let TestCase { db, file } = test_case("x: int"); - let root_table = symbol_table(&db, root_scope(&db, file)); + let module_global_table = symbol_table(&db, module_global_scope(&db, file)); - assert_eq!(names(&root_table), vec!["int", "x"]); + assert_eq!(names(&module_global_table), vec!["int", "x"]); // TODO record definition } #[test] fn import() { let TestCase { db, file } = test_case("import foo"); - let root_table = symbol_table(&db, root_scope(&db, file)); + let scope = module_global_scope(&db, file); + let module_global_table = symbol_table(&db, scope); - assert_eq!(names(&root_table), vec!["foo"]); - let foo = root_table.symbol_by_name("foo").unwrap(); + assert_eq!(names(&module_global_table), vec!["foo"]); + let foo = module_global_table.symbol_id_by_name("foo").unwrap(); - assert_eq!(foo.definitions().len(), 1); + let use_def = use_def_map(&db, scope); + assert_eq!(use_def.public_definitions(foo).len(), 1); } #[test] fn import_sub() { let TestCase { db, file } = test_case("import foo.bar"); - let root_table = symbol_table(&db, root_scope(&db, file)); + let module_global_table = symbol_table(&db, module_global_scope(&db, file)); - assert_eq!(names(&root_table), vec!["foo"]); + assert_eq!(names(&module_global_table), vec!["foo"]); } #[test] fn import_as() { let TestCase { db, file } = test_case("import foo.bar as baz"); - let root_table = symbol_table(&db, root_scope(&db, file)); + let module_global_table = symbol_table(&db, module_global_scope(&db, file)); - assert_eq!(names(&root_table), vec!["baz"]); + assert_eq!(names(&module_global_table), vec!["baz"]); } #[test] fn import_from() { let TestCase { db, file } = test_case("from bar import foo"); - let root_table = symbol_table(&db, root_scope(&db, file)); + let scope = module_global_scope(&db, file); + let module_global_table = symbol_table(&db, scope); - assert_eq!(names(&root_table), vec!["foo"]); - assert_eq!( - root_table - .symbol_by_name("foo") - .unwrap() - .definitions() - .len(), - 1 - ); + assert_eq!(names(&module_global_table), vec!["foo"]); assert!( - root_table + module_global_table .symbol_by_name("foo") - .is_some_and(|symbol| { symbol.is_defined() || !symbol.is_used() }), + .is_some_and(|symbol| { symbol.is_defined() && !symbol.is_used() }), "symbols that are defined get the defined flag" ); + + let use_def = use_def_map(&db, scope); + assert_eq!( + use_def + .public_definitions( + module_global_table + .symbol_id_by_name("foo") + .expect("symbol exists") + ) + .len(), + 1 + ); } #[test] fn assign() { let TestCase { db, file } = test_case("x = foo"); - let root_table = symbol_table(&db, root_scope(&db, file)); + let scope = module_global_scope(&db, file); + let module_global_table = symbol_table(&db, scope); - assert_eq!(names(&root_table), vec!["foo", "x"]); - assert_eq!( - root_table.symbol_by_name("x").unwrap().definitions().len(), - 1 - ); + assert_eq!(names(&module_global_table), vec!["foo", "x"]); assert!( - root_table + module_global_table .symbol_by_name("foo") .is_some_and(|symbol| { !symbol.is_defined() && symbol.is_used() }), "a symbol used but not defined in a scope should have only the used flag" ); + let use_def = use_def_map(&db, scope); + assert_eq!( + use_def + .public_definitions( + module_global_table + .symbol_id_by_name("x") + .expect("symbol exists") + ) + .len(), + 1 + ); } #[test] @@ -405,13 +444,13 @@ class C: y = 2 ", ); - let root_table = symbol_table(&db, root_scope(&db, file)); + let module_global_table = symbol_table(&db, module_global_scope(&db, file)); - assert_eq!(names(&root_table), vec!["C", "y"]); + assert_eq!(names(&module_global_table), vec!["C", "y"]); let index = semantic_index(&db, file); - let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect(); + let scopes: Vec<_> = index.child_scopes(FileScopeId::module_global()).collect(); assert_eq!(scopes.len(), 1); let (class_scope_id, class_scope) = scopes[0]; @@ -421,8 +460,12 @@ y = 2 let class_table = index.symbol_table(class_scope_id); assert_eq!(names(&class_table), vec!["x"]); + + let use_def = index.use_def_map(class_scope_id); assert_eq!( - class_table.symbol_by_name("x").unwrap().definitions().len(), + use_def + .public_definitions(class_table.symbol_id_by_name("x").expect("symbol exists")) + .len(), 1 ); } @@ -437,11 +480,13 @@ y = 2 ", ); let index = semantic_index(&db, file); - let root_table = index.symbol_table(FileScopeId::root()); + let module_global_table = index.symbol_table(FileScopeId::module_global()); - assert_eq!(names(&root_table), vec!["func", "y"]); + assert_eq!(names(&module_global_table), vec!["func", "y"]); - let scopes = index.child_scopes(FileScopeId::root()).collect::>(); + let scopes = index + .child_scopes(FileScopeId::module_global()) + .collect::>(); assert_eq!(scopes.len(), 1); let (function_scope_id, function_scope) = scopes[0]; @@ -450,11 +495,15 @@ y = 2 let function_table = index.symbol_table(function_scope_id); assert_eq!(names(&function_table), vec!["x"]); + + let use_def = index.use_def_map(function_scope_id); assert_eq!( - function_table - .symbol_by_name("x") - .unwrap() - .definitions() + use_def + .public_definitions( + function_table + .symbol_id_by_name("x") + .expect("symbol exists") + ) .len(), 1 ); @@ -471,10 +520,10 @@ def func(): ", ); let index = semantic_index(&db, file); - let root_table = index.symbol_table(FileScopeId::root()); + let module_global_table = index.symbol_table(FileScopeId::module_global()); - assert_eq!(names(&root_table), vec!["func"]); - let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect(); + assert_eq!(names(&module_global_table), vec!["func"]); + let scopes: Vec<_> = index.child_scopes(FileScopeId::module_global()).collect(); assert_eq!(scopes.len(), 2); let (func_scope1_id, func_scope_1) = scopes[0]; @@ -490,13 +539,17 @@ def func(): let func2_table = index.symbol_table(func_scope2_id); assert_eq!(names(&func1_table), vec!["x"]); assert_eq!(names(&func2_table), vec!["y"]); + + let use_def = index.use_def_map(FileScopeId::module_global()); assert_eq!( - root_table - .symbol_by_name("func") - .unwrap() - .definitions() + use_def + .public_definitions( + module_global_table + .symbol_id_by_name("func") + .expect("symbol exists") + ) .len(), - 2 + 1 ); } @@ -510,11 +563,11 @@ def func[T](): ); let index = semantic_index(&db, file); - let root_table = index.symbol_table(FileScopeId::root()); + let module_global_table = index.symbol_table(FileScopeId::module_global()); - assert_eq!(names(&root_table), vec!["func"]); + assert_eq!(names(&module_global_table), vec!["func"]); - let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect(); + let scopes: Vec<_> = index.child_scopes(FileScopeId::module_global()).collect(); assert_eq!(scopes.len(), 1); let (ann_scope_id, ann_scope) = scopes[0]; @@ -542,11 +595,11 @@ class C[T]: ); let index = semantic_index(&db, file); - let root_table = index.symbol_table(FileScopeId::root()); + let module_global_table = index.symbol_table(FileScopeId::module_global()); - assert_eq!(names(&root_table), vec!["C"]); + assert_eq!(names(&module_global_table), vec!["C"]); - let scopes: Vec<_> = index.child_scopes(FileScopeId::root()).collect(); + let scopes: Vec<_> = index.child_scopes(FileScopeId::module_global()).collect(); assert_eq!(scopes.len(), 1); let (ann_scope_id, ann_scope) = scopes[0]; @@ -578,7 +631,7 @@ class C[T]: // let index = SemanticIndex::from_ast(ast); // let table = &index.symbol_table; // let x_sym = table - // .root_symbol_id_by_name("x") + // .module_global_symbol_id_by_name("x") // .expect("x symbol should exist"); // let ast::Stmt::Expr(ast::StmtExpr { value: x_use, .. }) = &ast.body[1] else { // panic!("should be an expr") @@ -616,7 +669,7 @@ class C[T]: let x = &x_stmt.targets[0]; assert_eq!(index.expression_scope(x).kind(), ScopeKind::Module); - assert_eq!(index.expression_scope_id(x), FileScopeId::root()); + assert_eq!(index.expression_scope_id(x), FileScopeId::module_global()); let def = ast.body[1].as_function_def_stmt().unwrap(); let y_stmt = def.body[0].as_assign_stmt().unwrap(); @@ -653,16 +706,20 @@ def x(): let index = semantic_index(&db, file); - let descendents = index.descendent_scopes(FileScopeId::root()); + let descendents = index.descendent_scopes(FileScopeId::module_global()); assert_eq!( scope_names(descendents, &db, file), vec!["Test", "foo", "bar", "baz", "x"] ); - let children = index.child_scopes(FileScopeId::root()); + let children = index.child_scopes(FileScopeId::module_global()); assert_eq!(scope_names(children, &db, file), vec!["Test", "x"]); - let test_class = index.child_scopes(FileScopeId::root()).next().unwrap().0; + let test_class = index + .child_scopes(FileScopeId::module_global()) + .next() + .unwrap() + .0; let test_child_scopes = index.child_scopes(test_class); assert_eq!( scope_names(test_child_scopes, &db, file), @@ -670,7 +727,7 @@ def x(): ); let bar_scope = index - .descendent_scopes(FileScopeId::root()) + .descendent_scopes(FileScopeId::module_global()) .nth(2) .unwrap() .0; diff --git a/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs b/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs index 86f17216b8650..1aa0a869f716a 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/ast_ids.rs @@ -1,6 +1,6 @@ use rustc_hash::FxHashMap; -use ruff_index::{newtype_index, Idx}; +use ruff_index::newtype_index; use ruff_python_ast as ast; use ruff_python_ast::ExpressionRef; @@ -28,18 +28,54 @@ use crate::Db; pub(crate) struct AstIds { /// Maps expressions to their expression id. Uses `NodeKey` because it avoids cloning [`Parsed`]. expressions_map: FxHashMap, + /// Maps expressions which "use" a symbol (that is, [`ExprName`]) to a use id. + uses_map: FxHashMap, } impl AstIds { fn expression_id(&self, key: impl Into) -> ScopedExpressionId { self.expressions_map[&key.into()] } + + fn use_id(&self, key: impl Into) -> ScopedUseId { + self.uses_map[&key.into()] + } } fn ast_ids<'db>(db: &'db dyn Db, scope: ScopeId) -> &'db AstIds { semantic_index(db, scope.file(db)).ast_ids(scope.file_scope_id(db)) } +pub trait HasScopedUseId { + /// The type of the ID uniquely identifying the use. + type Id: Copy; + + /// Returns the ID that uniquely identifies the use in `scope`. + fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id; +} + +/// Uniquely identifies a use of a name in a [`crate::semantic_index::symbol::FileScopeId`]. +#[newtype_index] +pub struct ScopedUseId; + +impl HasScopedUseId for ast::ExprName { + type Id = ScopedUseId; + + fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id { + let expression_ref = ExpressionRef::from(self); + expression_ref.scoped_use_id(db, scope) + } +} + +impl HasScopedUseId for ast::ExpressionRef<'_> { + type Id = ScopedUseId; + + fn scoped_use_id(&self, db: &dyn Db, scope: ScopeId) -> Self::Id { + let ast_ids = ast_ids(db, scope); + ast_ids.use_id(*self) + } +} + pub trait HasScopedAstId { /// The type of the ID uniquely identifying the node. type Id: Copy; @@ -110,38 +146,43 @@ impl HasScopedAstId for ast::ExpressionRef<'_> { #[derive(Debug)] pub(super) struct AstIdsBuilder { - next_id: ScopedExpressionId, expressions_map: FxHashMap, + uses_map: FxHashMap, } impl AstIdsBuilder { pub(super) fn new() -> Self { Self { - next_id: ScopedExpressionId::new(0), expressions_map: FxHashMap::default(), + uses_map: FxHashMap::default(), } } - /// Adds `expr` to the AST ids map and returns its id. - /// - /// ## Safety - /// The function is marked as unsafe because it calls [`AstNodeRef::new`] which requires - /// that `expr` is a child of `parsed`. - #[allow(unsafe_code)] + /// Adds `expr` to the expression ids map and returns its id. pub(super) fn record_expression(&mut self, expr: &ast::Expr) -> ScopedExpressionId { - let expression_id = self.next_id; - self.next_id = expression_id + 1; + let expression_id = self.expressions_map.len().into(); self.expressions_map.insert(expr.into(), expression_id); expression_id } + /// Adds `expr` to the use ids map and returns its id. + pub(super) fn record_use(&mut self, expr: &ast::Expr) -> ScopedUseId { + let use_id = self.uses_map.len().into(); + + self.uses_map.insert(expr.into(), use_id); + + use_id + } + pub(super) fn finish(mut self) -> AstIds { self.expressions_map.shrink_to_fit(); + self.uses_map.shrink_to_fit(); AstIds { expressions_map: self.expressions_map, + uses_map: self.uses_map, } } } diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index e492098a7ee2d..4bc4321d1723b 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -9,55 +9,62 @@ use ruff_python_ast as ast; use ruff_python_ast::name::Name; use ruff_python_ast::visitor::{walk_expr, walk_stmt, Visitor}; +use crate::ast_node_ref::AstNodeRef; use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey; use crate::semantic_index::ast_ids::AstIdsBuilder; -use crate::semantic_index::definition::{Definition, DefinitionNodeKey, DefinitionNodeRef}; +use crate::semantic_index::definition::{ + AssignmentDefinitionNodeRef, Definition, DefinitionNodeKey, DefinitionNodeRef, + ImportFromDefinitionNodeRef, +}; +use crate::semantic_index::expression::Expression; use crate::semantic_index::symbol::{ FileScopeId, NodeWithScopeKey, NodeWithScopeRef, Scope, ScopeId, ScopedSymbolId, SymbolFlags, SymbolTableBuilder, }; +use crate::semantic_index::usedef::{FlowSnapshot, UseDefMapBuilder}; use crate::semantic_index::SemanticIndex; use crate::Db; -pub(super) struct SemanticIndexBuilder<'db, 'ast> { +pub(super) struct SemanticIndexBuilder<'db> { // Builder state db: &'db dyn Db, file: File, module: &'db ParsedModule, scope_stack: Vec, - /// the target we're currently inferring - current_target: Option>, + /// the assignment we're currently visiting + current_assignment: Option>, // Semantic Index fields scopes: IndexVec, scope_ids_by_scope: IndexVec>, - symbol_tables: IndexVec>, + symbol_tables: IndexVec, ast_ids: IndexVec, + use_def_maps: IndexVec>, scopes_by_node: FxHashMap, scopes_by_expression: FxHashMap, definitions_by_node: FxHashMap>, + expressions_by_node: FxHashMap>, } -impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> -where - 'db: 'ast, -{ +impl<'db> SemanticIndexBuilder<'db> { pub(super) fn new(db: &'db dyn Db, file: File, parsed: &'db ParsedModule) -> Self { let mut builder = Self { db, file, module: parsed, scope_stack: Vec::new(), - current_target: None, + current_assignment: None, scopes: IndexVec::new(), symbol_tables: IndexVec::new(), ast_ids: IndexVec::new(), scope_ids_by_scope: IndexVec::new(), + use_def_maps: IndexVec::new(), scopes_by_expression: FxHashMap::default(), scopes_by_node: FxHashMap::default(), definitions_by_node: FxHashMap::default(), + expressions_by_node: FxHashMap::default(), }; builder.push_scope_with_parent(NodeWithScopeRef::Module, None); @@ -72,16 +79,12 @@ where .expect("Always to have a root scope") } - fn push_scope(&mut self, node: NodeWithScopeRef<'ast>) { + fn push_scope(&mut self, node: NodeWithScopeRef) { let parent = self.current_scope(); self.push_scope_with_parent(node, Some(parent)); } - fn push_scope_with_parent( - &mut self, - node: NodeWithScopeRef<'ast>, - parent: Option, - ) { + fn push_scope_with_parent(&mut self, node: NodeWithScopeRef, parent: Option) { let children_start = self.scopes.next_index() + 1; let scope = Scope { @@ -92,6 +95,7 @@ where let file_scope_id = self.scopes.push(scope); self.symbol_tables.push(SymbolTableBuilder::new()); + self.use_def_maps.push(UseDefMapBuilder::new()); let ast_id_scope = self.ast_ids.push(AstIdsBuilder::new()); #[allow(unsafe_code)] @@ -116,32 +120,54 @@ where id } - fn current_symbol_table(&mut self) -> &mut SymbolTableBuilder<'db> { + fn current_symbol_table(&mut self) -> &mut SymbolTableBuilder { let scope_id = self.current_scope(); &mut self.symbol_tables[scope_id] } + fn current_use_def_map(&mut self) -> &mut UseDefMapBuilder<'db> { + let scope_id = self.current_scope(); + &mut self.use_def_maps[scope_id] + } + fn current_ast_ids(&mut self) -> &mut AstIdsBuilder { let scope_id = self.current_scope(); &mut self.ast_ids[scope_id] } + fn flow_snapshot(&mut self) -> FlowSnapshot { + self.current_use_def_map().snapshot() + } + + fn flow_set(&mut self, state: &FlowSnapshot) { + self.current_use_def_map().set(state); + } + + fn flow_merge(&mut self, state: &FlowSnapshot) { + self.current_use_def_map().merge(state); + } + fn add_or_update_symbol(&mut self, name: Name, flags: SymbolFlags) -> ScopedSymbolId { let symbol_table = self.current_symbol_table(); - symbol_table.add_or_update_symbol(name, flags) + let (symbol_id, added) = symbol_table.add_or_update_symbol(name, flags); + if added { + let use_def_map = self.current_use_def_map(); + use_def_map.add_symbol(symbol_id); + } + symbol_id } - fn add_definition( + fn add_definition<'a>( &mut self, - definition_node: impl Into>, - symbol_id: ScopedSymbolId, + symbol: ScopedSymbolId, + definition_node: impl Into>, ) -> Definition<'db> { let definition_node = definition_node.into(); let definition = Definition::new( self.db, self.file, self.current_scope(), - symbol_id, + symbol, #[allow(unsafe_code)] unsafe { definition_node.into_owned(self.module.clone()) @@ -150,26 +176,30 @@ where self.definitions_by_node .insert(definition_node.key(), definition); + self.current_use_def_map().record_def(symbol, definition); definition } - fn add_or_update_symbol_with_definition( - &mut self, - name: Name, - definition: impl Into>, - ) -> (ScopedSymbolId, Definition<'db>) { - let symbol_table = self.current_symbol_table(); - - let id = symbol_table.add_or_update_symbol(name, SymbolFlags::IS_DEFINED); - let definition = self.add_definition(definition, id); - self.current_symbol_table().add_definition(id, definition); - (id, definition) + /// Record an expression that needs to be a Salsa ingredient, because we need to infer its type + /// standalone (type narrowing tests, RHS of an assignment.) + fn add_standalone_expression(&mut self, expression_node: &ast::Expr) { + let expression = Expression::new( + self.db, + self.file, + self.current_scope(), + #[allow(unsafe_code)] + unsafe { + AstNodeRef::new(self.module.clone(), expression_node) + }, + ); + self.expressions_by_node + .insert(expression_node.into(), expression); } fn with_type_params( &mut self, - with_params: &WithTypeParams<'ast>, + with_params: &WithTypeParams, nested: impl FnOnce(&mut Self) -> FileScopeId, ) -> FileScopeId { let type_params = with_params.type_parameters(); @@ -213,7 +243,7 @@ where self.pop_scope(); assert!(self.scope_stack.is_empty()); - assert!(self.current_target.is_none()); + assert!(self.current_assignment.is_none()); let mut symbol_tables: IndexVec<_, _> = self .symbol_tables @@ -221,6 +251,12 @@ where .map(|builder| Arc::new(builder.finish())) .collect(); + let mut use_def_maps: IndexVec<_, _> = self + .use_def_maps + .into_iter() + .map(|builder| Arc::new(builder.finish())) + .collect(); + let mut ast_ids: IndexVec<_, _> = self .ast_ids .into_iter() @@ -228,8 +264,9 @@ where .collect(); self.scopes.shrink_to_fit(); - ast_ids.shrink_to_fit(); symbol_tables.shrink_to_fit(); + use_def_maps.shrink_to_fit(); + ast_ids.shrink_to_fit(); self.scopes_by_expression.shrink_to_fit(); self.definitions_by_node.shrink_to_fit(); @@ -240,17 +277,19 @@ where symbol_tables, scopes: self.scopes, definitions_by_node: self.definitions_by_node, + expressions_by_node: self.expressions_by_node, scope_ids_by_scope: self.scope_ids_by_scope, ast_ids, scopes_by_expression: self.scopes_by_expression, scopes_by_node: self.scopes_by_node, + use_def_maps, } } } -impl<'db, 'ast> Visitor<'ast> for SemanticIndexBuilder<'db, 'ast> +impl<'db, 'ast> Visitor<'ast> for SemanticIndexBuilder<'db> where - 'db: 'ast, + 'ast: 'db, { fn visit_stmt(&mut self, stmt: &'ast ast::Stmt) { match stmt { @@ -259,10 +298,9 @@ where self.visit_decorator(decorator); } - self.add_or_update_symbol_with_definition( - function_def.name.id.clone(), - function_def, - ); + let symbol = self + .add_or_update_symbol(function_def.name.id.clone(), SymbolFlags::IS_DEFINED); + self.add_definition(symbol, function_def); self.with_type_params( &WithTypeParams::FunctionDef { node: function_def }, @@ -283,7 +321,9 @@ where self.visit_decorator(decorator); } - self.add_or_update_symbol_with_definition(class.name.id.clone(), class); + let symbol = + self.add_or_update_symbol(class.name.id.clone(), SymbolFlags::IS_DEFINED); + self.add_definition(symbol, class); self.with_type_params(&WithTypeParams::ClassDef { node: class }, |builder| { if let Some(arguments) = &class.arguments { @@ -296,41 +336,84 @@ where builder.pop_scope() }); } - ast::Stmt::Import(ast::StmtImport { names, .. }) => { - for alias in names { + ast::Stmt::Import(node) => { + for alias in &node.names { let symbol_name = if let Some(asname) = &alias.asname { asname.id.clone() } else { Name::new(alias.name.id.split('.').next().unwrap()) }; - self.add_or_update_symbol_with_definition(symbol_name, alias); + let symbol = self.add_or_update_symbol(symbol_name, SymbolFlags::IS_DEFINED); + self.add_definition(symbol, alias); } } - ast::Stmt::ImportFrom(ast::StmtImportFrom { - module: _, - names, - level: _, - .. - }) => { - for alias in names { + ast::Stmt::ImportFrom(node) => { + for (alias_index, alias) in node.names.iter().enumerate() { let symbol_name = if let Some(asname) = &alias.asname { &asname.id } else { &alias.name.id }; - self.add_or_update_symbol_with_definition(symbol_name.clone(), alias); + let symbol = + self.add_or_update_symbol(symbol_name.clone(), SymbolFlags::IS_DEFINED); + self.add_definition(symbol, ImportFromDefinitionNodeRef { node, alias_index }); } } ast::Stmt::Assign(node) => { - debug_assert!(self.current_target.is_none()); + debug_assert!(self.current_assignment.is_none()); self.visit_expr(&node.value); + self.add_standalone_expression(&node.value); + self.current_assignment = Some(node.into()); for target in &node.targets { - self.current_target = Some(CurrentTarget::Expr(target)); self.visit_expr(target); } - self.current_target = None; + self.current_assignment = None; + } + ast::Stmt::AnnAssign(node) => { + debug_assert!(self.current_assignment.is_none()); + // TODO deferred annotation visiting + self.visit_expr(&node.annotation); + match &node.value { + Some(value) => { + self.visit_expr(value); + self.current_assignment = Some(node.into()); + self.visit_expr(&node.target); + self.current_assignment = None; + } + None => { + // TODO annotation-only assignments + self.visit_expr(&node.target); + } + } + } + ast::Stmt::If(node) => { + self.visit_expr(&node.test); + let pre_if = self.flow_snapshot(); + self.visit_body(&node.body); + let mut last_clause_is_else = false; + let mut post_clauses: Vec = vec![self.flow_snapshot()]; + for clause in &node.elif_else_clauses { + // we can only take an elif/else clause if none of the previous ones were taken + self.flow_set(&pre_if); + self.visit_elif_else_clause(clause); + post_clauses.push(self.flow_snapshot()); + if clause.test.is_none() { + last_clause_is_else = true; + } + } + let mut post_clause_iter = post_clauses.iter(); + if last_clause_is_else { + // if the last clause was an else, the pre_if state can't directly reach the + // post-state; we have to enter one of the clauses. + self.flow_set(post_clause_iter.next().unwrap()); + } else { + self.flow_set(&pre_if); + } + for post_clause_state in post_clause_iter { + self.flow_merge(post_clause_state); + } } _ => { walk_stmt(self, stmt); @@ -344,57 +427,64 @@ where self.current_ast_ids().record_expression(expr); match expr { - ast::Expr::Name(ast::ExprName { id, ctx, .. }) => { + ast::Expr::Name(name_node) => { + let ast::ExprName { id, ctx, .. } = name_node; let flags = match ctx { ast::ExprContext::Load => SymbolFlags::IS_USED, ast::ExprContext::Store => SymbolFlags::IS_DEFINED, ast::ExprContext::Del => SymbolFlags::IS_DEFINED, ast::ExprContext::Invalid => SymbolFlags::empty(), }; - match self.current_target { - Some(target) if flags.contains(SymbolFlags::IS_DEFINED) => { - self.add_or_update_symbol_with_definition(id.clone(), target); - } - _ => { - self.add_or_update_symbol(id.clone(), flags); + let symbol = self.add_or_update_symbol(id.clone(), flags); + if flags.contains(SymbolFlags::IS_DEFINED) { + match self.current_assignment { + Some(CurrentAssignment::Assign(assignment)) => { + self.add_definition( + symbol, + AssignmentDefinitionNodeRef { + assignment, + target: name_node, + }, + ); + } + Some(CurrentAssignment::AnnAssign(ann_assign)) => { + self.add_definition(symbol, ann_assign); + } + Some(CurrentAssignment::Named(named)) => { + self.add_definition(symbol, named); + } + None => {} } } + if flags.contains(SymbolFlags::IS_USED) { + let use_id = self.current_ast_ids().record_use(expr); + self.current_use_def_map().record_use(symbol, use_id); + } + walk_expr(self, expr); } ast::Expr::Named(node) => { - debug_assert!(self.current_target.is_none()); - self.current_target = Some(CurrentTarget::ExprNamed(node)); + debug_assert!(self.current_assignment.is_none()); + self.current_assignment = Some(node.into()); // TODO walrus in comprehensions is implicitly nonlocal self.visit_expr(&node.target); - self.current_target = None; + self.current_assignment = None; self.visit_expr(&node.value); } ast::Expr::If(ast::ExprIf { body, test, orelse, .. }) => { // TODO detect statically known truthy or falsy test (via type inference, not naive - // AST inspection, so we can't simplify here, need to record test expression in CFG - // for later checking) - + // AST inspection, so we can't simplify here, need to record test expression for + // later checking) self.visit_expr(test); - - // let if_branch = self.flow_graph_builder.add_branch(self.current_flow_node()); - - // self.set_current_flow_node(if_branch); - // self.insert_constraint(test); + let pre_if = self.flow_snapshot(); self.visit_expr(body); - - // let post_body = self.current_flow_node(); - - // self.set_current_flow_node(if_branch); + let post_body = self.flow_snapshot(); + self.flow_set(&pre_if); self.visit_expr(orelse); - - // let post_else = self - // .flow_graph_builder - // .add_phi(self.current_flow_node(), post_body); - - // self.set_current_flow_node(post_else); + self.flow_merge(&post_body); } _ => { walk_expr(self, expr); @@ -418,16 +508,26 @@ impl<'node> WithTypeParams<'node> { } #[derive(Copy, Clone, Debug)] -enum CurrentTarget<'a> { - Expr(&'a ast::Expr), - ExprNamed(&'a ast::ExprNamed), +enum CurrentAssignment<'a> { + Assign(&'a ast::StmtAssign), + AnnAssign(&'a ast::StmtAnnAssign), + Named(&'a ast::ExprNamed), } -impl<'a> From> for DefinitionNodeRef<'a> { - fn from(val: CurrentTarget<'a>) -> Self { - match val { - CurrentTarget::Expr(expression) => DefinitionNodeRef::Target(expression), - CurrentTarget::ExprNamed(named) => DefinitionNodeRef::NamedExpression(named), - } +impl<'a> From<&'a ast::StmtAssign> for CurrentAssignment<'a> { + fn from(value: &'a ast::StmtAssign) -> Self { + Self::Assign(value) + } +} + +impl<'a> From<&'a ast::StmtAnnAssign> for CurrentAssignment<'a> { + fn from(value: &'a ast::StmtAnnAssign) -> Self { + Self::AnnAssign(value) + } +} + +impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> { + fn from(value: &'a ast::ExprNamed) -> Self { + Self::Named(value) } } diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index a9cf7cf1f0770..ff114a5856858 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -4,63 +4,111 @@ use ruff_python_ast as ast; use crate::ast_node_ref::AstNodeRef; use crate::node_key::NodeKey; -use crate::semantic_index::symbol::{FileScopeId, ScopedSymbolId}; +use crate::semantic_index::symbol::{FileScopeId, ScopeId, ScopedSymbolId}; +use crate::Db; #[salsa::tracked] pub struct Definition<'db> { - /// The file in which the definition is defined. + /// The file in which the definition occurs. #[id] - pub(super) file: File, + pub(crate) file: File, - /// The scope in which the definition is defined. + /// The scope in which the definition occurs. #[id] - pub(crate) scope: FileScopeId, + pub(crate) file_scope: FileScopeId, - /// The id of the corresponding symbol. Mainly used as ID. + /// The symbol defined. #[id] - symbol_id: ScopedSymbolId, + pub(crate) symbol: ScopedSymbolId, #[no_eq] #[return_ref] pub(crate) node: DefinitionKind, } +impl<'db> Definition<'db> { + pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> { + self.file_scope(db).to_scope_id(db, self.file(db)) + } +} + #[derive(Copy, Clone, Debug)] pub(crate) enum DefinitionNodeRef<'a> { - Alias(&'a ast::Alias), + Import(&'a ast::Alias), + ImportFrom(ImportFromDefinitionNodeRef<'a>), Function(&'a ast::StmtFunctionDef), Class(&'a ast::StmtClassDef), NamedExpression(&'a ast::ExprNamed), - Target(&'a ast::Expr), + Assignment(AssignmentDefinitionNodeRef<'a>), + AnnotatedAssignment(&'a ast::StmtAnnAssign), } -impl<'a> From<&'a ast::Alias> for DefinitionNodeRef<'a> { - fn from(node: &'a ast::Alias) -> Self { - Self::Alias(node) - } -} impl<'a> From<&'a ast::StmtFunctionDef> for DefinitionNodeRef<'a> { fn from(node: &'a ast::StmtFunctionDef) -> Self { Self::Function(node) } } + impl<'a> From<&'a ast::StmtClassDef> for DefinitionNodeRef<'a> { fn from(node: &'a ast::StmtClassDef) -> Self { Self::Class(node) } } + impl<'a> From<&'a ast::ExprNamed> for DefinitionNodeRef<'a> { fn from(node: &'a ast::ExprNamed) -> Self { Self::NamedExpression(node) } } +impl<'a> From<&'a ast::StmtAnnAssign> for DefinitionNodeRef<'a> { + fn from(node: &'a ast::StmtAnnAssign) -> Self { + Self::AnnotatedAssignment(node) + } +} + +impl<'a> From<&'a ast::Alias> for DefinitionNodeRef<'a> { + fn from(node_ref: &'a ast::Alias) -> Self { + Self::Import(node_ref) + } +} + +impl<'a> From> for DefinitionNodeRef<'a> { + fn from(node_ref: ImportFromDefinitionNodeRef<'a>) -> Self { + Self::ImportFrom(node_ref) + } +} + +impl<'a> From> for DefinitionNodeRef<'a> { + fn from(node_ref: AssignmentDefinitionNodeRef<'a>) -> Self { + Self::Assignment(node_ref) + } +} + +#[derive(Copy, Clone, Debug)] +pub(crate) struct ImportFromDefinitionNodeRef<'a> { + pub(crate) node: &'a ast::StmtImportFrom, + pub(crate) alias_index: usize, +} + +#[derive(Copy, Clone, Debug)] +pub(crate) struct AssignmentDefinitionNodeRef<'a> { + pub(crate) assignment: &'a ast::StmtAssign, + pub(crate) target: &'a ast::ExprName, +} + impl DefinitionNodeRef<'_> { #[allow(unsafe_code)] pub(super) unsafe fn into_owned(self, parsed: ParsedModule) -> DefinitionKind { match self { - DefinitionNodeRef::Alias(alias) => { - DefinitionKind::Alias(AstNodeRef::new(parsed, alias)) + DefinitionNodeRef::Import(alias) => { + DefinitionKind::Import(AstNodeRef::new(parsed, alias)) + } + DefinitionNodeRef::ImportFrom(ImportFromDefinitionNodeRef { node, alias_index }) => { + DefinitionKind::ImportFrom(ImportFromDefinitionKind { + node: AstNodeRef::new(parsed, node), + alias_index, + }) } DefinitionNodeRef::Function(function) => { DefinitionKind::Function(AstNodeRef::new(parsed, function)) @@ -71,33 +119,111 @@ impl DefinitionNodeRef<'_> { DefinitionNodeRef::NamedExpression(named) => { DefinitionKind::NamedExpression(AstNodeRef::new(parsed, named)) } - DefinitionNodeRef::Target(target) => { - DefinitionKind::Target(AstNodeRef::new(parsed, target)) + DefinitionNodeRef::Assignment(AssignmentDefinitionNodeRef { assignment, target }) => { + DefinitionKind::Assignment(AssignmentDefinitionKind { + assignment: AstNodeRef::new(parsed.clone(), assignment), + target: AstNodeRef::new(parsed, target), + }) + } + DefinitionNodeRef::AnnotatedAssignment(assign) => { + DefinitionKind::AnnotatedAssignment(AstNodeRef::new(parsed, assign)) } } } -} -impl DefinitionNodeRef<'_> { pub(super) fn key(self) -> DefinitionNodeKey { match self { - Self::Alias(node) => DefinitionNodeKey(NodeKey::from_node(node)), - Self::Function(node) => DefinitionNodeKey(NodeKey::from_node(node)), - Self::Class(node) => DefinitionNodeKey(NodeKey::from_node(node)), - Self::NamedExpression(node) => DefinitionNodeKey(NodeKey::from_node(node)), - Self::Target(node) => DefinitionNodeKey(NodeKey::from_node(node)), + Self::Import(node) => node.into(), + Self::ImportFrom(ImportFromDefinitionNodeRef { node, alias_index }) => { + (&node.names[alias_index]).into() + } + Self::Function(node) => node.into(), + Self::Class(node) => node.into(), + Self::NamedExpression(node) => node.into(), + Self::Assignment(AssignmentDefinitionNodeRef { + assignment: _, + target, + }) => target.into(), + Self::AnnotatedAssignment(node) => node.into(), } } } #[derive(Clone, Debug)] pub enum DefinitionKind { - Alias(AstNodeRef), + Import(AstNodeRef), + ImportFrom(ImportFromDefinitionKind), Function(AstNodeRef), Class(AstNodeRef), NamedExpression(AstNodeRef), - Target(AstNodeRef), + Assignment(AssignmentDefinitionKind), + AnnotatedAssignment(AstNodeRef), +} + +#[derive(Clone, Debug)] +pub struct ImportFromDefinitionKind { + node: AstNodeRef, + alias_index: usize, +} + +impl ImportFromDefinitionKind { + pub(crate) fn import(&self) -> &ast::StmtImportFrom { + self.node.node() + } + + pub(crate) fn alias(&self) -> &ast::Alias { + &self.node.node().names[self.alias_index] + } +} + +#[derive(Clone, Debug)] +#[allow(dead_code)] +pub struct AssignmentDefinitionKind { + assignment: AstNodeRef, + target: AstNodeRef, +} + +impl AssignmentDefinitionKind { + pub(crate) fn assignment(&self) -> &ast::StmtAssign { + self.assignment.node() + } } #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] -pub(super) struct DefinitionNodeKey(NodeKey); +pub(crate) struct DefinitionNodeKey(NodeKey); + +impl From<&ast::Alias> for DefinitionNodeKey { + fn from(node: &ast::Alias) -> Self { + Self(NodeKey::from_node(node)) + } +} + +impl From<&ast::StmtFunctionDef> for DefinitionNodeKey { + fn from(node: &ast::StmtFunctionDef) -> Self { + Self(NodeKey::from_node(node)) + } +} + +impl From<&ast::StmtClassDef> for DefinitionNodeKey { + fn from(node: &ast::StmtClassDef) -> Self { + Self(NodeKey::from_node(node)) + } +} + +impl From<&ast::ExprName> for DefinitionNodeKey { + fn from(node: &ast::ExprName) -> Self { + Self(NodeKey::from_node(node)) + } +} + +impl From<&ast::ExprNamed> for DefinitionNodeKey { + fn from(node: &ast::ExprNamed) -> Self { + Self(NodeKey::from_node(node)) + } +} + +impl From<&ast::StmtAnnAssign> for DefinitionNodeKey { + fn from(node: &ast::StmtAnnAssign) -> Self { + Self(NodeKey::from_node(node)) + } +} diff --git a/crates/red_knot_python_semantic/src/semantic_index/expression.rs b/crates/red_knot_python_semantic/src/semantic_index/expression.rs new file mode 100644 index 0000000000000..23f48ca416fdf --- /dev/null +++ b/crates/red_knot_python_semantic/src/semantic_index/expression.rs @@ -0,0 +1,31 @@ +use crate::ast_node_ref::AstNodeRef; +use crate::db::Db; +use crate::semantic_index::symbol::{FileScopeId, ScopeId}; +use ruff_db::files::File; +use ruff_python_ast as ast; +use salsa; + +/// An independently type-inferable expression. +/// +/// Includes constraint expressions (e.g. if tests) and the RHS of an unpacking assignment. +#[salsa::tracked] +pub(crate) struct Expression<'db> { + /// The file in which the expression occurs. + #[id] + pub(crate) file: File, + + /// The scope in which the expression occurs. + #[id] + pub(crate) file_scope: FileScopeId, + + /// The expression node. + #[no_eq] + #[return_ref] + pub(crate) node: AstNodeRef, +} + +impl<'db> Expression<'db> { + pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> { + self.file_scope(db).to_scope_id(db, self.file(db)) + } +} diff --git a/crates/red_knot_python_semantic/src/semantic_index/symbol.rs b/crates/red_knot_python_semantic/src/semantic_index/symbol.rs index ce4edecf3593a..6deab6ba10b70 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/symbol.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/symbol.rs @@ -12,33 +12,23 @@ use rustc_hash::FxHasher; use crate::ast_node_ref::AstNodeRef; use crate::node_key::NodeKey; -use crate::semantic_index::definition::Definition; -use crate::semantic_index::{root_scope, semantic_index, symbol_table, SymbolMap}; +use crate::semantic_index::{semantic_index, SymbolMap}; use crate::Db; #[derive(Eq, PartialEq, Debug)] -pub struct Symbol<'db> { +pub struct Symbol { name: Name, flags: SymbolFlags, - /// The nodes that define this symbol, in source order. - /// - /// TODO: Use smallvec here, but it creates the same lifetime issues as in [QualifiedName](https://github.com/astral-sh/ruff/blob/5109b50bb3847738eeb209352cf26bda392adf62/crates/ruff_python_ast/src/name.rs#L562-L569) - definitions: Vec>, } -impl<'db> Symbol<'db> { +impl Symbol { fn new(name: Name) -> Self { Self { name, flags: SymbolFlags::empty(), - definitions: Vec::new(), } } - fn push_definition(&mut self, definition: Definition<'db>) { - self.definitions.push(definition); - } - fn insert_flags(&mut self, flags: SymbolFlags) { self.flags.insert(flags); } @@ -57,10 +47,6 @@ impl<'db> Symbol<'db> { pub fn is_defined(&self) -> bool { self.flags.contains(SymbolFlags::IS_DEFINED) } - - pub fn definitions(&self) -> &[Definition] { - &self.definitions - } } bitflags! { @@ -75,15 +61,6 @@ bitflags! { } } -/// ID that uniquely identifies a public symbol defined in a module's root scope. -#[salsa::tracked] -pub struct PublicSymbolId<'db> { - #[id] - pub(crate) file: File, - #[id] - pub(crate) scoped_symbol_id: ScopedSymbolId, -} - /// ID that uniquely identifies a symbol in a file. #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] pub struct FileSymbolId { @@ -111,47 +88,6 @@ impl From for ScopedSymbolId { #[newtype_index] pub struct ScopedSymbolId; -impl ScopedSymbolId { - /// Converts the symbol to a public symbol. - /// - /// # Panics - /// May panic if the symbol does not belong to `file` or is not a symbol of `file`'s root scope. - pub(crate) fn to_public_symbol(self, db: &dyn Db, file: File) -> PublicSymbolId { - let symbols = public_symbols_map(db, file); - symbols.public(self) - } -} - -#[salsa::tracked(return_ref)] -pub(crate) fn public_symbols_map(db: &dyn Db, file: File) -> PublicSymbolsMap<'_> { - let _span = tracing::trace_span!("public_symbols_map", ?file).entered(); - - let module_scope = root_scope(db, file); - let symbols = symbol_table(db, module_scope); - - let public_symbols: IndexVec<_, _> = symbols - .symbol_ids() - .map(|id| PublicSymbolId::new(db, file, id)) - .collect(); - - PublicSymbolsMap { - symbols: public_symbols, - } -} - -/// Maps [`LocalSymbolId`] of a file's root scope to the corresponding [`PublicSymbolId`] (Salsa ingredients). -#[derive(Eq, PartialEq, Debug)] -pub(crate) struct PublicSymbolsMap<'db> { - symbols: IndexVec>, -} - -impl<'db> PublicSymbolsMap<'db> { - /// Resolve the [`PublicSymbolId`] for the module-level `symbol_id`. - fn public(&self, symbol_id: ScopedSymbolId) -> PublicSymbolId<'db> { - self.symbols[symbol_id] - } -} - /// A cross-module identifier of a scope that can be used as a salsa query parameter. #[salsa::tracked] pub struct ScopeId<'db> { @@ -185,8 +121,8 @@ impl<'db> ScopeId<'db> { pub struct FileScopeId; impl FileScopeId { - /// Returns the scope id of the Root scope. - pub fn root() -> Self { + /// Returns the scope id of the module-global scope. + pub fn module_global() -> Self { FileScopeId::from_u32(0) } @@ -223,15 +159,15 @@ pub enum ScopeKind { /// Symbol table for a specific [`Scope`]. #[derive(Debug)] -pub struct SymbolTable<'db> { +pub struct SymbolTable { /// The symbols in this scope. - symbols: IndexVec>, + symbols: IndexVec, /// The symbols indexed by name. symbols_by_name: SymbolMap, } -impl<'db> SymbolTable<'db> { +impl SymbolTable { fn new() -> Self { Self { symbols: IndexVec::new(), @@ -243,21 +179,22 @@ impl<'db> SymbolTable<'db> { self.symbols.shrink_to_fit(); } - pub(crate) fn symbol(&self, symbol_id: impl Into) -> &Symbol<'db> { + pub(crate) fn symbol(&self, symbol_id: impl Into) -> &Symbol { &self.symbols[symbol_id.into()] } - pub(crate) fn symbol_ids(&self) -> impl Iterator + 'db { + #[allow(unused)] + pub(crate) fn symbol_ids(&self) -> impl Iterator { self.symbols.indices() } - pub fn symbols(&self) -> impl Iterator> { + pub fn symbols(&self) -> impl Iterator { self.symbols.iter() } /// Returns the symbol named `name`. #[allow(unused)] - pub(crate) fn symbol_by_name(&self, name: &str) -> Option<&Symbol<'db>> { + pub(crate) fn symbol_by_name(&self, name: &str) -> Option<&Symbol> { let id = self.symbol_id_by_name(name)?; Some(self.symbol(id)) } @@ -281,21 +218,21 @@ impl<'db> SymbolTable<'db> { } } -impl PartialEq for SymbolTable<'_> { +impl PartialEq for SymbolTable { fn eq(&self, other: &Self) -> bool { // We don't need to compare the symbols_by_name because the name is already captured in `Symbol`. self.symbols == other.symbols } } -impl Eq for SymbolTable<'_> {} +impl Eq for SymbolTable {} #[derive(Debug)] -pub(super) struct SymbolTableBuilder<'db> { - table: SymbolTable<'db>, +pub(super) struct SymbolTableBuilder { + table: SymbolTable, } -impl<'db> SymbolTableBuilder<'db> { +impl SymbolTableBuilder { pub(super) fn new() -> Self { Self { table: SymbolTable::new(), @@ -306,7 +243,7 @@ impl<'db> SymbolTableBuilder<'db> { &mut self, name: Name, flags: SymbolFlags, - ) -> ScopedSymbolId { + ) -> (ScopedSymbolId, bool) { let hash = SymbolTable::hash_name(&name); let entry = self .table @@ -319,7 +256,7 @@ impl<'db> SymbolTableBuilder<'db> { let symbol = &mut self.table.symbols[*entry.key()]; symbol.insert_flags(flags); - *entry.key() + (*entry.key(), false) } RawEntryMut::Vacant(entry) => { let mut symbol = Symbol::new(name); @@ -329,16 +266,12 @@ impl<'db> SymbolTableBuilder<'db> { entry.insert_with_hasher(hash, id, (), |id| { SymbolTable::hash_name(self.table.symbols[*id].name().as_str()) }); - id + (id, true) } } } - pub(super) fn add_definition(&mut self, symbol: ScopedSymbolId, definition: Definition<'db>) { - self.table.symbols[symbol].push_definition(definition); - } - - pub(super) fn finish(mut self) -> SymbolTable<'db> { + pub(super) fn finish(mut self) -> SymbolTable { self.table.shrink_to_fit(); self.table } diff --git a/crates/red_knot_python_semantic/src/semantic_index/usedef.rs b/crates/red_knot_python_semantic/src/semantic_index/usedef.rs new file mode 100644 index 0000000000000..17b82690efdf4 --- /dev/null +++ b/crates/red_knot_python_semantic/src/semantic_index/usedef.rs @@ -0,0 +1,160 @@ +use crate::semantic_index::ast_ids::ScopedUseId; +use crate::semantic_index::definition::Definition; +use crate::semantic_index::symbol::ScopedSymbolId; +use ruff_index::IndexVec; +use std::ops::Range; + +/// All definitions that can reach a given use of a name. +#[derive(Debug, PartialEq, Eq)] +pub(crate) struct UseDefMap<'db> { + // TODO store constraints with definitions for type narrowing + all_definitions: Vec>, + + /// Definitions that can reach a [`ScopedUseId`]. + definitions_by_use: IndexVec, + + /// Definitions of a symbol visible to other scopes. + public_definitions: IndexVec, +} + +impl<'db> UseDefMap<'db> { + pub(crate) fn use_definitions(&self, use_id: ScopedUseId) -> &[Definition<'db>] { + &self.all_definitions[self.definitions_by_use[use_id].definitions.clone()] + } + + pub(crate) fn use_may_be_unbound(&self, use_id: ScopedUseId) -> bool { + self.definitions_by_use[use_id].may_be_unbound + } + + pub(crate) fn public_definitions(&self, symbol: ScopedSymbolId) -> &[Definition<'db>] { + &self.all_definitions[self.public_definitions[symbol].definitions.clone()] + } + + pub(crate) fn public_may_be_unbound(&self, symbol: ScopedSymbolId) -> bool { + self.public_definitions[symbol].may_be_unbound + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +struct Definitions { + definitions: Range, + may_be_unbound: bool, +} + +impl Default for Definitions { + fn default() -> Self { + Self { + definitions: 0..0, + may_be_unbound: true, + } + } +} + +#[derive(Debug)] +pub(super) struct FlowSnapshot { + definitions_by_symbol: IndexVec, +} + +pub(super) struct UseDefMapBuilder<'db> { + all_definitions: Vec>, + + definitions_by_use: IndexVec, + + // builder state: currently visible definitions for each symbol + definitions_by_symbol: IndexVec, +} + +impl<'db> UseDefMapBuilder<'db> { + pub(super) fn new() -> Self { + Self { + all_definitions: Vec::new(), + definitions_by_use: IndexVec::new(), + definitions_by_symbol: IndexVec::new(), + } + } + + pub(super) fn add_symbol(&mut self, symbol: ScopedSymbolId) { + let new_symbol = self.definitions_by_symbol.push(Definitions::default()); + debug_assert_eq!(symbol, new_symbol); + } + + pub(super) fn record_def(&mut self, symbol: ScopedSymbolId, definition: Definition<'db>) { + let def_idx = self.all_definitions.len(); + self.all_definitions.push(definition); + self.definitions_by_symbol[symbol] = Definitions { + #[allow(clippy::range_plus_one)] + definitions: def_idx..(def_idx + 1), + may_be_unbound: false, + }; + } + + pub(super) fn record_use(&mut self, symbol: ScopedSymbolId, use_id: ScopedUseId) { + let new_use = self + .definitions_by_use + .push(self.definitions_by_symbol[symbol].clone()); + debug_assert_eq!(use_id, new_use); + } + + pub(super) fn snapshot(&self) -> FlowSnapshot { + FlowSnapshot { + definitions_by_symbol: self.definitions_by_symbol.clone(), + } + } + + pub(super) fn set(&mut self, state: &FlowSnapshot) { + let num_symbols = self.definitions_by_symbol.len(); + self.definitions_by_symbol = state.definitions_by_symbol.clone(); + self.definitions_by_symbol + .resize(num_symbols, Definitions::default()); + } + + pub(super) fn merge(&mut self, state: &FlowSnapshot) { + for (symbol_id, to_merge) in state.definitions_by_symbol.iter_enumerated() { + let current = self.definitions_by_symbol.get_mut(symbol_id).unwrap(); + // if the symbol can be unbound in either predecessor, it can be unbound + current.may_be_unbound |= to_merge.may_be_unbound; + // merge the definition ranges + if current.definitions == to_merge.definitions { + // ranges already identical, nothing to do! + } else if current.definitions.end == to_merge.definitions.start { + // ranges adjacent (current first), just merge them + current.definitions = (current.definitions.start)..(to_merge.definitions.end); + } else if current.definitions.start == to_merge.definitions.end { + // ranges adjacent (to_merge first), just merge them + current.definitions = (to_merge.definitions.start)..(current.definitions.end); + } else if current.definitions.end == self.all_definitions.len() { + // ranges not adjacent but current is at end, copy only to_merge + self.all_definitions + .extend_from_within(to_merge.definitions.clone()); + current.definitions.end = self.all_definitions.len(); + } else if to_merge.definitions.end == self.all_definitions.len() { + // ranges not adjacent but to_merge is at end, copy only current + self.all_definitions + .extend_from_within(current.definitions.clone()); + current.definitions.start = to_merge.definitions.start; + current.definitions.end = self.all_definitions.len(); + } else { + // ranges not adjacent and neither at end, must copy both + let start = self.all_definitions.len(); + self.all_definitions + .extend_from_within(current.definitions.clone()); + self.all_definitions + .extend_from_within(to_merge.definitions.clone()); + current.definitions.start = start; + current.definitions.end = self.all_definitions.len(); + } + } + } + + pub(super) fn finish(mut self) -> UseDefMap<'db> { + self.all_definitions.shrink_to_fit(); + self.definitions_by_symbol.shrink_to_fit(); + self.definitions_by_use.shrink_to_fit(); + + UseDefMap { + all_definitions: self.all_definitions, + definitions_by_use: self.definitions_by_use, + public_definitions: self.definitions_by_symbol, + } + } +} diff --git a/crates/red_knot_python_semantic/src/semantic_model.rs b/crates/red_knot_python_semantic/src/semantic_model.rs index 29433ba4ee7e9..851bc31832354 100644 --- a/crates/red_knot_python_semantic/src/semantic_model.rs +++ b/crates/red_knot_python_semantic/src/semantic_model.rs @@ -4,9 +4,8 @@ use ruff_python_ast as ast; use ruff_python_ast::{Expr, ExpressionRef, StmtClassDef}; use crate::semantic_index::ast_ids::HasScopedAstId; -use crate::semantic_index::symbol::PublicSymbolId; -use crate::semantic_index::{public_symbol, semantic_index}; -use crate::types::{infer_types, public_symbol_ty, Type}; +use crate::semantic_index::semantic_index; +use crate::types::{definition_ty, infer_scope_types, module_global_symbol_ty_by_name, Type}; use crate::Db; pub struct SemanticModel<'db> { @@ -29,12 +28,8 @@ impl<'db> SemanticModel<'db> { resolve_module(self.db.upcast(), module_name) } - pub fn public_symbol(&self, module: &Module, symbol_name: &str) -> Option> { - public_symbol(self.db, module.file(), symbol_name) - } - - pub fn public_symbol_ty(&self, symbol: PublicSymbolId<'db>) -> Type { - public_symbol_ty(self.db, symbol) + pub fn module_global_symbol_ty(&self, module: &Module, symbol_name: &str) -> Type<'db> { + module_global_symbol_ty_by_name(self.db, module.file(), symbol_name) } } @@ -53,7 +48,7 @@ impl HasTy for ast::ExpressionRef<'_> { let scope = file_scope.to_scope_id(model.db, model.file); let expression_id = self.scoped_ast_id(model.db, scope); - infer_types(model.db, scope).expression_ty(expression_id) + infer_scope_types(model.db, scope).expression_ty(expression_id) } } @@ -145,11 +140,7 @@ impl HasTy for ast::StmtFunctionDef { fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> { let index = semantic_index(model.db, model.file); let definition = index.definition(self); - - let scope = definition.scope(model.db).to_scope_id(model.db, model.file); - let types = infer_types(model.db, scope); - - types.definition_ty(definition) + definition_ty(model.db, definition) } } @@ -157,11 +148,7 @@ impl HasTy for StmtClassDef { fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> { let index = semantic_index(model.db, model.file); let definition = index.definition(self); - - let scope = definition.scope(model.db).to_scope_id(model.db, model.file); - let types = infer_types(model.db, scope); - - types.definition_ty(definition) + definition_ty(model.db, definition) } } @@ -169,11 +156,7 @@ impl HasTy for ast::Alias { fn ty<'db>(&self, model: &SemanticModel<'db>) -> Type<'db> { let index = semantic_index(model.db, model.file); let definition = index.definition(self); - - let scope = definition.scope(model.db).to_scope_id(model.db, model.file); - let types = infer_types(model.db, scope); - - types.definition_ty(definition) + definition_ty(model.db, definition) } } diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 517fb52a76e87..9cb7d78e6b3da 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -1,91 +1,89 @@ use ruff_db::files::File; -use ruff_db::parsed::parsed_module; use ruff_python_ast::name::Name; -use crate::semantic_index::symbol::{NodeWithScopeKind, PublicSymbolId, ScopeId}; -use crate::semantic_index::{public_symbol, root_scope, semantic_index, symbol_table}; -use crate::types::infer::{TypeInference, TypeInferenceBuilder}; +use crate::semantic_index::definition::Definition; +use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId}; +use crate::semantic_index::{module_global_scope, symbol_table, use_def_map}; use crate::{Db, FxOrderSet}; mod display; mod infer; -/// Infers the type of a public symbol. -/// -/// This is a Salsa query to get symbol-level invalidation instead of file-level dependency invalidation. -/// Without this being a query, changing any public type of a module would invalidate the type inference -/// for the module scope of its dependents and the transitive dependents because. -/// -/// For example if we have -/// ```python -/// # a.py -/// import x from b -/// -/// # b.py -/// -/// x = 20 -/// ``` -/// -/// And x is now changed from `x = 20` to `x = 30`. The following happens: -/// -/// * The module level types of `b.py` change because `x` now is a `Literal[30]`. -/// * The module level types of `a.py` change because the imported symbol `x` now has a `Literal[30]` type -/// * The module level types of any dependents of `a.py` change because the imported symbol `x` now has a `Literal[30]` type -/// * And so on for all transitive dependencies. -/// -/// This being a query ensures that the invalidation short-circuits if the type of this symbol didn't change. -#[salsa::tracked] -pub(crate) fn public_symbol_ty<'db>(db: &'db dyn Db, symbol: PublicSymbolId<'db>) -> Type<'db> { - let _span = tracing::trace_span!("public_symbol_ty", ?symbol).entered(); - - let file = symbol.file(db); - let scope = root_scope(db, file); - - // TODO switch to inferring just the definition(s), not the whole scope - let inference = infer_types(db, scope); - inference.symbol_ty(symbol.scoped_symbol_id(db)) +pub(crate) use self::infer::{infer_definition_types, infer_expression_types, infer_scope_types}; + +/// Infer the public type of a symbol (its type as seen from outside its scope). +pub(crate) fn symbol_ty<'db>( + db: &'db dyn Db, + scope: ScopeId<'db>, + symbol: ScopedSymbolId, +) -> Type<'db> { + let _span = tracing::trace_span!("symbol_ty", ?symbol).entered(); + + let use_def = use_def_map(db, scope); + definitions_ty( + db, + use_def.public_definitions(symbol), + use_def.public_may_be_unbound(symbol), + ) } -/// Shorthand for `public_symbol_ty` that takes a symbol name instead of a [`PublicSymbolId`]. -pub(crate) fn public_symbol_ty_by_name<'db>( +/// Shorthand for `symbol_ty` that takes a symbol name instead of an ID. +pub(crate) fn symbol_ty_by_name<'db>( + db: &'db dyn Db, + scope: ScopeId<'db>, + name: &str, +) -> Type<'db> { + let table = symbol_table(db, scope); + table + .symbol_id_by_name(name) + .map(|symbol| symbol_ty(db, scope, symbol)) + .unwrap_or(Type::Unbound) +} + +/// Shorthand for `symbol_ty` that looks up a module-global symbol in a file. +pub(crate) fn module_global_symbol_ty_by_name<'db>( db: &'db dyn Db, file: File, name: &str, -) -> Option> { - let symbol = public_symbol(db, file, name)?; - Some(public_symbol_ty(db, symbol)) +) -> Type<'db> { + symbol_ty_by_name(db, module_global_scope(db, file), name) } -/// Infers all types for `scope`. -#[salsa::tracked(return_ref)] -pub(crate) fn infer_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> TypeInference<'db> { - let _span = tracing::trace_span!("infer_types", ?scope).entered(); +/// Infer the type of a [`Definition`]. +pub(crate) fn definition_ty<'db>(db: &'db dyn Db, definition: Definition<'db>) -> Type<'db> { + let inference = infer_definition_types(db, definition); + inference.definition_ty(definition) +} - let file = scope.file(db); - // Using the index here is fine because the code below depends on the AST anyway. - // The isolation of the query is by the return inferred types. - let index = semantic_index(db, file); +pub(crate) fn definitions_ty<'db>( + db: &'db dyn Db, + definitions: &[Definition<'db>], + may_be_unbound: bool, +) -> Type<'db> { + let unbound_iter = if may_be_unbound { + [Type::Unbound].iter() + } else { + [].iter() + }; + let def_types = definitions.iter().map(|def| definition_ty(db, *def)); + let mut all_types = unbound_iter.copied().chain(def_types); - let node = scope.node(db); + let Some(first) = all_types.next() else { + return Type::Unbound; + }; - let mut context = TypeInferenceBuilder::new(db, scope, index); + if let Some(second) = all_types.next() { + let mut builder = UnionTypeBuilder::new(db); + builder = builder.add(first).add(second); - match node { - NodeWithScopeKind::Module => { - let parsed = parsed_module(db.upcast(), file); - context.infer_module(parsed.syntax()); - } - NodeWithScopeKind::Function(function) => context.infer_function_body(function.node()), - NodeWithScopeKind::Class(class) => context.infer_class_body(class.node()), - NodeWithScopeKind::ClassTypeParameters(class) => { - context.infer_class_type_params(class.node()); + for variant in all_types { + builder = builder.add(variant); } - NodeWithScopeKind::FunctionTypeParameters(function) => { - context.infer_function_type_params(function.node()); - } - } - context.finish() + Type::Union(builder.build()) + } else { + first + } } /// unique ID for a type @@ -96,9 +94,10 @@ pub enum Type<'db> { /// the empty set of values Never, /// unknown type (no annotation) - /// equivalent to Any, or to object in strict mode + /// equivalent to Any, or possibly to object in strict mode Unknown, - /// name is not bound to any value + /// name does not exist or is not bound to any value (this represents an error, but with some + /// leniency options it could be silently resolved to Unknown in some cases) Unbound, /// the None object (TODO remove this in favor of Instance(types.NoneType) None, @@ -125,15 +124,16 @@ impl<'db> Type<'db> { matches!(self, Type::Unknown) } - pub fn member(&self, db: &'db dyn Db, name: &Name) -> Option> { + #[must_use] + pub fn member(&self, db: &'db dyn Db, name: &Name) -> Type<'db> { match self { - Type::Any => Some(Type::Any), + Type::Any => Type::Any, Type::Never => todo!("attribute lookup on Never type"), - Type::Unknown => Some(Type::Unknown), - Type::Unbound => todo!("attribute lookup on Unbound type"), + Type::Unknown => Type::Unknown, + Type::Unbound => Type::Unbound, Type::None => todo!("attribute lookup on None type"), Type::Function(_) => todo!("attribute lookup on Function type"), - Type::Module(file) => public_symbol_ty_by_name(db, *file, name), + Type::Module(file) => module_global_symbol_ty_by_name(db, *file, name), Type::Class(class) => class.class_member(db, name), Type::Instance(_) => { // TODO MRO? get_own_instance_member, get_instance_member @@ -152,7 +152,7 @@ impl<'db> Type<'db> { } Type::IntLiteral(_) => { // TODO raise error - Some(Type::Unknown) + Type::Unknown } } } @@ -188,32 +188,30 @@ impl<'db> ClassType<'db> { /// Returns the class member of this class named `name`. /// /// The member resolves to a member of the class itself or any of its bases. - pub fn class_member(self, db: &'db dyn Db, name: &Name) -> Option> { - if let Some(member) = self.own_class_member(db, name) { - return Some(member); + pub fn class_member(self, db: &'db dyn Db, name: &Name) -> Type<'db> { + let member = self.own_class_member(db, name); + if !member.is_unbound() { + return member; } self.inherited_class_member(db, name) } /// Returns the inferred type of the class member named `name`. - pub fn own_class_member(self, db: &'db dyn Db, name: &Name) -> Option> { + pub fn own_class_member(self, db: &'db dyn Db, name: &Name) -> Type<'db> { let scope = self.body_scope(db); - let symbols = symbol_table(db, scope); - let symbol = symbols.symbol_id_by_name(name)?; - let types = infer_types(db, scope); - - Some(types.symbol_ty(symbol)) + symbol_ty_by_name(db, scope, name) } - pub fn inherited_class_member(self, db: &'db dyn Db, name: &Name) -> Option> { + pub fn inherited_class_member(self, db: &'db dyn Db, name: &Name) -> Type<'db> { for base in self.bases(db) { - if let Some(member) = base.member(db, name) { - return Some(member); + let member = base.member(db, name); + if !member.is_unbound() { + return member; } } - None + Type::Unbound } } @@ -268,165 +266,3 @@ pub struct IntersectionType<'db> { // the intersection type does not include any value in any of these types negative: FxOrderSet>, } - -#[cfg(test)] -mod tests { - use red_knot_module_resolver::{ - set_module_resolution_settings, RawModuleResolutionSettings, TargetVersion, - }; - use ruff_db::files::system_path_to_file; - use ruff_db::parsed::parsed_module; - use ruff_db::system::{DbWithTestSystem, SystemPathBuf}; - use ruff_db::testing::{assert_function_query_was_not_run, assert_function_query_was_run}; - - use crate::db::tests::TestDb; - use crate::semantic_index::root_scope; - use crate::types::{infer_types, public_symbol_ty_by_name}; - use crate::{HasTy, SemanticModel}; - - fn setup_db() -> TestDb { - let mut db = TestDb::new(); - set_module_resolution_settings( - &mut db, - RawModuleResolutionSettings { - target_version: TargetVersion::Py38, - extra_paths: vec![], - workspace_root: SystemPathBuf::from("/src"), - site_packages: None, - custom_typeshed: None, - }, - ); - - db - } - - #[test] - fn local_inference() -> anyhow::Result<()> { - let mut db = setup_db(); - - db.write_file("/src/a.py", "x = 10")?; - let a = system_path_to_file(&db, "/src/a.py").unwrap(); - - let parsed = parsed_module(&db, a); - - let statement = parsed.suite().first().unwrap().as_assign_stmt().unwrap(); - let model = SemanticModel::new(&db, a); - - let literal_ty = statement.value.ty(&model); - - assert_eq!(format!("{}", literal_ty.display(&db)), "Literal[10]"); - - Ok(()) - } - - #[test] - fn dependency_public_symbol_type_change() -> anyhow::Result<()> { - let mut db = setup_db(); - - db.write_files([ - ("/src/a.py", "from foo import x"), - ("/src/foo.py", "x = 10\ndef foo(): ..."), - ])?; - - let a = system_path_to_file(&db, "/src/a.py").unwrap(); - let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap(); - - assert_eq!(x_ty.display(&db).to_string(), "Literal[10]"); - - // Change `x` to a different value - db.write_file("/src/foo.py", "x = 20\ndef foo(): ...")?; - - let a = system_path_to_file(&db, "/src/a.py").unwrap(); - - db.clear_salsa_events(); - let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap(); - - assert_eq!(x_ty_2.display(&db).to_string(), "Literal[20]"); - - let events = db.take_salsa_events(); - - let a_root_scope = root_scope(&db, a); - assert_function_query_was_run::( - &db, - |ty| &ty.function, - &a_root_scope, - &events, - ); - - Ok(()) - } - - #[test] - fn dependency_non_public_symbol_change() -> anyhow::Result<()> { - let mut db = setup_db(); - - db.write_files([ - ("/src/a.py", "from foo import x"), - ("/src/foo.py", "x = 10\ndef foo(): y = 1"), - ])?; - - let a = system_path_to_file(&db, "/src/a.py").unwrap(); - let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap(); - - assert_eq!(x_ty.display(&db).to_string(), "Literal[10]"); - - db.write_file("/src/foo.py", "x = 10\ndef foo(): pass")?; - - let a = system_path_to_file(&db, "/src/a.py").unwrap(); - - db.clear_salsa_events(); - - let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap(); - - assert_eq!(x_ty_2.display(&db).to_string(), "Literal[10]"); - - let events = db.take_salsa_events(); - - let a_root_scope = root_scope(&db, a); - - assert_function_query_was_not_run::( - &db, - |ty| &ty.function, - &a_root_scope, - &events, - ); - - Ok(()) - } - - #[test] - fn dependency_unrelated_public_symbol() -> anyhow::Result<()> { - let mut db = setup_db(); - - db.write_files([ - ("/src/a.py", "from foo import x"), - ("/src/foo.py", "x = 10\ny = 20"), - ])?; - - let a = system_path_to_file(&db, "/src/a.py").unwrap(); - let x_ty = public_symbol_ty_by_name(&db, a, "x").unwrap(); - - assert_eq!(x_ty.display(&db).to_string(), "Literal[10]"); - - db.write_file("/src/foo.py", "x = 10\ny = 30")?; - - let a = system_path_to_file(&db, "/src/a.py").unwrap(); - - db.clear_salsa_events(); - - let x_ty_2 = public_symbol_ty_by_name(&db, a, "x").unwrap(); - - assert_eq!(x_ty_2.display(&db).to_string(), "Literal[10]"); - - let events = db.take_salsa_events(); - - let a_root_scope = root_scope(&db, a); - assert_function_query_was_not_run::( - &db, - |ty| &ty.function, - &a_root_scope, - &events, - ); - Ok(()) - } -} diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index f8623ae37d699..7a04e0cc703c4 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -1,43 +1,86 @@ use rustc_hash::FxHashMap; -use std::borrow::Cow; -use std::sync::Arc; +use salsa; use red_knot_module_resolver::{resolve_module, ModuleName}; use ruff_db::files::File; -use ruff_index::IndexVec; use ruff_python_ast as ast; use ruff_python_ast::{ExprContext, TypeParams}; -use crate::semantic_index::ast_ids::ScopedExpressionId; -use crate::semantic_index::definition::{Definition, DefinitionNodeRef}; -use crate::semantic_index::symbol::{ - FileScopeId, NodeWithScopeRef, ScopeId, ScopedSymbolId, SymbolTable, +use crate::semantic_index::ast_ids::{HasScopedAstId, HasScopedUseId, ScopedExpressionId}; +use crate::semantic_index::definition::{Definition, DefinitionKind}; +use crate::semantic_index::expression::Expression; +use crate::semantic_index::symbol::{NodeWithScopeRef, ScopeId}; +use crate::semantic_index::SemanticIndex; +use crate::types::{ + definitions_ty, use_def_map, ClassType, FunctionType, Name, Type, UnionTypeBuilder, }; -use crate::semantic_index::{symbol_table, SemanticIndex}; -use crate::types::{infer_types, ClassType, FunctionType, Name, Type, UnionTypeBuilder}; use crate::Db; +use ruff_db::parsed::parsed_module; -/// The inferred types for a single scope. +use crate::semantic_index::semantic_index; +use crate::semantic_index::symbol::NodeWithScopeKind; + +/// Infer all types for a [`Definition`] (including sub-expressions). +#[salsa::tracked(return_ref)] +pub(crate) fn infer_definition_types<'db>( + db: &'db dyn Db, + definition: Definition<'db>, +) -> TypeInference<'db> { + let _span = tracing::trace_span!("infer_definition_types", ?definition).entered(); + + let index = semantic_index(db, definition.file(db)); + + TypeInferenceBuilder::new(db, InferenceRegion::Definition(definition), index).finish() +} + +/// Infer all types for an [`Expression`] (including sub-expressions). +#[allow(unused)] +#[salsa::tracked(return_ref)] +pub(crate) fn infer_expression_types<'db>( + db: &'db dyn Db, + expression: Expression<'db>, +) -> TypeInference<'db> { + let _span = tracing::trace_span!("infer_expression_types", ?expression).entered(); + + let index = semantic_index(db, expression.file(db)); + + TypeInferenceBuilder::new(db, InferenceRegion::Expression(expression), index).finish() +} + +/// Infer all types for `scope`. +#[salsa::tracked(return_ref)] +pub(crate) fn infer_scope_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> TypeInference<'db> { + let _span = tracing::trace_span!("infer_scope_types", ?scope).entered(); + + let file = scope.file(db); + // Using the index here is fine because the code below depends on the AST anyway. + // The isolation of the query is by the return inferred types. + let index = semantic_index(db, file); + + TypeInferenceBuilder::new(db, InferenceRegion::Scope(scope), index).finish() +} + +/// A region within which we can infer types. +pub(crate) enum InferenceRegion<'db> { + Expression(Expression<'db>), + Definition(Definition<'db>), + Scope(ScopeId<'db>), +} + +/// The inferred types for a single region. #[derive(Debug, Eq, PartialEq, Default, Clone)] pub(crate) struct TypeInference<'db> { - /// The types of every expression in this scope. - expressions: IndexVec>, + /// The types of every expression in this region. + expressions: FxHashMap>, - /// The public types of every symbol in this scope. - symbols: IndexVec>, - - /// The type of a definition. + /// The types of every definition in this region. definitions: FxHashMap, Type<'db>>, } impl<'db> TypeInference<'db> { #[allow(unused)] pub(crate) fn expression_ty(&self, expression: ScopedExpressionId) -> Type<'db> { - self.expressions[expression] - } - - pub(super) fn symbol_ty(&self, symbol: ScopedSymbolId) -> Type<'db> { - self.symbols[symbol] + self.expressions[&expression] } pub(crate) fn definition_ty(&self, definition: Definition<'db>) -> Type<'db> { @@ -46,69 +89,134 @@ impl<'db> TypeInference<'db> { fn shrink_to_fit(&mut self) { self.expressions.shrink_to_fit(); - self.symbols.shrink_to_fit(); self.definitions.shrink_to_fit(); } } -/// Builder to infer all types in a [`ScopeId`]. -pub(super) struct TypeInferenceBuilder<'db> { +/// Builder to infer all types in a region. +struct TypeInferenceBuilder<'db> { db: &'db dyn Db, + index: &'db SemanticIndex<'db>, + region: InferenceRegion<'db>, // Cached lookups - index: &'db SemanticIndex<'db>, - file_scope_id: FileScopeId, - file_id: File, - symbol_table: Arc>, + file: File, + scope: ScopeId<'db>, /// The type inference results types: TypeInference<'db>, } impl<'db> TypeInferenceBuilder<'db> { - /// Creates a new builder for inferring the types of `scope`. + /// Creates a new builder for inferring types in a region. pub(super) fn new( db: &'db dyn Db, - scope: ScopeId<'db>, + region: InferenceRegion<'db>, index: &'db SemanticIndex<'db>, ) -> Self { - let file_scope_id = scope.file_scope_id(db); - let file = scope.file(db); - let symbol_table = index.symbol_table(file_scope_id); + let (file, scope) = match region { + InferenceRegion::Expression(expression) => (expression.file(db), expression.scope(db)), + InferenceRegion::Definition(definition) => (definition.file(db), definition.scope(db)), + InferenceRegion::Scope(scope) => (scope.file(db), scope), + }; Self { + db, index, - file_scope_id, - file_id: file, - symbol_table, + region, + + file, + scope, - db, types: TypeInference::default(), } } - /// Infers the types of a `module`. - pub(super) fn infer_module(&mut self, module: &ast::ModModule) { + fn extend(&mut self, inference: &TypeInference<'db>) { + self.types.definitions.extend(inference.definitions.iter()); + self.types.expressions.extend(inference.expressions.iter()); + } + + /// Infers types in the given [`InferenceRegion`]. + fn infer_region(&mut self) { + match self.region { + InferenceRegion::Scope(scope) => self.infer_region_scope(scope), + InferenceRegion::Definition(definition) => self.infer_region_definition(definition), + InferenceRegion::Expression(expression) => self.infer_region_expression(expression), + } + } + + fn infer_region_scope(&mut self, scope: ScopeId<'db>) { + let node = scope.node(self.db); + match node { + NodeWithScopeKind::Module => { + let parsed = parsed_module(self.db.upcast(), self.file); + self.infer_module(parsed.syntax()); + } + NodeWithScopeKind::Function(function) => self.infer_function_body(function.node()), + NodeWithScopeKind::Class(class) => self.infer_class_body(class.node()), + NodeWithScopeKind::ClassTypeParameters(class) => { + self.infer_class_type_params(class.node()); + } + NodeWithScopeKind::FunctionTypeParameters(function) => { + self.infer_function_type_params(function.node()); + } + } + } + + fn infer_region_definition(&mut self, definition: Definition<'db>) { + match definition.node(self.db) { + DefinitionKind::Function(function) => { + self.infer_function_definition(function.node(), definition); + } + DefinitionKind::Class(class) => self.infer_class_definition(class.node(), definition), + DefinitionKind::Import(import) => { + self.infer_import_definition(import.node(), definition); + } + DefinitionKind::ImportFrom(import_from) => { + self.infer_import_from_definition( + import_from.import(), + import_from.alias(), + definition, + ); + } + DefinitionKind::Assignment(assignment) => { + self.infer_assignment_definition(assignment.assignment(), definition); + } + DefinitionKind::AnnotatedAssignment(annotated_assignment) => { + self.infer_annotated_assignment_definition(annotated_assignment.node(), definition); + } + DefinitionKind::NamedExpression(named_expression) => { + self.infer_named_expression_definition(named_expression.node(), definition); + } + } + } + + fn infer_region_expression(&mut self, expression: Expression<'db>) { + self.infer_expression(expression.node(self.db)); + } + + fn infer_module(&mut self, module: &ast::ModModule) { self.infer_body(&module.body); } - pub(super) fn infer_class_type_params(&mut self, class: &ast::StmtClassDef) { + fn infer_class_type_params(&mut self, class: &ast::StmtClassDef) { if let Some(type_params) = class.type_params.as_deref() { self.infer_type_parameters(type_params); } } - pub(super) fn infer_class_body(&mut self, class: &ast::StmtClassDef) { + fn infer_class_body(&mut self, class: &ast::StmtClassDef) { self.infer_body(&class.body); } - pub(super) fn infer_function_type_params(&mut self, function: &ast::StmtFunctionDef) { + fn infer_function_type_params(&mut self, function: &ast::StmtFunctionDef) { if let Some(type_params) = function.type_params.as_deref() { self.infer_type_parameters(type_params); } } - pub(super) fn infer_function_body(&mut self, function: &ast::StmtFunctionDef) { + fn infer_function_body(&mut self, function: &ast::StmtFunctionDef) { self.infer_body(&function.body); } @@ -139,6 +247,16 @@ impl<'db> TypeInferenceBuilder<'db> { } fn infer_function_definition_statement(&mut self, function: &ast::StmtFunctionDef) { + let definition = self.index.definition(function); + let result = infer_definition_types(self.db, definition); + self.extend(result); + } + + fn infer_function_definition( + &mut self, + function: &ast::StmtFunctionDef, + definition: Definition<'db>, + ) { let ast::StmtFunctionDef { range: _, is_async: _, @@ -164,11 +282,16 @@ impl<'db> TypeInferenceBuilder<'db> { let function_ty = Type::Function(FunctionType::new(self.db, name.id.clone(), decorator_tys)); - let definition = self.index.definition(function); self.types.definitions.insert(definition, function_ty); } fn infer_class_definition_statement(&mut self, class: &ast::StmtClassDef) { + let definition = self.index.definition(class); + let result = infer_definition_types(self.db, definition); + self.extend(result); + } + + fn infer_class_definition(&mut self, class: &ast::StmtClassDef, definition: Definition<'db>) { let ast::StmtClassDef { range: _, name, @@ -190,11 +313,10 @@ impl<'db> TypeInferenceBuilder<'db> { let body_scope = self .index .node_scope(NodeWithScopeRef::Class(class)) - .to_scope_id(self.db, self.file_id); + .to_scope_id(self.db, self.file); let class_ty = Type::Class(ClassType::new(self.db, name.id.clone(), bases, body_scope)); - let definition = self.index.definition(class); self.types.definitions.insert(definition, class_ty); } @@ -228,22 +350,46 @@ impl<'db> TypeInferenceBuilder<'db> { let ast::StmtAssign { range: _, targets, - value, + value: _, } = assignment; - let value_ty = self.infer_expression(value); - for target in targets { - self.infer_expression(target); - - self.types.definitions.insert( - self.index.definition(DefinitionNodeRef::Target(target)), - value_ty, - ); + match target { + ast::Expr::Name(name) => { + let definition = self.index.definition(name); + let result = infer_definition_types(self.db, definition); + self.extend(result); + } + _ => todo!("support unpacking assignment"), + } } } + fn infer_assignment_definition( + &mut self, + assignment: &ast::StmtAssign, + definition: Definition<'db>, + ) { + let expression = self.index.expression(assignment.value.as_ref()); + let result = infer_expression_types(self.db, expression); + self.extend(result); + let value_ty = self + .types + .expression_ty(assignment.value.scoped_ast_id(self.db, self.scope)); + self.types.definitions.insert(definition, value_ty); + } + fn infer_annotated_assignment_statement(&mut self, assignment: &ast::StmtAnnAssign) { + let definition = self.index.definition(assignment); + let result = infer_definition_types(self.db, definition); + self.extend(result); + } + + fn infer_annotated_assignment_definition( + &mut self, + assignment: &ast::StmtAnnAssign, + definition: Definition<'db>, + ) { let ast::StmtAnnAssign { range: _, target, @@ -257,12 +403,10 @@ impl<'db> TypeInferenceBuilder<'db> { } let annotation_ty = self.infer_expression(annotation); + self.infer_expression(target); - self.types.definitions.insert( - self.index.definition(DefinitionNodeRef::Target(target)), - annotation_ty, - ); + self.types.definitions.insert(definition, annotation_ty); } fn infer_for_statement(&mut self, for_statement: &ast::StmtFor) { @@ -285,54 +429,66 @@ impl<'db> TypeInferenceBuilder<'db> { let ast::StmtImport { range: _, names } = import; for alias in names { - let ast::Alias { - range: _, - name, - asname: _, - } = alias; - - let module_name = ModuleName::new(&name.id); - let module = module_name.and_then(|name| resolve_module(self.db.upcast(), name)); - let module_ty = module - .map(|module| Type::Module(module.file())) - .unwrap_or(Type::Unknown); - let definition = self.index.definition(alias); - - self.types.definitions.insert(definition, module_ty); + let result = infer_definition_types(self.db, definition); + self.extend(result); } } + fn infer_import_definition(&mut self, alias: &ast::Alias, definition: Definition<'db>) { + let ast::Alias { + range: _, + name, + asname: _, + } = alias; + + let module_ty = self.module_ty_from_name(name); + self.types.definitions.insert(definition, module_ty); + } + fn infer_import_from_statement(&mut self, import: &ast::StmtImportFrom) { let ast::StmtImportFrom { range: _, - module, + module: _, names, level: _, } = import; - let module_name = ModuleName::new(module.as_deref().expect("Support relative imports")); + for alias in names { + let definition = self.index.definition(alias); + let result = infer_definition_types(self.db, definition); + self.extend(result); + } + } - let module = - module_name.and_then(|module_name| resolve_module(self.db.upcast(), module_name)); - let module_ty = module - .map(|module| Type::Module(module.file())) - .unwrap_or(Type::Unknown); + fn infer_import_from_definition( + &mut self, + import_from: &ast::StmtImportFrom, + alias: &ast::Alias, + definition: Definition<'db>, + ) { + let ast::StmtImportFrom { module, .. } = import_from; + let module_ty = + self.module_ty_from_name(module.as_ref().expect("Support relative imports")); + + let ast::Alias { + range: _, + name, + asname: _, + } = alias; - for alias in names { - let ast::Alias { - range: _, - name, - asname: _, - } = alias; + let ty = module_ty.member(self.db, &Name::new(&name.id)); - let ty = module_ty - .member(self.db, &Name::new(&name.id)) - .unwrap_or(Type::Unknown); + self.types.definitions.insert(definition, ty); + } - let definition = self.index.definition(alias); - self.types.definitions.insert(definition, ty); - } + fn module_ty_from_name(&self, name: &ast::Identifier) -> Type<'db> { + let module_name = ModuleName::new(&name.id); + let module = + module_name.and_then(|module_name| resolve_module(self.db.upcast(), module_name)); + module + .map(|module| Type::Module(module.file())) + .unwrap_or(Type::Unbound) } fn infer_decorator(&mut self, decorator: &ast::Decorator) -> Type<'db> { @@ -378,7 +534,8 @@ impl<'db> TypeInferenceBuilder<'db> { _ => todo!("expression type resolution for {:?}", expression), }; - self.types.expressions.push(ty); + let expr_id = expression.scoped_ast_id(self.db, self.scope); + self.types.expressions.insert(expr_id, ty); ty } @@ -398,6 +555,17 @@ impl<'db> TypeInferenceBuilder<'db> { } fn infer_named_expression(&mut self, named: &ast::ExprNamed) -> Type<'db> { + let definition = self.index.definition(named); + let result = infer_definition_types(self.db, definition); + self.extend(result); + result.definition_ty(definition) + } + + fn infer_named_expression_definition( + &mut self, + named: &ast::ExprNamed, + definition: Definition<'db>, + ) -> Type<'db> { let ast::ExprNamed { range: _, target, @@ -407,9 +575,7 @@ impl<'db> TypeInferenceBuilder<'db> { let value_ty = self.infer_expression(value); self.infer_expression(target); - self.types - .definitions - .insert(self.index.definition(named), value_ty); + self.types.definitions.insert(definition, value_ty); value_ty } @@ -437,46 +603,21 @@ impl<'db> TypeInferenceBuilder<'db> { } fn infer_name_expression(&mut self, name: &ast::ExprName) -> Type<'db> { - let ast::ExprName { range: _, id, ctx } = name; + let ast::ExprName { + range: _, + id: _, + ctx, + } = name; match ctx { ExprContext::Load => { - let ancestors = self.index.ancestor_scopes(self.file_scope_id); - - for (ancestor_id, _) in ancestors { - // TODO: Skip over class scopes unless the they are a immediately-nested type param scope. - // TODO: Support built-ins - - let (symbol_table, ancestor_scope) = if ancestor_id == self.file_scope_id { - (Cow::Borrowed(&self.symbol_table), None) - } else { - let ancestor_scope = ancestor_id.to_scope_id(self.db, self.file_id); - ( - Cow::Owned(symbol_table(self.db, ancestor_scope)), - Some(ancestor_scope), - ) - }; - - if let Some(symbol_id) = symbol_table.symbol_id_by_name(id) { - let symbol = symbol_table.symbol(symbol_id); - - if !symbol.is_defined() { - continue; - } - - return if let Some(ancestor_scope) = ancestor_scope { - let types = infer_types(self.db, ancestor_scope); - types.symbol_ty(symbol_id) - } else { - self.local_definition_ty(symbol_id) - }; - } - } - Type::Unknown + let use_def = use_def_map(self.db, self.scope); + let use_id = name.scoped_use_id(self.db, self.scope); + let definitions = use_def.use_definitions(use_id); + definitions_ty(self.db, definitions, use_def.use_may_be_unbound(use_id)) } - ExprContext::Del => Type::None, + ExprContext::Store | ExprContext::Del => Type::None, ExprContext::Invalid => Type::Unknown, - ExprContext::Store => Type::None, } } @@ -489,9 +630,7 @@ impl<'db> TypeInferenceBuilder<'db> { } = attribute; let value_ty = self.infer_expression(value); - let member_ty = value_ty - .member(self.db, &Name::new(&attr.id)) - .unwrap_or(Type::Unknown); + let member_ty = value_ty.member(self.db, &Name::new(&attr.id)); match ctx { ExprContext::Load => member_ty, @@ -558,42 +697,10 @@ impl<'db> TypeInferenceBuilder<'db> { } pub(super) fn finish(mut self) -> TypeInference<'db> { - let symbol_tys: IndexVec<_, _> = self - .index - .symbol_table(self.file_scope_id) - .symbol_ids() - .map(|symbol| self.local_definition_ty(symbol)) - .collect(); - - self.types.symbols = symbol_tys; + self.infer_region(); self.types.shrink_to_fit(); self.types } - - fn local_definition_ty(&mut self, symbol: ScopedSymbolId) -> Type<'db> { - let symbol = self.symbol_table.symbol(symbol); - let mut definitions = symbol - .definitions() - .iter() - .filter_map(|definition| self.types.definitions.get(definition).copied()); - - let Some(first) = definitions.next() else { - return Type::Unbound; - }; - - if let Some(second) = definitions.next() { - let mut builder = UnionTypeBuilder::new(self.db); - builder = builder.add(first).add(second); - - for variant in definitions { - builder = builder.add(variant); - } - - Type::Union(builder.build()) - } else { - first - } - } } #[cfg(test)] @@ -601,12 +708,20 @@ mod tests { use red_knot_module_resolver::{ set_module_resolution_settings, RawModuleResolutionSettings, TargetVersion, }; - use ruff_db::files::system_path_to_file; + use ruff_db::files::{system_path_to_file, File}; + use ruff_db::parsed::parsed_module; use ruff_db::system::{DbWithTestSystem, SystemPathBuf}; + use ruff_db::testing::{assert_function_query_was_not_run, assert_function_query_was_run}; use ruff_python_ast::name::Name; use crate::db::tests::TestDb; - use crate::types::{public_symbol_ty_by_name, Type}; + use crate::semantic_index::definition::Definition; + use crate::types::{ + infer_definition_types, module_global_scope, module_global_symbol_ty_by_name, symbol_table, + use_def_map, Type, + }; + use crate::{HasTy, SemanticModel}; + use textwrap::dedent; fn setup_db() -> TestDb { let mut db = TestDb::new(); @@ -628,10 +743,17 @@ mod tests { fn assert_public_ty(db: &TestDb, file_name: &str, symbol_name: &str, expected: &str) { let file = system_path_to_file(db, file_name).expect("Expected file to exist."); - let ty = public_symbol_ty_by_name(db, file, symbol_name).unwrap_or(Type::Unknown); + let ty = module_global_symbol_ty_by_name(db, file, symbol_name); assert_eq!(ty.display(db).to_string(), expected); } + impl TestDb { + fn write_dedented(&mut self, path: &str, content: &str) -> anyhow::Result<()> { + self.write_file(path, dedent(content))?; + Ok(()) + } + } + #[test] fn follow_import_to_class() -> anyhow::Result<()> { let mut db = setup_db(); @@ -650,18 +772,19 @@ mod tests { fn resolve_base_class_by_name() -> anyhow::Result<()> { let mut db = setup_db(); - db.write_file( + db.write_dedented( "src/mod.py", - r#" -class Base: - pass + " + class Base: + pass -class Sub(Base): - pass"#, + class Sub(Base): + pass + ", )?; let mod_file = system_path_to_file(&db, "src/mod.py").expect("Expected file to exist."); - let ty = public_symbol_ty_by_name(&db, mod_file, "Sub").expect("Symbol type to exist"); + let ty = module_global_symbol_ty_by_name(&db, mod_file, "Sub"); let Type::Class(class) = ty else { panic!("Sub is not a Class") @@ -682,16 +805,16 @@ class Sub(Base): fn resolve_method() -> anyhow::Result<()> { let mut db = setup_db(); - db.write_file( + db.write_dedented( "src/mod.py", " -class C: - def f(self): pass + class C: + def f(self): pass ", )?; let mod_file = system_path_to_file(&db, "src/mod.py").unwrap(); - let ty = public_symbol_ty_by_name(&db, mod_file, "C").unwrap(); + let ty = module_global_symbol_ty_by_name(&db, mod_file, "C"); let Type::Class(class_id) = ty else { panic!("C is not a Class"); @@ -699,7 +822,7 @@ class C: let member_ty = class_id.class_member(&db, &Name::new_static("f")); - let Some(Type::Function(func)) = member_ty else { + let Type::Function(func) = member_ty else { panic!("C.f is not a Function"); }; @@ -737,13 +860,13 @@ class C: fn resolve_union() -> anyhow::Result<()> { let mut db = setup_db(); - db.write_file( + db.write_dedented( "src/a.py", " -if flag: - x = 1 -else: - x = 2 + if flag: + x = 1 + else: + x = 2 ", )?; @@ -756,14 +879,14 @@ else: fn literal_int_arithmetic() -> anyhow::Result<()> { let mut db = setup_db(); - db.write_file( + db.write_dedented( "src/a.py", " -a = 2 + 1 -b = a - 4 -c = a * b -d = c / 3 -e = 5 % 3 + a = 2 + 1 + b = a - 4 + c = a * b + d = c / 3 + e = 5 % 3 ", )?; @@ -803,13 +926,14 @@ e = 5 % 3 fn ifexpr_walrus() -> anyhow::Result<()> { let mut db = setup_db(); - db.write_file( + db.write_dedented( "src/a.py", " -y = z = 0 -x = (y := 1) if flag else (z := 2) -a = y -b = z + y = 0 + z = 0 + x = (y := 1) if flag else (z := 2) + a = y + b = z ", )?; @@ -831,6 +955,19 @@ b = z Ok(()) } + #[test] + fn multi_target_assign() -> anyhow::Result<()> { + let db = setup_db(); + + db.memory_file_system() + .write_file("src/a.py", "x = y = 1")?; + + assert_public_ty(&db, "src/a.py", "x", "Literal[1]"); + assert_public_ty(&db, "src/a.py", "y", "Literal[1]"); + + Ok(()) + } + #[test] fn none() -> anyhow::Result<()> { let mut db = setup_db(); @@ -840,4 +977,263 @@ b = z assert_public_ty(&db, "src/a.py", "x", "Literal[1] | None"); Ok(()) } + + #[test] + fn simple_if() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + y = 1 + y = 2 + if flag: + y = 3 + x = y + ", + )?; + + assert_public_ty(&db, "src/a.py", "x", "Literal[2, 3]"); + Ok(()) + } + + #[test] + fn maybe_unbound() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + if flag: + y = 3 + x = y + ", + )?; + + assert_public_ty(&db, "src/a.py", "x", "Literal[3] | Unbound"); + Ok(()) + } + + #[test] + fn if_elif_else() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + y = 1 + y = 2 + if flag: + y = 3 + elif flag2: + y = 4 + else: + r = y + y = 5 + s = y + x = y + ", + )?; + + assert_public_ty(&db, "src/a.py", "x", "Literal[3, 4, 5]"); + assert_public_ty(&db, "src/a.py", "r", "Literal[2] | Unbound"); + assert_public_ty(&db, "src/a.py", "s", "Literal[5] | Unbound"); + Ok(()) + } + + #[test] + fn if_elif() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + y = 1 + y = 2 + if flag: + y = 3 + elif flag2: + y = 4 + x = y + ", + )?; + + assert_public_ty(&db, "src/a.py", "x", "Literal[2, 3, 4]"); + Ok(()) + } + + #[test] + fn import_cycle() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + class A: pass + import b + class C(b.B): pass + ", + )?; + db.write_dedented( + "src/b.py", + " + from a import A + class B(A): pass + ", + )?; + + let a = system_path_to_file(&db, "src/a.py").expect("Expected file to exist."); + let c_ty = module_global_symbol_ty_by_name(&db, a, "C"); + let Type::Class(c_class) = c_ty else { + panic!("C is not a Class") + }; + let c_bases = c_class.bases(&db); + let b_ty = c_bases.first().unwrap(); + let Type::Class(b_class) = b_ty else { + panic!("B is not a Class") + }; + assert_eq!(b_class.name(&db), "B"); + let b_bases = b_class.bases(&db); + let a_ty = b_bases.first().unwrap(); + let Type::Class(a_class) = a_ty else { + panic!("A is not a Class") + }; + assert_eq!(a_class.name(&db), "A"); + + Ok(()) + } + + #[test] + fn local_inference() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_file("/src/a.py", "x = 10")?; + let a = system_path_to_file(&db, "/src/a.py").unwrap(); + + let parsed = parsed_module(&db, a); + + let statement = parsed.suite().first().unwrap().as_assign_stmt().unwrap(); + let model = SemanticModel::new(&db, a); + + let literal_ty = statement.value.ty(&model); + + assert_eq!(format!("{}", literal_ty.display(&db)), "Literal[10]"); + + Ok(()) + } + + fn first_public_def<'db>(db: &'db TestDb, file: File, name: &str) -> Definition<'db> { + let scope = module_global_scope(db, file); + *use_def_map(db, scope) + .public_definitions(symbol_table(db, scope).symbol_id_by_name(name).unwrap()) + .first() + .unwrap() + } + + #[test] + fn dependency_public_symbol_type_change() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_files([ + ("/src/a.py", "from foo import x"), + ("/src/foo.py", "x = 10\ndef foo(): ..."), + ])?; + + let a = system_path_to_file(&db, "/src/a.py").unwrap(); + let x_ty = module_global_symbol_ty_by_name(&db, a, "x"); + + assert_eq!(x_ty.display(&db).to_string(), "Literal[10]"); + + // Change `x` to a different value + db.write_file("/src/foo.py", "x = 20\ndef foo(): ...")?; + + let a = system_path_to_file(&db, "/src/a.py").unwrap(); + + db.clear_salsa_events(); + let x_ty_2 = module_global_symbol_ty_by_name(&db, a, "x"); + + assert_eq!(x_ty_2.display(&db).to_string(), "Literal[20]"); + + let events = db.take_salsa_events(); + + assert_function_query_was_run::( + &db, + |ty| &ty.function, + &first_public_def(&db, a, "x"), + &events, + ); + + Ok(()) + } + + #[test] + fn dependency_internal_symbol_change() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_files([ + ("/src/a.py", "from foo import x"), + ("/src/foo.py", "x = 10\ndef foo(): y = 1"), + ])?; + + let a = system_path_to_file(&db, "/src/a.py").unwrap(); + let x_ty = module_global_symbol_ty_by_name(&db, a, "x"); + + assert_eq!(x_ty.display(&db).to_string(), "Literal[10]"); + + db.write_file("/src/foo.py", "x = 10\ndef foo(): pass")?; + + let a = system_path_to_file(&db, "/src/a.py").unwrap(); + + db.clear_salsa_events(); + + let x_ty_2 = module_global_symbol_ty_by_name(&db, a, "x"); + + assert_eq!(x_ty_2.display(&db).to_string(), "Literal[10]"); + + let events = db.take_salsa_events(); + + assert_function_query_was_not_run::( + &db, + |ty| &ty.function, + &first_public_def(&db, a, "x"), + &events, + ); + + Ok(()) + } + + #[test] + fn dependency_unrelated_symbol() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_files([ + ("/src/a.py", "from foo import x"), + ("/src/foo.py", "x = 10\ny = 20"), + ])?; + + let a = system_path_to_file(&db, "/src/a.py").unwrap(); + let x_ty = module_global_symbol_ty_by_name(&db, a, "x"); + + assert_eq!(x_ty.display(&db).to_string(), "Literal[10]"); + + db.write_file("/src/foo.py", "x = 10\ny = 30")?; + + let a = system_path_to_file(&db, "/src/a.py").unwrap(); + + db.clear_salsa_events(); + + let x_ty_2 = module_global_symbol_ty_by_name(&db, a, "x"); + + assert_eq!(x_ty_2.display(&db).to_string(), "Literal[10]"); + + let events = db.take_salsa_events(); + + assert_function_query_was_not_run::( + &db, + |ty| &ty.function, + &first_public_def(&db, a, "x"), + &events, + ); + Ok(()) + } } diff --git a/crates/ruff_db/src/system/test.rs b/crates/ruff_db/src/system/test.rs index 38c5dad7ce8dc..0a4c2d9a2bd05 100644 --- a/crates/ruff_db/src/system/test.rs +++ b/crates/ruff_db/src/system/test.rs @@ -129,7 +129,7 @@ pub trait DbWithTestSystem: Db + Sized { result } - /// Writes the content of the given file and notifies the Db about the change. + /// Writes the content of the given files and notifies the Db about the change. /// /// # Panics /// If the system isn't using the memory file system for testing. diff --git a/crates/ruff_index/src/vec.rs b/crates/ruff_index/src/vec.rs index 795f8315d4639..184cf0ec89922 100644 --- a/crates/ruff_index/src/vec.rs +++ b/crates/ruff_index/src/vec.rs @@ -74,6 +74,14 @@ impl IndexVec { pub fn shrink_to_fit(&mut self) { self.raw.shrink_to_fit(); } + + #[inline] + pub fn resize(&mut self, new_len: usize, value: T) + where + T: Clone, + { + self.raw.resize(new_len, value); + } } impl Debug for IndexVec