Skip to content

Commit

Permalink
[red-knot] definition-level inference
Browse files Browse the repository at this point in the history
  • Loading branch information
carljm committed Jul 11, 2024
1 parent d0298dc commit 524f359
Show file tree
Hide file tree
Showing 14 changed files with 1,092 additions and 712 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 5 additions & 6 deletions crates/red_knot/src/lint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,7 @@ fn lint_bad_override(context: &SemanticLintContext, class: &ast::StmtClassDef) {
return;
};

let Some(typing_override) = semantic.public_symbol(&typing, "override") else {
return;
};

let override_ty = semantic.public_symbol_ty(typing_override);
let override_ty = semantic.root_symbol_ty(&typing, "override");

let Type::Class(class_ty) = class.ty(semantic) else {
return;
Expand All @@ -154,7 +150,10 @@ fn lint_bad_override(context: &SemanticLintContext, class: &ast::StmtClassDef) {

if ty.has_decorator(db, override_ty) {
let method_name = ty.name(db);
if class_ty.inherited_class_member(db, &method_name).is_none() {
if class_ty
.inherited_class_member(db, &method_name)
.is_unbound()
{
// TODO should have a qualname() method to support nested classes
context.push_diagnostic(
format!(
Expand Down
2 changes: 2 additions & 0 deletions crates/red_knot_python_semantic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ bitflags = { workspace = true }
ordermap = { workspace = true }
salsa = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
tracing-tree = { workspace = true }
rustc-hash = { workspace = true }
hashbrown = { workspace = true }

Expand Down
13 changes: 7 additions & 6 deletions crates/red_knot_python_semantic/src/ast_node_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ pub struct AstNodeRef<T> {

#[allow(unsafe_code)]
impl<T> AstNodeRef<T> {
/// Creates a new `AstNodeRef` that reference `node`. The `parsed` is the [`ParsedModule`] to which
/// the `AstNodeRef` belongs.
/// Creates a new `AstNodeRef` that reference `node`. The `parsed` is the [`ParsedModule`] to
/// which the `AstNodeRef` belongs.
///
/// ## Safety
/// Dereferencing the `node` can result in undefined behavior if `parsed` isn't the [`ParsedModule`] to
/// which `node` belongs. It's the caller's responsibility to ensure that the invariant `node belongs to parsed` is upheld.
/// Dereferencing the `node` can result in undefined behavior if `parsed` isn't the
/// [`ParsedModule`] to which `node` belongs. It's the caller's responsibility to ensure that
/// the invariant `node belongs to parsed` is upheld.

pub(super) unsafe fn new(parsed: ParsedModule, node: &T) -> Self {
Self {
Expand All @@ -43,8 +44,8 @@ impl<T> AstNodeRef<T> {

/// Returns a reference to the wrapped node.
pub fn node(&self) -> &T {
// SAFETY: Holding on to `parsed` ensures that the AST to which `node` belongs is still alive
// and not moved.
// SAFETY: Holding on to `parsed` ensures that the AST to which `node` belongs is still
// alive and not moved.
unsafe { self.node.as_ref() }
}
}
Expand Down
17 changes: 10 additions & 7 deletions crates/red_knot_python_semantic/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,30 @@ use red_knot_module_resolver::Db as ResolverDb;
use ruff_db::{Db as SourceDb, Upcast};

use crate::semantic_index::definition::Definition;
use crate::semantic_index::symbol::{public_symbols_map, PublicSymbolId, ScopeId};
use crate::semantic_index::{root_scope, semantic_index, symbol_table};
use crate::semantic_index::symbol::ScopeId;
use crate::semantic_index::usedef::Expression;
use crate::semantic_index::{root_scope, semantic_index, symbol_table, use_def_map};
use crate::types::{
infer_types, public_symbol_ty, ClassType, FunctionType, IntersectionType, UnionType,
infer_definition_types, infer_expression_types, infer_scope_types, ClassType, FunctionType,
IntersectionType, UnionType,
};

#[salsa::jar(db=Db)]
pub struct Jar(
ScopeId<'_>,
PublicSymbolId<'_>,
Definition<'_>,
Expression<'_>,
FunctionType<'_>,
ClassType<'_>,
UnionType<'_>,
IntersectionType<'_>,
symbol_table,
use_def_map,
root_scope,
semantic_index,
infer_types,
public_symbol_ty,
public_symbols_map,
infer_definition_types,
infer_expression_types,
infer_scope_types,
);

/// Database giving access to semantic information about a Python program.
Expand Down
136 changes: 81 additions & 55 deletions crates/red_knot_python_semantic/src/semantic_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@ use ruff_index::{IndexSlice, IndexVec};
use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
use crate::semantic_index::ast_ids::AstIds;
use crate::semantic_index::builder::SemanticIndexBuilder;
use crate::semantic_index::definition::{Definition, DefinitionNodeKey, DefinitionNodeRef};
use crate::semantic_index::definition::{Definition, DefinitionNodeKey};
use crate::semantic_index::symbol::{
FileScopeId, NodeWithScopeKey, NodeWithScopeRef, PublicSymbolId, Scope, ScopeId,
ScopedSymbolId, SymbolTable,
FileScopeId, NodeWithScopeKey, NodeWithScopeRef, Scope, ScopeId, ScopedSymbolId, SymbolTable,
};
use crate::semantic_index::usedef::UseDefMap;
use crate::Db;

pub mod ast_ids;
mod builder;
pub mod definition;
pub mod symbol;
pub mod usedef;

type SymbolMap = hashbrown::HashMap<ScopedSymbolId, (), ()>;

Expand All @@ -42,13 +43,26 @@ pub(crate) fn semantic_index(db: &dyn Db, file: File) -> SemanticIndex<'_> {
/// Salsa can avoid invalidating dependent queries if this scope's symbol table
/// is unchanged.
#[salsa::tracked]
pub(crate) fn symbol_table<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Arc<SymbolTable<'db>> {
pub(crate) fn symbol_table<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Arc<SymbolTable> {
let _span = tracing::trace_span!("symbol_table", ?scope).entered();
let index = semantic_index(db, scope.file(db));

index.symbol_table(scope.file_scope_id(db))
}

/// Returns the use-def map for a specific `scope`.
///
/// Using [`use_def_map`] over [`semantic_index`] has the advantage that
/// Salsa can avoid invalidating dependent queries if this scope's use-def map
/// is unchanged.
#[salsa::tracked]
pub(crate) fn use_def_map<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Arc<UseDefMap<'db>> {
let _span = tracing::trace_span!("use_def_map", ?scope).entered();
let index = semantic_index(db, scope.file(db));

index.use_def_map(scope.file_scope_id(db))
}

/// Returns the root scope of `file`.
#[salsa::tracked]
pub(crate) fn root_scope(db: &dyn Db, file: File) -> ScopeId<'_> {
Expand All @@ -57,24 +71,11 @@ pub(crate) fn root_scope(db: &dyn Db, file: File) -> ScopeId<'_> {
FileScopeId::root().to_scope_id(db, file)
}

/// Returns the symbol with the given name in `file`'s public scope or `None` if
/// no symbol with the given name exists.
pub(crate) fn public_symbol<'db>(
db: &'db dyn Db,
file: File,
name: &str,
) -> Option<PublicSymbolId<'db>> {
let root_scope = root_scope(db, file);
let symbol_table = symbol_table(db, root_scope);
let local = symbol_table.symbol_id_by_name(name)?;
Some(local.to_public_symbol(db, file))
}

/// The symbol tables for an entire file.
/// The symbol tables and use-def maps for all scopes in a file.
#[derive(Debug)]
pub(crate) struct SemanticIndex<'db> {
/// List of all symbol tables in this file, indexed by scope.
symbol_tables: IndexVec<FileScopeId, Arc<SymbolTable<'db>>>,
symbol_tables: IndexVec<FileScopeId, Arc<SymbolTable>>,

/// List of all scopes in this file.
scopes: IndexVec<FileScopeId, Scope>,
Expand All @@ -84,7 +85,7 @@ pub(crate) struct SemanticIndex<'db> {
/// an [`ast::Expr`] to an [`ExpressionId`] (which requires knowing the scope).
scopes_by_expression: FxHashMap<ExpressionNodeKey, FileScopeId>,

/// Maps from a node creating a definition node to its definition.
/// Maps from a node creating a definition to its definition.
definitions_by_node: FxHashMap<DefinitionNodeKey, Definition<'db>>,

/// Map from nodes that create a scope to the scope they create.
Expand All @@ -93,6 +94,9 @@ pub(crate) struct SemanticIndex<'db> {
/// Map from the file-local [`FileScopeId`] to the salsa-ingredient [`ScopeId`].
scope_ids_by_scope: IndexVec<FileScopeId, ScopeId<'db>>,

/// Use-def map for each scope in this file.
use_def_maps: IndexVec<FileScopeId, Arc<UseDefMap<'db>>>,

/// Lookup table to map between node ids and ast nodes.
///
/// Note: We should not depend on this map when analysing other files or
Expand All @@ -105,10 +109,18 @@ impl<'db> SemanticIndex<'db> {
///
/// Use the Salsa cached [`symbol_table`] query if you only need the
/// symbol table for a single scope.
pub(super) fn symbol_table(&self, scope_id: FileScopeId) -> Arc<SymbolTable<'db>> {
pub(super) fn symbol_table(&self, scope_id: FileScopeId) -> Arc<SymbolTable> {
self.symbol_tables[scope_id].clone()
}

/// Returns the use-def map for a specific scope.
///
/// Use the Salsa cached [`use_def_map`] query if you only need the
/// use-def map for a single scope.
pub(super) fn use_def_map(&self, scope_id: FileScopeId) -> Arc<UseDefMap> {
self.use_def_maps[scope_id].clone()
}

pub(crate) fn ast_ids(&self, scope_id: FileScopeId) -> &AstIds {
&self.ast_ids[scope_id]
}
Expand Down Expand Up @@ -157,16 +169,17 @@ impl<'db> SemanticIndex<'db> {
}

/// Returns an iterator over all ancestors of `scope`, starting with `scope` itself.
#[allow(unused)]
pub(crate) fn ancestor_scopes(&self, scope: FileScopeId) -> AncestorsIter {
AncestorsIter::new(self, scope)
}

/// Returns the [`Definition`] salsa ingredient for `definition_node`.
pub(crate) fn definition<'def>(
pub(crate) fn definition(
&self,
definition_node: impl Into<DefinitionNodeRef<'def>>,
definition_key: impl Into<DefinitionNodeKey>,
) -> Definition<'db> {
self.definitions_by_node[&definition_node.into().key()]
self.definitions_by_node[&definition_key.into()]
}

/// Returns the id of the scope that `node` creates. This is different from [`Definition::scope`] which
Expand All @@ -176,8 +189,6 @@ impl<'db> SemanticIndex<'db> {
}
}

/// ID that uniquely identifies an expression inside a [`Scope`].

pub struct AncestorsIter<'a> {
scopes: &'a IndexSlice<FileScopeId, Scope>,
next_id: Option<FileScopeId>,
Expand Down Expand Up @@ -278,7 +289,7 @@ mod tests {

use crate::db::tests::TestDb;
use crate::semantic_index::symbol::{FileScopeId, Scope, ScopeKind, SymbolTable};
use crate::semantic_index::{root_scope, semantic_index, symbol_table};
use crate::semantic_index::{root_scope, semantic_index, symbol_table, use_def_map};
use crate::Db;

struct TestCase {
Expand Down Expand Up @@ -332,12 +343,14 @@ mod tests {
#[test]
fn import() {
let TestCase { db, file } = test_case("import foo");
let root_table = symbol_table(&db, root_scope(&db, file));
let scope = root_scope(&db, file);
let root_table = symbol_table(&db, scope);

assert_eq!(names(&root_table), vec!["foo"]);
let foo = root_table.symbol_by_name("foo").unwrap();
let foo = root_table.symbol_id_by_name("foo").unwrap();

assert_eq!(foo.definitions().len(), 1);
let use_def = use_def_map(&db, scope);
assert_eq!(use_def.public_definitions(foo).len(), 1);
}

#[test]
Expand All @@ -359,41 +372,46 @@ mod tests {
#[test]
fn import_from() {
let TestCase { db, file } = test_case("from bar import foo");
let root_table = symbol_table(&db, root_scope(&db, file));
let scope = root_scope(&db, file);
let root_table = symbol_table(&db, scope);

assert_eq!(names(&root_table), vec!["foo"]);
assert_eq!(
root_table
.symbol_by_name("foo")
.unwrap()
.definitions()
.len(),
1
);
assert!(
root_table
.symbol_by_name("foo")
.is_some_and(|symbol| { symbol.is_defined() || !symbol.is_used() }),
.is_some_and(|symbol| { symbol.is_defined() && !symbol.is_used() }),
"symbols that are defined get the defined flag"
);

let use_def = use_def_map(&db, scope);
assert_eq!(
use_def
.public_definitions(root_table.symbol_id_by_name("foo").expect("symbol exists"))
.len(),
1
);
}

#[test]
fn assign() {
let TestCase { db, file } = test_case("x = foo");
let root_table = symbol_table(&db, root_scope(&db, file));
let scope = root_scope(&db, file);
let root_table = symbol_table(&db, scope);

assert_eq!(names(&root_table), vec!["foo", "x"]);
assert_eq!(
root_table.symbol_by_name("x").unwrap().definitions().len(),
1
);
assert!(
root_table
.symbol_by_name("foo")
.is_some_and(|symbol| { !symbol.is_defined() && symbol.is_used() }),
"a symbol used but not defined in a scope should have only the used flag"
);
let use_def = use_def_map(&db, scope);
assert_eq!(
use_def
.public_definitions(root_table.symbol_id_by_name("x").expect("symbol exists"))
.len(),
1
);
}

#[test]
Expand Down Expand Up @@ -421,8 +439,12 @@ y = 2

let class_table = index.symbol_table(class_scope_id);
assert_eq!(names(&class_table), vec!["x"]);

let use_def = index.use_def_map(class_scope_id);
assert_eq!(
class_table.symbol_by_name("x").unwrap().definitions().len(),
use_def
.public_definitions(class_table.symbol_id_by_name("x").expect("symbol exists"))
.len(),
1
);
}
Expand Down Expand Up @@ -450,11 +472,15 @@ y = 2

let function_table = index.symbol_table(function_scope_id);
assert_eq!(names(&function_table), vec!["x"]);

let use_def = index.use_def_map(function_scope_id);
assert_eq!(
function_table
.symbol_by_name("x")
.unwrap()
.definitions()
use_def
.public_definitions(
function_table
.symbol_id_by_name("x")
.expect("symbol exists")
)
.len(),
1
);
Expand Down Expand Up @@ -490,13 +516,13 @@ def func():
let func2_table = index.symbol_table(func_scope2_id);
assert_eq!(names(&func1_table), vec!["x"]);
assert_eq!(names(&func2_table), vec!["y"]);

let use_def = index.use_def_map(FileScopeId::root());
assert_eq!(
root_table
.symbol_by_name("func")
.unwrap()
.definitions()
use_def
.public_definitions(root_table.symbol_id_by_name("func").expect("symbol exists"))
.len(),
2
1
);
}

Expand Down
Loading

0 comments on commit 524f359

Please sign in to comment.