From 194777738742f69673d72b7cc3815e0748ea054d Mon Sep 17 00:00:00 2001 From: Henno Date: Thu, 18 Mar 2021 20:04:01 -0400 Subject: [PATCH 1/7] Add basic support for struct DSTs --- .../src/codegen_cx/entry.rs | 114 +++++++++++++----- 1 file changed, 84 insertions(+), 30 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs index 4a428e19a0..ace6300375 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs @@ -1,5 +1,4 @@ use super::CodegenCx; -use crate::abi::ConvSpirvType; use crate::builder_spirv::SpirvValue; use crate::spirv_type::SpirvType; use crate::symbols::{parse_attrs, Entry, SpirvAttribute}; @@ -9,7 +8,10 @@ use rustc_hir as hir; use rustc_middle::ty::layout::TyAndLayout; use rustc_middle::ty::{Instance, Ty}; use rustc_span::Span; -use rustc_target::abi::call::{FnAbi, PassMode}; +use rustc_target::abi::{ + call::{ArgAbi, ArgAttribute, ArgAttributes, FnAbi, PassMode}, + Size, +}; use std::collections::HashMap; impl<'tcx> CodegenCx<'tcx> { @@ -36,8 +38,27 @@ impl<'tcx> CodegenCx<'tcx> { }; let fn_hir_id = self.tcx.hir().local_def_id_to_hir_id(local_id); let body = self.tcx.hir().body(self.tcx.hir().body_owned_by(fn_hir_id)); + const EMPTY: ArgAttribute = ArgAttribute::empty(); for (abi, arg) in fn_abi.args.iter().zip(body.params) { if let PassMode::Direct(_) = abi.mode { + } else if let PassMode::Pair( + // plain DST/RTA/VLA + ArgAttributes { + pointee_size: Size::ZERO, + .. + }, + ArgAttributes { regular: EMPTY, .. }, + ) = abi.mode + { + } else if let PassMode::Pair( + // DST struct with fields before the DST member + ArgAttributes { .. }, + ArgAttributes { + pointee_size: Size::ZERO, + .. + }, + ) = abi.mode + { } else { self.tcx.sess.span_err( arg.span, @@ -62,7 +83,7 @@ impl<'tcx> CodegenCx<'tcx> { self.shader_entry_stub( self.tcx.def_span(instance.def_id()), entry_func, - fn_abi, + &fn_abi.args, body.params, name, execution_model, @@ -81,7 +102,7 @@ impl<'tcx> CodegenCx<'tcx> { &self, span: Span, entry_func: SpirvValue, - entry_fn_abi: &FnAbi<'tcx, Ty<'tcx>>, + arg_abis: &[ArgAbi<'tcx, Ty<'tcx>>], hir_params: &[hir::Param<'tcx>], name: String, execution_model: ExecutionModel, @@ -92,11 +113,11 @@ impl<'tcx> CodegenCx<'tcx> { arguments: vec![], } .def(span, self); - let entry_func_return_type = match self.lookup_type(entry_func.ty) { + let (entry_func_return_type, entry_func_arg_types) = match self.lookup_type(entry_func.ty) { SpirvType::Function { return_type, - arguments: _, - } => return_type, + arguments, + } => (return_type, arguments), other => self.tcx.sess.fatal(&format!( "Invalid entry_stub type: {}", other.debug(entry_func.ty, self) @@ -104,40 +125,73 @@ impl<'tcx> CodegenCx<'tcx> { }; let mut decoration_locations = HashMap::new(); // Create OpVariables before OpFunction so they're global instead of local vars. - let arguments = entry_fn_abi - .args - .iter() - .zip(hir_params) - .map(|(entry_fn_arg, hir_param)| { - self.declare_parameter(entry_fn_arg.layout, hir_param, &mut decoration_locations) - }) - .collect::>(); + let new_spirv = self.emit_global().version().unwrap() > (1, 3); + let arg_len = arg_abis.len(); + let mut arguments = Vec::with_capacity(arg_len); + let mut interface = Vec::with_capacity(arg_len); + let mut rta_lens = Vec::with_capacity(arg_len / 2); + let mut arg_types = entry_func_arg_types.iter(); + for (hir_param, arg_abi) in hir_params.iter().zip(arg_abis) { + // explicit next because there are two args for scalar pairs, but only one param & abi + let arg_t = *arg_types.next().unwrap_or_else(|| { + self.tcx.sess.span_fatal( + hir_param.span, + &format!( + "Invalid function arguments: Param {:?} Abi {:?} missing type", + hir_param, arg_abi.layout.ty + ), + ) + }); + let (argument, storage_class) = + self.declare_parameter(arg_abi.layout, hir_param, arg_t, &mut decoration_locations); + // SPIR-V <= v1.3 only includes Input and Output in the interface. + if new_spirv + || storage_class == StorageClass::Input + || storage_class == StorageClass::Output + { + interface.push(argument); + } + arguments.push(argument); + if let SpirvType::Pointer { pointee } = self.lookup_type(arg_t) { + if let SpirvType::Adt { + size: None, + field_types, + .. + } = self.lookup_type(pointee) + { + let len_t = *arg_types.next().unwrap_or_else(|| { + self.tcx.sess.span_fatal( + hir_param.span, + &format!( + "Invalid function arguments: Param {:?} Abi {:?} fat pointer missing length", + hir_param, arg_abi.layout.ty + ), + ) + }); + rta_lens.push((arguments.len() as u32, len_t, field_types.len() as u32 - 1)); + arguments.push(u32::MAX); + } + } + } let mut emit = self.emit_global(); let fn_id = emit .begin_function(void, None, FunctionControl::NONE, fn_void_void) .unwrap(); emit.begin_block(None).unwrap(); + rta_lens.iter().for_each(|&(len_idx, len_t, member_idx)| { + arguments[len_idx as usize] = emit + .array_length(len_t, None, arguments[len_idx as usize - 1], member_idx) + .unwrap() + }); emit.function_call( entry_func_return_type, None, entry_func.def_cx(self), - arguments.iter().map(|&(a, _)| a), + arguments, ) .unwrap(); emit.ret().unwrap(); emit.end_function().unwrap(); - - let interface: Vec<_> = if emit.version().unwrap() > (1, 3) { - // SPIR-V >= v1.4 includes all OpVariables in the interface. - arguments.into_iter().map(|(a, _)| a).collect() - } else { - // SPIR-V <= v1.3 only includes Input and Output in the interface. - arguments - .into_iter() - .filter(|&(_, s)| s == StorageClass::Input || s == StorageClass::Output) - .map(|(a, _)| a) - .collect() - }; emit.entry_point(execution_model, fn_id, name, interface); fn_id } @@ -146,6 +200,7 @@ impl<'tcx> CodegenCx<'tcx> { &self, layout: TyAndLayout<'tcx>, hir_param: &hir::Param<'tcx>, + arg_t: Word, decoration_locations: &mut HashMap, ) -> (Word, StorageClass) { let storage_class = crate::abi::get_storage_class(self, layout).unwrap_or_else(|| { @@ -159,10 +214,9 @@ impl<'tcx> CodegenCx<'tcx> { StorageClass::Input | StorageClass::Output | StorageClass::UniformConstant ); // Note: this *declares* the variable too. - let spirv_type = layout.spirv_type(hir_param.span, self); let variable = self .emit_global() - .variable(spirv_type, None, storage_class, None); + .variable(arg_t, None, storage_class, None); if let hir::PatKind::Binding(_, _, ident, _) = &hir_param.pat.kind { self.emit_global().name(variable, ident.to_string()); } From 7e9728dc4d22072a5d1939ea913b07877e9eb4c5 Mon Sep 17 00:00:00 2001 From: Henno Date: Fri, 19 Mar 2021 15:46:25 -0400 Subject: [PATCH 2/7] Add tests --- crates/spirv-builder/src/test/basic.rs | 135 ++++++++++++++++++++++++- crates/spirv-builder/src/test/mod.rs | 29 ++++++ 2 files changed, 163 insertions(+), 1 deletion(-) diff --git a/crates/spirv-builder/src/test/basic.rs b/crates/spirv-builder/src/test/basic.rs index fd32fa61c3..0744e7545d 100644 --- a/crates/spirv-builder/src/test/basic.rs +++ b/crates/spirv-builder/src/test/basic.rs @@ -1,4 +1,4 @@ -use super::{dis_fn, dis_globals, val, val_vulkan}; +use super::{dis_fn, dis_entry_fn, dis_globals, val, val_vulkan}; use std::ffi::OsStr; struct SetEnvVar<'a> { @@ -479,3 +479,136 @@ fn ptr_copy_from_method() { "# ); } + +#[test] +fn index_user_dst() { + dis_entry_fn( + r#" +fn index_user_dst(slice: &SliceF32, idx: usize) -> f32 { + slice.rta[idx] +} + +#[spirv(fragment)] +pub fn main( + #[spirv(descriptor_set = 0, binding = 0)] slice: Uniform, +) { + index_user_dst(&slice, 0); +} + +pub struct SliceF32 { + rta: [f32], +} + "#, + "main", + r#"%1 = OpFunction %2 None %3 +%4 = OpLabel +%5 = OpVariable %6 Function +%7 = OpArrayLength %8 %9 0 +%10 = OpAccessChain %11 %5 %12 +OpStore %10 %9 +%13 = OpAccessChain %14 %5 %15 +OpStore %13 %7 +%16 = OpAccessChain %11 %5 %12 +%17 = OpLoad %18 %16 +%19 = OpAccessChain %14 %5 %15 +%20 = OpLoad %8 %19 +%21 = OpCompositeInsert %22 %17 %23 0 +%24 = OpCompositeInsert %22 %20 %21 1 +%25 = OpCompositeExtract %18 %24 0 +%26 = OpCompositeExtract %8 %24 1 +%27 = OpCompositeInsert %22 %25 %23 0 +%28 = OpCompositeInsert %22 %26 %27 1 +%29 = OpAccessChain %30 %25 %12 +%31 = OpULessThan %32 %12 %26 +OpSelectionMerge %33 None +OpBranchConditional %31 %34 %35 +%34 = OpLabel +%36 = OpAccessChain %30 %25 %12 +%37 = OpInBoundsAccessChain %38 %36 %12 +%39 = OpLoad %40 %37 +OpReturn +%35 = OpLabel +OpBranch %41 +%41 = OpLabel +OpBranch %42 +%42 = OpLabel +%43 = OpPhi %32 %44 %41 %44 %45 +OpLoopMerge %46 %45 None +OpBranchConditional %43 %47 %46 +%47 = OpLabel +OpBranch %45 +%45 = OpLabel +OpBranch %42 +%46 = OpLabel +OpUnreachable +%33 = OpLabel +OpUnreachable +OpFunctionEnd"#, + ) +} + +#[test] +fn entry_stub_array_length() { + dis_globals( + r#" +#[spirv(compute(threads(32)))] +pub fn main_cs( + #[spirv(descriptor_set = 0, binding = 0)] slice: Uniform>, +) { + let _ = slice.index(0); +} + +pub struct Slice { + rta: [T], +} + +impl Slice { + fn index(&self, idx: usize) -> T { + self.rta[idx] + } +} + "#, + r#"OpCapability Shader +OpCapability VulkanMemoryModel +OpCapability VariablePointers +OpExtension "SPV_KHR_vulkan_memory_model" +OpMemoryModel Logical Vulkan +OpEntryPoint GLCompute %1 "main_cs" +OpExecutionMode %1 LocalSize 32 1 1 +OpMemberName %2 0 "rta" +OpName %2 "Slice" +OpName %3 "&Slice" +OpMemberName %2 0 "rta" +OpName %4 "slice" +OpMemberName %2 0 "rta" +OpMemberDecorate %2 0 Offset 0 +OpMemberDecorate %3 0 Offset 0 +OpMemberDecorate %3 1 Offset 4 +OpDecorate %4 DescriptorSet 0 +OpDecorate %4 Binding 0 +%5 = OpTypeFloat 32 +%6 = OpTypeRuntimeArray %5 +%2 = OpTypeStruct %6 +%7 = OpTypePointer Uniform %2 +%8 = OpTypeInt 32 0 +%3 = OpTypeStruct %7 %8 +%9 = OpTypePointer Function %3 +%10 = OpTypePointer Function %7 +%11 = OpConstant %8 0 +%12 = OpTypePointer Function %8 +%13 = OpConstant %8 1 +%14 = OpTypePointer Uniform %6 +%15 = OpTypeBool +%16 = OpTypeVoid +%17 = OpTypePointer Uniform %5 +%18 = OpTypePointer Function %5 +%19 = OpTypeFunction %16 +%4 = OpVariable %7 Uniform +%20 = OpTypePointer Function %9 +%21 = OpUndef %3 +%22 = OpConstantFalse %15 +%23 = OpConstantTrue %15 +%24 = OpConstantFalse %15 +%25 = OpConstantTrue %15"# + ) +} diff --git a/crates/spirv-builder/src/test/mod.rs b/crates/spirv-builder/src/test/mod.rs index 0dc909ae0a..db00ac6dac 100644 --- a/crates/spirv-builder/src/test/mod.rs +++ b/crates/spirv-builder/src/test/mod.rs @@ -159,6 +159,35 @@ fn dis_fn(src: &str, func: &str, expect: &str) { assert_str_eq(expect, &func.disassemble()) } +fn dis_entry_fn(src: &str, func: &str, expect: &str) { + let _lock = global_lock(); + let module = read_module(&build(src)).unwrap(); + let id = module + .entry_points + .iter() + .find(|inst| { + inst.operands.last().unwrap().unwrap_literal_string() == func + }) + .unwrap_or_else(|| { + panic!( + "no entry point with the name `{}` found in:\n{}\n", + func, + module.disassemble() + ) + }) + .operands[1] + .unwrap_id_ref(); + let mut func = module + .functions + .into_iter() + .find(|f| f.def_id().unwrap() == id) + .unwrap(); + // Compact to make IDs more stable + compact_ids(&mut func); + use rspirv::binary::Disassemble; + assert_str_eq(expect, &func.disassemble()) +} + fn dis_globals(src: &str, expect: &str) { let _lock = global_lock(); let module = read_module(&build(src)).unwrap(); From 57abe134dbf8abde766acc03080be884d7fd00ed Mon Sep 17 00:00:00 2001 From: Henno Date: Fri, 19 Mar 2021 17:00:23 -0400 Subject: [PATCH 3/7] cleanup tests --- crates/spirv-builder/src/test/basic.rs | 119 +++++-------------------- crates/spirv-builder/src/test/mod.rs | 4 +- 2 files changed, 25 insertions(+), 98 deletions(-) diff --git a/crates/spirv-builder/src/test/basic.rs b/crates/spirv-builder/src/test/basic.rs index 0744e7545d..509f504502 100644 --- a/crates/spirv-builder/src/test/basic.rs +++ b/crates/spirv-builder/src/test/basic.rs @@ -1,4 +1,4 @@ -use super::{dis_fn, dis_entry_fn, dis_globals, val, val_vulkan}; +use super::{dis_entry_fn, dis_fn, dis_globals, val, val_vulkan}; use std::ffi::OsStr; struct SetEnvVar<'a> { @@ -484,15 +484,12 @@ fn ptr_copy_from_method() { fn index_user_dst() { dis_entry_fn( r#" -fn index_user_dst(slice: &SliceF32, idx: usize) -> f32 { - slice.rta[idx] -} - #[spirv(fragment)] pub fn main( #[spirv(descriptor_set = 0, binding = 0)] slice: Uniform, ) { - index_user_dst(&slice, 0); + let float: f32 = slice.rta[0]; + let _ = float; } pub struct SliceF32 { @@ -516,99 +513,31 @@ OpStore %13 %7 %24 = OpCompositeInsert %22 %20 %21 1 %25 = OpCompositeExtract %18 %24 0 %26 = OpCompositeExtract %8 %24 1 -%27 = OpCompositeInsert %22 %25 %23 0 -%28 = OpCompositeInsert %22 %26 %27 1 -%29 = OpAccessChain %30 %25 %12 -%31 = OpULessThan %32 %12 %26 -OpSelectionMerge %33 None -OpBranchConditional %31 %34 %35 -%34 = OpLabel -%36 = OpAccessChain %30 %25 %12 -%37 = OpInBoundsAccessChain %38 %36 %12 -%39 = OpLoad %40 %37 +%27 = OpAccessChain %28 %25 %12 +%29 = OpULessThan %30 %12 %26 +OpSelectionMerge %31 None +OpBranchConditional %29 %32 %33 +%32 = OpLabel +%34 = OpAccessChain %28 %25 %12 +%35 = OpInBoundsAccessChain %36 %34 %12 +%37 = OpLoad %38 %35 OpReturn -%35 = OpLabel -OpBranch %41 -%41 = OpLabel -OpBranch %42 -%42 = OpLabel -%43 = OpPhi %32 %44 %41 %44 %45 -OpLoopMerge %46 %45 None -OpBranchConditional %43 %47 %46 -%47 = OpLabel -OpBranch %45 +%33 = OpLabel +OpBranch %39 +%39 = OpLabel +OpBranch %40 +%40 = OpLabel +%41 = OpPhi %30 %42 %39 %42 %43 +OpLoopMerge %44 %43 None +OpBranchConditional %41 %45 %44 %45 = OpLabel -OpBranch %42 -%46 = OpLabel +OpBranch %43 +%43 = OpLabel +OpBranch %40 +%44 = OpLabel OpUnreachable -%33 = OpLabel +%31 = OpLabel OpUnreachable OpFunctionEnd"#, ) } - -#[test] -fn entry_stub_array_length() { - dis_globals( - r#" -#[spirv(compute(threads(32)))] -pub fn main_cs( - #[spirv(descriptor_set = 0, binding = 0)] slice: Uniform>, -) { - let _ = slice.index(0); -} - -pub struct Slice { - rta: [T], -} - -impl Slice { - fn index(&self, idx: usize) -> T { - self.rta[idx] - } -} - "#, - r#"OpCapability Shader -OpCapability VulkanMemoryModel -OpCapability VariablePointers -OpExtension "SPV_KHR_vulkan_memory_model" -OpMemoryModel Logical Vulkan -OpEntryPoint GLCompute %1 "main_cs" -OpExecutionMode %1 LocalSize 32 1 1 -OpMemberName %2 0 "rta" -OpName %2 "Slice" -OpName %3 "&Slice" -OpMemberName %2 0 "rta" -OpName %4 "slice" -OpMemberName %2 0 "rta" -OpMemberDecorate %2 0 Offset 0 -OpMemberDecorate %3 0 Offset 0 -OpMemberDecorate %3 1 Offset 4 -OpDecorate %4 DescriptorSet 0 -OpDecorate %4 Binding 0 -%5 = OpTypeFloat 32 -%6 = OpTypeRuntimeArray %5 -%2 = OpTypeStruct %6 -%7 = OpTypePointer Uniform %2 -%8 = OpTypeInt 32 0 -%3 = OpTypeStruct %7 %8 -%9 = OpTypePointer Function %3 -%10 = OpTypePointer Function %7 -%11 = OpConstant %8 0 -%12 = OpTypePointer Function %8 -%13 = OpConstant %8 1 -%14 = OpTypePointer Uniform %6 -%15 = OpTypeBool -%16 = OpTypeVoid -%17 = OpTypePointer Uniform %5 -%18 = OpTypePointer Function %5 -%19 = OpTypeFunction %16 -%4 = OpVariable %7 Uniform -%20 = OpTypePointer Function %9 -%21 = OpUndef %3 -%22 = OpConstantFalse %15 -%23 = OpConstantTrue %15 -%24 = OpConstantFalse %15 -%25 = OpConstantTrue %15"# - ) -} diff --git a/crates/spirv-builder/src/test/mod.rs b/crates/spirv-builder/src/test/mod.rs index db00ac6dac..ece32c0990 100644 --- a/crates/spirv-builder/src/test/mod.rs +++ b/crates/spirv-builder/src/test/mod.rs @@ -165,9 +165,7 @@ fn dis_entry_fn(src: &str, func: &str, expect: &str) { let id = module .entry_points .iter() - .find(|inst| { - inst.operands.last().unwrap().unwrap_literal_string() == func - }) + .find(|inst| inst.operands.last().unwrap().unwrap_literal_string() == func) .unwrap_or_else(|| { panic!( "no entry point with the name `{}` found in:\n{}\n", From 22cc7750a96f52396937a5282add020474a7b774 Mon Sep 17 00:00:00 2001 From: Henno Date: Tue, 23 Mar 2021 17:55:53 -0400 Subject: [PATCH 4/7] Update with entry changes, address review --- .../src/codegen_cx/entry.rs | 63 +++++++++++++------ crates/spirv-builder/src/test/basic.rs | 63 ++++++++----------- 2 files changed, 70 insertions(+), 56 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs index 500695ef57..f7c96a2e07 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs @@ -5,16 +5,16 @@ use crate::builder_spirv::SpirvValue; use crate::spirv_type::SpirvType; use rspirv::dr::Operand; use rspirv::spirv::{Decoration, ExecutionModel, FunctionControl, StorageClass, Word}; +use rustc_codegen_ssa::traits::BaseTypeMethods; use rustc_hir as hir; -use rustc_middle::ty::layout::TyAndLayout; +use rustc_middle::ty::layout::{HasParamEnv, TyAndLayout}; use rustc_middle::ty::{Instance, Ty, TyKind}; use rustc_span::Span; use rustc_target::abi::{ call::{ArgAbi, ArgAttribute, ArgAttributes, FnAbi, PassMode}, - LayoutOf, - Size, + LayoutOf, Size, }; -use std::collections::HashMap; +use std::{collections::HashMap, iter}; impl<'tcx> CodegenCx<'tcx> { // Entry points declare their "interface" (all uniforms, inputs, outputs, etc.) as parameters. @@ -114,11 +114,8 @@ impl<'tcx> CodegenCx<'tcx> { arguments: vec![], } .def(span, self); - let (entry_func_return_type, entry_func_arg_types) = match self.lookup_type(entry_func.ty) { - SpirvType::Function { - return_type, - arguments, - } => (return_type, arguments), + let entry_func_return_type = match self.lookup_type(entry_func.ty) { + SpirvType::Function { return_type, .. } => return_type, other => self.tcx.sess.fatal(&format!( "Invalid entry_stub type: {}", other.debug(entry_func.ty, self) @@ -126,14 +123,14 @@ impl<'tcx> CodegenCx<'tcx> { }; let mut decoration_locations = HashMap::new(); // Create OpVariables before OpFunction so they're global instead of local vars. - let declared_params = entry_fn_abi - .args + let declared_params = arg_abis .iter() .zip(hir_params) .map(|(entry_fn_arg, hir_param)| { self.declare_parameter(entry_fn_arg.layout, hir_param, &mut decoration_locations) }) .collect::>(); + let len_t = self.type_isize(); let mut emit = self.emit_global(); let fn_id = emit .begin_function(void, None, FunctionControl::NONE, fn_void_void) @@ -142,14 +139,41 @@ impl<'tcx> CodegenCx<'tcx> { // Adjust any global `OpVariable`s as needed (e.g. loading from `Input`s). let arguments: Vec<_> = declared_params .iter() - .zip(&entry_fn_abi.args) + .zip(arg_abis) .zip(hir_params) - .map(|((&(var, storage_class), entry_fn_arg), hir_param)| { + .flat_map(|((&(var, storage_class), entry_fn_arg), hir_param)| { match entry_fn_arg.layout.ty.kind() { - TyKind::Ref(..) => var, - + TyKind::Ref(_, ty, _) if ty.is_sized(self.tcx.at(span), self.param_env()) => iter::once(var).chain(None), + TyKind::Ref(_, ty, _) => { + match ty.kind() { + TyKind::Adt(adt_def, substs) => { + let (member_idx, field_def) = adt_def.all_fields().enumerate().last().unwrap(); + let field_ty = field_def.ty(self.tcx, substs); + if !matches!(field_ty.kind(), TyKind::Slice(..)) { + self.tcx.sess.span_fatal( + hir_param.ty_span, + "DST parameters are currently restricted to a reference to a struct whose last field is a slice.", + ) + } + let len = emit + .array_length(len_t, None, var, member_idx as u32) + .unwrap(); + iter::once(var).chain(Some(len)) + } + TyKind::Slice(..) | TyKind::Str => self.tcx.sess.span_fatal( + hir_param.ty_span, + "Straight slices are not yet supported, wrap the slice in a newtype.", + ), + // TODO: Is this needed? + TyKind::Dynamic(..) => self.tcx.sess.span_fatal( + hir_param.ty_span, + "Trait objects are not supported.", + ), + _ => unreachable!(), + } + } _ => match entry_fn_arg.mode { - PassMode::Indirect { .. } => var, + PassMode::Indirect { .. } => iter::once(var).chain(None), PassMode::Direct(_) => { assert_eq!(storage_class, StorageClass::Input); @@ -158,8 +182,10 @@ impl<'tcx> CodegenCx<'tcx> { let value_spirv_type = entry_fn_arg.layout.spirv_type(hir_param.span, self); - emit.load(value_spirv_type, None, var, None, std::iter::empty()) - .unwrap() + let loaded_var = emit + .load(value_spirv_type, None, var, None, iter::empty()) + .unwrap(); + iter::once(loaded_var).chain(None) } _ => unreachable!(), }, @@ -195,7 +221,6 @@ impl<'tcx> CodegenCx<'tcx> { &self, layout: TyAndLayout<'tcx>, hir_param: &hir::Param<'tcx>, - arg_t: Word, decoration_locations: &mut HashMap, ) -> (Word, StorageClass) { let attrs = AggregatedSpirvAttributes::parse(self, self.tcx.hir().attrs(hir_param.hir_id)); diff --git a/crates/spirv-builder/src/test/basic.rs b/crates/spirv-builder/src/test/basic.rs index d6d1821a80..d0e790abfe 100644 --- a/crates/spirv-builder/src/test/basic.rs +++ b/crates/spirv-builder/src/test/basic.rs @@ -486,7 +486,7 @@ fn index_user_dst() { r#" #[spirv(fragment)] pub fn main( - #[spirv(descriptor_set = 0, binding = 0)] slice: Uniform, + #[spirv(uniform, descriptor_set = 0, binding = 0)] slice: &mut SliceF32, ) { let float: f32 = slice.rta[0]; let _ = float; @@ -499,44 +499,33 @@ pub struct SliceF32 { "main", r#"%1 = OpFunction %2 None %3 %4 = OpLabel -%5 = OpVariable %6 Function -%7 = OpArrayLength %8 %9 0 -%10 = OpAccessChain %11 %5 %12 -OpStore %10 %9 -%13 = OpAccessChain %14 %5 %15 -OpStore %13 %7 -%16 = OpAccessChain %11 %5 %12 -%17 = OpLoad %18 %16 -%19 = OpAccessChain %14 %5 %15 -%20 = OpLoad %8 %19 -%21 = OpCompositeInsert %22 %17 %23 0 -%24 = OpCompositeInsert %22 %20 %21 1 -%25 = OpCompositeExtract %18 %24 0 -%26 = OpCompositeExtract %8 %24 1 -%27 = OpAccessChain %28 %25 %12 -%29 = OpULessThan %30 %12 %26 -OpSelectionMerge %31 None -OpBranchConditional %29 %32 %33 -%32 = OpLabel -%34 = OpAccessChain %28 %25 %12 -%35 = OpInBoundsAccessChain %36 %34 %12 -%37 = OpLoad %38 %35 +%5 = OpArrayLength %6 %7 0 +%8 = OpCompositeInsert %9 %7 %10 0 +%11 = OpCompositeInsert %9 %5 %8 1 +%12 = OpAccessChain %13 %7 %14 +%15 = OpULessThan %16 %14 %5 +OpSelectionMerge %17 None +OpBranchConditional %15 %18 %19 +%18 = OpLabel +%20 = OpAccessChain %13 %7 %14 +%21 = OpInBoundsAccessChain %22 %20 %14 +%23 = OpLoad %24 %21 OpReturn -%33 = OpLabel -OpBranch %39 -%39 = OpLabel -OpBranch %40 -%40 = OpLabel -%41 = OpPhi %30 %42 %39 %42 %43 -OpLoopMerge %44 %43 None -OpBranchConditional %41 %45 %44 -%45 = OpLabel -OpBranch %43 -%43 = OpLabel -OpBranch %40 -%44 = OpLabel -OpUnreachable +%19 = OpLabel +OpBranch %25 +%25 = OpLabel +OpBranch %26 +%26 = OpLabel +%27 = OpPhi %16 %28 %25 %28 %29 +OpLoopMerge %30 %29 None +OpBranchConditional %27 %31 %30 %31 = OpLabel +OpBranch %29 +%29 = OpLabel +OpBranch %26 +%30 = OpLabel +OpUnreachable +%17 = OpLabel OpUnreachable OpFunctionEnd"#, ) From 367ac4da05ee7d630a4fef771cf25bd9c8ccf6ff Mon Sep 17 00:00:00 2001 From: Henno Date: Wed, 24 Mar 2021 13:34:54 -0400 Subject: [PATCH 5/7] Address review --- .../src/codegen_cx/entry.rs | 79 +++++++++++-------- 1 file changed, 45 insertions(+), 34 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs index f7c96a2e07..ea29876032 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs @@ -14,7 +14,7 @@ use rustc_target::abi::{ call::{ArgAbi, ArgAttribute, ArgAttributes, FnAbi, PassMode}, LayoutOf, Size, }; -use std::{collections::HashMap, iter}; +use std::collections::HashMap; impl<'tcx> CodegenCx<'tcx> { // Entry points declare their "interface" (all uniforms, inputs, outputs, etc.) as parameters. @@ -142,38 +142,18 @@ impl<'tcx> CodegenCx<'tcx> { .zip(arg_abis) .zip(hir_params) .flat_map(|((&(var, storage_class), entry_fn_arg), hir_param)| { - match entry_fn_arg.layout.ty.kind() { - TyKind::Ref(_, ty, _) if ty.is_sized(self.tcx.at(span), self.param_env()) => iter::once(var).chain(None), + let mut dst_len_arg = None; + let arg = match entry_fn_arg.layout.ty.kind() { TyKind::Ref(_, ty, _) => { - match ty.kind() { - TyKind::Adt(adt_def, substs) => { - let (member_idx, field_def) = adt_def.all_fields().enumerate().last().unwrap(); - let field_ty = field_def.ty(self.tcx, substs); - if !matches!(field_ty.kind(), TyKind::Slice(..)) { - self.tcx.sess.span_fatal( - hir_param.ty_span, - "DST parameters are currently restricted to a reference to a struct whose last field is a slice.", - ) - } - let len = emit - .array_length(len_t, None, var, member_idx as u32) - .unwrap(); - iter::once(var).chain(Some(len)) - } - TyKind::Slice(..) | TyKind::Str => self.tcx.sess.span_fatal( - hir_param.ty_span, - "Straight slices are not yet supported, wrap the slice in a newtype.", - ), - // TODO: Is this needed? - TyKind::Dynamic(..) => self.tcx.sess.span_fatal( - hir_param.ty_span, - "Trait objects are not supported.", - ), - _ => unreachable!(), + if !ty.is_sized(self.tcx.at(span), self.param_env()) { + dst_len_arg.replace( + self.dst_length_argument(&mut emit, ty, hir_param, len_t, var), + ); } + var } _ => match entry_fn_arg.mode { - PassMode::Indirect { .. } => iter::once(var).chain(None), + PassMode::Indirect { .. } => var, PassMode::Direct(_) => { assert_eq!(storage_class, StorageClass::Input); @@ -182,14 +162,13 @@ impl<'tcx> CodegenCx<'tcx> { let value_spirv_type = entry_fn_arg.layout.spirv_type(hir_param.span, self); - let loaded_var = emit - .load(value_spirv_type, None, var, None, iter::empty()) - .unwrap(); - iter::once(loaded_var).chain(None) + emit.load(value_spirv_type, None, var, None, std::iter::empty()) + .unwrap() } _ => unreachable!(), }, - } + }; + std::iter::once(arg).chain(dst_len_arg) }) .collect(); emit.function_call( @@ -217,6 +196,38 @@ impl<'tcx> CodegenCx<'tcx> { fn_id } + fn dst_length_argument( + &self, + emit: &mut std::cell::RefMut<'_, rspirv::dr::Builder>, + ty: Ty<'tcx>, + hir_param: &hir::Param<'tcx>, + len_t: Word, + var: Word, + ) -> Word { + match ty.kind() { + TyKind::Adt(adt_def, substs) => { + let (member_idx, field_def) = adt_def.all_fields().enumerate().last().unwrap(); + let field_ty = field_def.ty(self.tcx, substs); + if !matches!(field_ty.kind(), TyKind::Slice(..)) { + self.tcx.sess.span_fatal( + hir_param.ty_span, + "DST parameters are currently restricted to a reference to a struct whose last field is a slice.", + ) + } + emit.array_length(len_t, None, var, member_idx as u32) + .unwrap() + } + TyKind::Slice(..) | TyKind::Str => self.tcx.sess.span_fatal( + hir_param.ty_span, + "Straight slices are not yet supported, wrap the slice in a newtype.", + ), + _ => self + .tcx + .sess + .span_fatal(hir_param.ty_span, "Unsupported parameter type."), + } + } + fn declare_parameter( &self, layout: TyAndLayout<'tcx>, From 7b9f73f4a8e92bad90b6fa59819e3cb2b8c9f018 Mon Sep 17 00:00:00 2001 From: Henno Date: Fri, 26 Mar 2021 14:06:23 -0400 Subject: [PATCH 6/7] Update allocate_const_scalar.stderr --- tests/ui/lang/core/ptr/allocate_const_scalar.stderr | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ui/lang/core/ptr/allocate_const_scalar.stderr b/tests/ui/lang/core/ptr/allocate_const_scalar.stderr index 1bc13a388c..15428bb61a 100644 --- a/tests/ui/lang/core/ptr/allocate_const_scalar.stderr +++ b/tests/ui/lang/core/ptr/allocate_const_scalar.stderr @@ -2,7 +2,7 @@ error: pointer has non-null integer address | = note: Stack: allocate_const_scalar::main - Unnamed function ID %4 + Unnamed function ID %5 error: invalid binary:0:0 - No OpEntryPoint instruction was found. This is only allowed if the Linkage capability is being used. | From 2704ef1aba603b35922fb114c62b6faffddab43b Mon Sep 17 00:00:00 2001 From: Henno Date: Mon, 29 Mar 2021 11:22:40 -0400 Subject: [PATCH 7/7] Add ArrayStride decoration to OpTypeRuntimeArray --- crates/rustc_codegen_spirv/src/spirv_type.rs | 11 ++++++++ crates/spirv-builder/src/test/basic.rs | 29 ++++++++++---------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/crates/rustc_codegen_spirv/src/spirv_type.rs b/crates/rustc_codegen_spirv/src/spirv_type.rs index 3597e5dda9..b95dc224d7 100644 --- a/crates/rustc_codegen_spirv/src/spirv_type.rs +++ b/crates/rustc_codegen_spirv/src/spirv_type.rs @@ -188,6 +188,17 @@ impl SpirvType { } Self::RuntimeArray { element } => { let result = cx.emit_global().type_runtime_array(element); + // ArrayStride decoration wants in *bytes* + let element_size = cx + .lookup_type(element) + .sizeof(cx) + .expect("Element of sized array must be sized") + .bytes(); + cx.emit_global().decorate( + result, + Decoration::ArrayStride, + iter::once(Operand::LiteralInt32(element_size as u32)), + ); if cx.kernel_mode { cx.zombie_with_span(result, def_span, "RuntimeArray in kernel mode"); } diff --git a/crates/spirv-builder/src/test/basic.rs b/crates/spirv-builder/src/test/basic.rs index d0e790abfe..5829d279de 100644 --- a/crates/spirv-builder/src/test/basic.rs +++ b/crates/spirv-builder/src/test/basic.rs @@ -183,20 +183,21 @@ OpEntryPoint Fragment %1 "main" OpExecutionMode %1 OriginUpperLeft OpName %2 "test_project::add_decorate" OpName %3 "test_project::main" -OpDecorate %4 DescriptorSet 0 -OpDecorate %4 Binding 0 -%5 = OpTypeVoid -%6 = OpTypeFunction %5 -%7 = OpTypeInt 32 0 -%8 = OpTypePointer Function %7 -%9 = OpConstant %7 1 -%10 = OpTypeFloat 32 -%11 = OpTypeImage %10 2D 0 0 0 1 Unknown -%12 = OpTypeSampledImage %11 -%13 = OpTypeRuntimeArray %12 -%14 = OpTypePointer UniformConstant %13 -%4 = OpVariable %14 UniformConstant -%15 = OpTypePointer UniformConstant %12"#, +OpDecorate %4 ArrayStride 4 +OpDecorate %5 DescriptorSet 0 +OpDecorate %5 Binding 0 +%6 = OpTypeVoid +%7 = OpTypeFunction %6 +%8 = OpTypeInt 32 0 +%9 = OpTypePointer Function %8 +%10 = OpConstant %8 1 +%11 = OpTypeFloat 32 +%12 = OpTypeImage %11 2D 0 0 0 1 Unknown +%13 = OpTypeSampledImage %12 +%4 = OpTypeRuntimeArray %13 +%14 = OpTypePointer UniformConstant %4 +%5 = OpVariable %14 UniformConstant +%15 = OpTypePointer UniformConstant %13"#, ); }