Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the alignment checks to match rust-lang/reference#1387 #113343

Merged
merged 1 commit into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 54 additions & 51 deletions compiler/rustc_mir_transform/src/check_alignment.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use crate::MirPass;
use rustc_hir::def_id::DefId;
use rustc_hir::lang_items::LangItem;
use rustc_index::IndexVec;
use rustc_middle::mir::*;
use rustc_middle::mir::{
interpret::Scalar,
visit::{PlaceContext, Visitor},
visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor},
};
use rustc_middle::ty::{Ty, TyCtxt, TypeAndMut};
use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt, TypeAndMut};
use rustc_session::Session;

pub struct CheckAlignment;
Expand All @@ -30,30 +29,32 @@ impl<'tcx> MirPass<'tcx> for CheckAlignment {

let basic_blocks = body.basic_blocks.as_mut();
let local_decls = &mut body.local_decls;
let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id());

// This pass inserts new blocks. Each insertion changes the Location for all
// statements/blocks after. Iterating or visiting the MIR in order would require updating
// our current location after every insertion. By iterating backwards, we dodge this issue:
// The only Locations that an insertion changes have already been handled.
for block in (0..basic_blocks.len()).rev() {
oli-obk marked this conversation as resolved.
Show resolved Hide resolved
let block = block.into();
for statement_index in (0..basic_blocks[block].statements.len()).rev() {
let location = Location { block, statement_index };
let statement = &basic_blocks[block].statements[statement_index];
let source_info = statement.source_info;

let mut finder = PointerFinder {
local_decls,
tcx,
pointers: Vec::new(),
def_id: body.source.def_id(),
};
for (pointer, pointee_ty) in finder.find_pointers(statement) {
debug!("Inserting alignment check for {:?}", pointer.ty(&*local_decls, tcx).ty);
let mut finder =
PointerFinder { tcx, local_decls, param_env, pointers: Vec::new() };
finder.visit_statement(statement, location);

for (local, ty) in finder.pointers {
debug!("Inserting alignment check for {:?}", ty);
let new_block = split_block(basic_blocks, location);
insert_alignment_check(
tcx,
local_decls,
&mut basic_blocks[block],
pointer,
pointee_ty,
local,
ty,
source_info,
new_block,
);
Expand All @@ -63,69 +64,71 @@ impl<'tcx> MirPass<'tcx> for CheckAlignment {
}
}

impl<'tcx, 'a> PointerFinder<'tcx, 'a> {
fn find_pointers(&mut self, statement: &Statement<'tcx>) -> Vec<(Place<'tcx>, Ty<'tcx>)> {
self.pointers.clear();
self.visit_statement(statement, Location::START);
core::mem::take(&mut self.pointers)
}
}

struct PointerFinder<'tcx, 'a> {
local_decls: &'a mut LocalDecls<'tcx>,
tcx: TyCtxt<'tcx>,
def_id: DefId,
local_decls: &'a mut LocalDecls<'tcx>,
param_env: ParamEnv<'tcx>,
pointers: Vec<(Place<'tcx>, Ty<'tcx>)>,
}

impl<'tcx, 'a> Visitor<'tcx> for PointerFinder<'tcx, 'a> {
fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) {
if let Rvalue::AddressOf(..) = rvalue {
// Ignore dereferences inside of an AddressOf
return;
fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
// We want to only check reads and writes to Places, so we specifically exclude
// Borrows and AddressOf.
match context {
saethlin marked this conversation as resolved.
Show resolved Hide resolved
PlaceContext::MutatingUse(
MutatingUseContext::Store
| MutatingUseContext::AsmOutput
| MutatingUseContext::Call
| MutatingUseContext::Yield
| MutatingUseContext::Drop,
) => {}
PlaceContext::NonMutatingUse(
NonMutatingUseContext::Copy | NonMutatingUseContext::Move,
) => {}
_ => {
return;
}
}
self.super_rvalue(rvalue, location);
}

fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) {
if let PlaceContext::NonUse(_) = context {
return;
}
if !place.is_indirect() {
RalfJung marked this conversation as resolved.
Show resolved Hide resolved
return;
}
saethlin marked this conversation as resolved.
Show resolved Hide resolved

// Since Deref projections must come first and only once, the pointer for an indirect place
// is the Local that the Place is based on.
let pointer = Place::from(place.local);
let pointer_ty = pointer.ty(&*self.local_decls, self.tcx).ty;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like the old code was actually already doing the right thing?

Copy link
Member Author

@saethlin saethlin Oct 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically yes, but the first time around (with no knowledge of MIR) I managed to get this right by accident and I've wanted to clarify why this is right ever since. Hopefully the change now does that.

let pointer_ty = self.local_decls[place.local].ty;

// We only want to check unsafe pointers
// We only want to check places based on unsafe pointers
if !pointer_ty.is_unsafe_ptr() {
trace!("Indirect, but not an unsafe ptr, not checking {:?}", pointer_ty);
trace!("Indirect, but not based on an unsafe ptr, not checking {:?}", place);
return;
}

let Some(pointee) = pointer_ty.builtin_deref(true) else {
debug!("Indirect but no builtin deref: {:?}", pointer_ty);
let pointee_ty =
pointer_ty.builtin_deref(true).expect("no builtin_deref for an unsafe pointer").ty;
// Ideally we'd support this in the future, but for now we are limited to sized types.
if !pointee_ty.is_sized(self.tcx, self.param_env) {
debug!("Unsafe pointer, but pointee is not known to be sized: {:?}", pointer_ty);
return;
};
let mut pointee_ty = pointee.ty;
if pointee_ty.is_array() || pointee_ty.is_slice() || pointee_ty.is_str() {
pointee_ty = pointee_ty.sequence_element_type(self.tcx);
}

if !pointee_ty.is_sized(self.tcx, self.tcx.param_env_reveal_all_normalized(self.def_id)) {
debug!("Unsafe pointer, but unsized: {:?}", pointer_ty);
// Try to detect types we are sure have an alignment of 1 and skip the check
// We don't need to look for str and slices, we already rejected unsized types above
let element_ty = match pointee_ty.kind() {
ty::Array(ty, _) => *ty,
_ => pointee_ty,
};
if [self.tcx.types.bool, self.tcx.types.i8, self.tcx.types.u8].contains(&element_ty) {
debug!("Trivially aligned place type: {:?}", pointee_ty);
return;
}

if [self.tcx.types.bool, self.tcx.types.i8, self.tcx.types.u8, self.tcx.types.str_]
.contains(&pointee_ty)
{
debug!("Trivially aligned pointee type: {:?}", pointer_ty);
return;
}
// Ensure that this place is based on an aligned pointer.
self.pointers.push((pointer, pointee_ty));

self.pointers.push((pointer, pointee_ty))
self.super_place(place, context, location);
}
}

Expand Down
2 changes: 1 addition & 1 deletion tests/debuginfo/simple-struct.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// min-lldb-version: 310
// ignore-gdb // Test temporarily ignored due to debuginfo tests being disabled, see PR 47155

// compile-flags:-g
// compile-flags: -g -Zmir-enable-passes=-CheckAlignment

// === GDB TESTS ===================================================================================

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// run-pass
// ignore-wasm32-bare: No panic messages
// compile-flags: -C debug-assertions

struct Misalignment {
saethlin marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -9,7 +8,7 @@ struct Misalignment {
fn main() {
let items: [Misalignment; 2] = [Misalignment { a: 0 }, Misalignment { a: 1 }];
unsafe {
let ptr: *const Misalignment = items.as_ptr().cast::<u8>().add(1).cast::<Misalignment>();
let ptr: *const Misalignment = items.as_ptr().byte_add(1);
let _ptr = core::ptr::addr_of!((*ptr).a);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

fn main() {
let mut x = [0u64; 2];
let ptr: *mut u8 = x.as_mut_ptr().cast::<u8>();
let ptr = x.as_mut_ptr();
unsafe {
let misaligned = ptr.add(4).cast::<u64>();
let misaligned = ptr.byte_add(4);
assert!(misaligned.addr() % 8 != 0);
assert!(misaligned.addr() % 4 == 0);
*misaligned = 42;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

fn main() {
let mut x = [0u32; 2];
let ptr: *mut u8 = x.as_mut_ptr().cast::<u8>();
let ptr = x.as_mut_ptr();
unsafe {
*(ptr.add(1).cast::<u32>()) = 42;
*(ptr.byte_add(1)) = 42;
}
}
13 changes: 13 additions & 0 deletions tests/ui/mir/alignment/misaligned_rhs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// run-fail
// ignore-wasm32-bare: No panic messages
// ignore-i686-pc-windows-msvc: #112480
// compile-flags: -C debug-assertions
// error-pattern: misaligned pointer dereference: address must be a multiple of 0x4 but is

fn main() {
let mut x = [0u32; 2];
let ptr = x.as_mut_ptr();
unsafe {
let _v = *(ptr.byte_add(1));
}
}
29 changes: 29 additions & 0 deletions tests/ui/mir/alignment/packed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// run-pass
// compile-flags: -C debug-assertions

#![feature(strict_provenance, pointer_is_aligned)]

#[repr(packed)]
struct Misaligner {
_head: u8,
tail: u64,
}

fn main() {
let memory = [Misaligner { _head: 0, tail: 0}, Misaligner { _head: 0, tail: 0}];
// Test that we can use addr_of! to get the address of a packed member which according to its
// type is not aligned, but because it is a projection from a packed type is a valid place.
let ptr0 = std::ptr::addr_of!(memory[0].tail);
let ptr1 = std::ptr::addr_of!(memory[0].tail);
// Even if ptr0 happens to be aligned by chance, ptr1 is not.
assert!(!ptr0.is_aligned() || !ptr1.is_aligned());

// And also test that we can get the addr of a packed struct then do a member read from it.
unsafe {
let ptr = std::ptr::addr_of!(memory[0]);
let _tail = (*ptr).tail;

let ptr = std::ptr::addr_of!(memory[1]);
let _tail = (*ptr).tail;
}
}
16 changes: 16 additions & 0 deletions tests/ui/mir/alignment/place_computation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// run-pass
// compile-flags: -C debug-assertions

#[repr(align(8))]
struct Misalignment {
a: u8,
}

fn main() {
let mem = 0u64;
let ptr = &mem as *const u64 as *const Misalignment;
unsafe {
let ptr = ptr.byte_add(1);
let _ref: &u8 = &(*ptr).a;
}
}
9 changes: 9 additions & 0 deletions tests/ui/mir/alignment/place_without_read.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// run-pass
// compile-flags: -C debug-assertions

fn main() {
let ptr = 1 as *const u16;
unsafe {
let _ = *ptr;
}
}
15 changes: 15 additions & 0 deletions tests/ui/mir/alignment/two_pointers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// run-fail
// ignore-wasm32-bare: No panic messages
// ignore-i686-pc-windows-msvc: #112480
// compile-flags: -C debug-assertions
// error-pattern: misaligned pointer dereference: address must be a multiple of 0x4 but is

fn main() {
let x = [0u32; 2];
let ptr = x.as_ptr();
let mut dest = 0u32;
let dest_ptr = &mut dest as *mut u32;
unsafe {
*dest_ptr = *(ptr.byte_add(1));
}
}
Loading