diff --git a/src/front/spv/function.rs b/src/front/spv/function.rs index 5dc781504e..1f4ade20e2 100644 --- a/src/front/spv/function.rs +++ b/src/front/spv/function.rs @@ -597,7 +597,11 @@ impl<'function> BlockContext<'function> { crate::Span::default(), ) } - super::BodyFragment::Loop { body, continuing } => { + super::BodyFragment::Loop { + body, + continuing, + break_if, + } => { let body = lower_impl(blocks, bodies, body); let continuing = lower_impl(blocks, bodies, continuing); @@ -605,7 +609,7 @@ impl<'function> BlockContext<'function> { crate::Statement::Loop { body, continuing, - break_if: None, + break_if, }, crate::Span::default(), ) diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index c69a230cb0..1221c6c5aa 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -388,6 +388,11 @@ enum BodyFragment { Loop { body: BodyIndex, continuing: BodyIndex, + + /// If the SPIR-V loop's back-edge branch is conditional, this is the + /// expression that must be `false` for the back-edge to be taken, with + /// `true` being for the "loop merge" (which breaks out of the loop). + break_if: Option>, }, Switch { selector: Handle, @@ -429,7 +434,7 @@ struct PhiExpression { expressions: Vec<(spirv::Word, spirv::Word)>, } -#[derive(Debug)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] enum MergeBlockInformation { LoopMerge, LoopContinue, @@ -3114,35 +3119,121 @@ impl> Frontend { get_expr_handle!(condition_id, lexp) }; + // HACK(eddyb) Naga doesn't seem to have this helper, + // so it's declared on the fly here for convenience. + #[derive(Copy, Clone)] + struct BranchTarget { + label_id: spirv::Word, + merge_info: Option, + } + let branch_target = |label_id| BranchTarget { + label_id, + merge_info: ctx.mergers.get(&label_id).copied(), + }; + + let true_target = branch_target(self.next()?); + let false_target = branch_target(self.next()?); + + // Consume branch weights + for _ in 4..inst.wc { + let _ = self.next()?; + } + + // Handle `OpBranchConditional`s used at the end of a loop + // body's "continuing" section as a "conditional backedge", + // i.e. a `do`-`while` condition, or `break if` in WGSL. + + // HACK(eddyb) this has to go to the parent *twice*, because + // `OpLoopMerge` left the "continuing" section nested in the + // loop body in terms of `parent`, but not `BodyFragment`. + let parent_body_idx = ctx.bodies[body_idx].parent; + let parent_parent_body_idx = ctx.bodies[parent_body_idx].parent; + match ctx.bodies[parent_parent_body_idx].data[..] { + // The `OpLoopMerge`'s `continuing` block and the loop's + // backedge block may not be the same, but they'll both + // belong to the same body. + [.., BodyFragment::Loop { + body: loop_body_idx, + continuing: loop_continuing_idx, + break_if: ref mut break_if_slot @ None, + }] if body_idx == loop_continuing_idx => { + // Try both orderings of break-vs-backedge, because + // SPIR-V is symmetrical here, unlike WGSL `break if`. + let break_if_cond = [true, false].into_iter().find_map(|true_breaks| { + let (break_candidate, backedge_candidate) = if true_breaks { + (true_target, false_target) + } else { + (false_target, true_target) + }; + + if break_candidate.merge_info + != Some(MergeBlockInformation::LoopMerge) + { + return None; + } + + // HACK(eddyb) since Naga doesn't explicitly track + // backedges, this is checking for the outcome of + // `OpLoopMerge` below (even if it looks weird). + let backedge_candidate_is_backedge = + backedge_candidate.merge_info.is_none() + && ctx.body_for_label.get(&backedge_candidate.label_id) + == Some(&loop_body_idx); + if !backedge_candidate_is_backedge { + return None; + } + + Some(if true_breaks { + condition + } else { + ctx.expressions.append( + crate::Expression::Unary { + op: crate::UnaryOperator::Not, + expr: condition, + }, + span, + ) + }) + }); + + if let Some(break_if_cond) = break_if_cond { + *break_if_slot = Some(break_if_cond); + + // This `OpBranchConditional` ends the "continuing" + // section of the loop body as normal, with the + // `break if` condition having been stashed above. + break None; + } + } + _ => {} + } + block.extend(emitter.finish(ctx.expressions)); ctx.blocks.insert(block_id, block); let body = &mut ctx.bodies[body_idx]; body.data.push(BodyFragment::BlockId(block_id)); - let true_id = self.next()?; - let false_id = self.next()?; - - let same_target = true_id == false_id; + let same_target = true_target.label_id == false_target.label_id; // Start a body block for the `accept` branch. let accept = ctx.bodies.len(); let mut accept_block = Body::with_parent(body_idx); - // If the `OpBranchConditional`target is somebody else's + // If the `OpBranchConditional` target is somebody else's // merge or continue block, then put a `Break` or `Continue` // statement in this new body block. - if let Some(info) = ctx.mergers.get(&true_id) { + if let Some(info) = true_target.merge_info { merger( match same_target { true => &mut ctx.bodies[body_idx], false => &mut accept_block, }, - info, + &info, ) } else { // Note the body index for the block we're branching to. let prev = ctx.body_for_label.insert( - true_id, + true_target.label_id, match same_target { true => body_idx, false => accept, @@ -3161,10 +3252,10 @@ impl> Frontend { let reject = ctx.bodies.len(); let mut reject_block = Body::with_parent(body_idx); - if let Some(info) = ctx.mergers.get(&false_id) { - merger(&mut reject_block, info) + if let Some(info) = false_target.merge_info { + merger(&mut reject_block, &info) } else { - let prev = ctx.body_for_label.insert(false_id, reject); + let prev = ctx.body_for_label.insert(false_target.label_id, reject); debug_assert!(prev.is_none()); } @@ -3177,11 +3268,6 @@ impl> Frontend { reject, }); - // Consume branch weights - for _ in 4..inst.wc { - let _ = self.next()?; - } - return Ok(()); } Op::Switch => { @@ -3351,6 +3437,7 @@ impl> Frontend { parent_body.data.push(BodyFragment::Loop { body: loop_body_idx, continuing: continue_idx, + break_if: None, }); body_idx = loop_body_idx; } diff --git a/tests/in/spv/do-while.spv b/tests/in/spv/do-while.spv new file mode 100644 index 0000000000..23d45958b9 Binary files /dev/null and b/tests/in/spv/do-while.spv differ diff --git a/tests/in/spv/do-while.spvasm b/tests/in/spv/do-while.spvasm new file mode 100644 index 0000000000..fa27c3544f --- /dev/null +++ b/tests/in/spv/do-while.spvasm @@ -0,0 +1,64 @@ +;; Ensure that `do`-`while`-style loops, with conditional backedges, are properly +;; supported, via `break if` (as `continuing { ... if c { break; } }` is illegal). +;; +;; The SPIR-V below was compiled from this GLSL fragment shader: +;; ```glsl +;; #version 450 +;; +;; void f(bool cond) { +;; do {} while(cond); +;; } +;; +;; void main() { +;; f(false); +;; } +;; ``` + + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + OpSource GLSL 450 + OpName %main "main" + OpName %f_b1_ "f(b1;" + OpName %cond "cond" + OpName %param "param" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %bool = OpTypeBool +%_ptr_Function_bool = OpTypePointer Function %bool + %8 = OpTypeFunction %void %_ptr_Function_bool + %false = OpConstantFalse %bool + + %main = OpFunction %void None %3 + %5 = OpLabel + %param = OpVariable %_ptr_Function_bool Function + OpStore %param %false + %19 = OpFunctionCall %void %f_b1_ %param + OpReturn + OpFunctionEnd + + %f_b1_ = OpFunction %void None %8 + %cond = OpFunctionParameter %_ptr_Function_bool + + %11 = OpLabel + OpBranch %12 + + %12 = OpLabel + OpLoopMerge %14 %15 None + OpBranch %13 + + %13 = OpLabel + OpBranch %15 + +;; This is the "continuing" block, and it contains a conditional branch between +;; the backedge (back to the loop header) and the loop merge ("break") target. + %15 = OpLabel + %16 = OpLoad %bool %cond + OpBranchConditional %16 %12 %14 + + %14 = OpLabel + OpReturn + + OpFunctionEnd diff --git a/tests/out/glsl/do-while.main.Fragment.glsl b/tests/out/glsl/do-while.main.Fragment.glsl new file mode 100644 index 0000000000..1d1ca4e8ec --- /dev/null +++ b/tests/out/glsl/do-while.main.Fragment.glsl @@ -0,0 +1,33 @@ +#version 310 es + +precision highp float; +precision highp int; + + +void fb1_(inout bool cond) { + bool loop_init = true; + while(true) { + if (!loop_init) { + bool _e6 = cond; + bool unnamed = !(_e6); + if (unnamed) { + break; + } + } + loop_init = false; + continue; + } + return; +} + +void main_1() { + bool param = false; + param = false; + fb1_(param); + return; +} + +void main() { + main_1(); +} + diff --git a/tests/out/hlsl/do-while.hlsl b/tests/out/hlsl/do-while.hlsl new file mode 100644 index 0000000000..17341a6cfc --- /dev/null +++ b/tests/out/hlsl/do-while.hlsl @@ -0,0 +1,31 @@ + +void fb1_(inout bool cond) +{ + bool loop_init = true; + while(true) { + if (!loop_init) { + bool _expr6 = cond; + bool unnamed = !(_expr6); + if (unnamed) { + break; + } + } + loop_init = false; + continue; + } + return; +} + +void main_1() +{ + bool param = (bool)0; + + param = false; + fb1_(param); + return; +} + +void main() +{ + main_1(); +} diff --git a/tests/out/hlsl/do-while.hlsl.config b/tests/out/hlsl/do-while.hlsl.config new file mode 100644 index 0000000000..98453a04ee --- /dev/null +++ b/tests/out/hlsl/do-while.hlsl.config @@ -0,0 +1,3 @@ +vertex=() +fragment=(main:ps_5_1 ) +compute=() diff --git a/tests/out/msl/do-while.msl b/tests/out/msl/do-while.msl new file mode 100644 index 0000000000..035965cd2e --- /dev/null +++ b/tests/out/msl/do-while.msl @@ -0,0 +1,37 @@ +// language: metal2.0 +#include +#include + +using metal::uint; + + +void fb1_( + thread bool& cond +) { + bool loop_init = true; + while(true) { + if (!loop_init) { + bool _e6 = cond; + bool unnamed = !(_e6); + if (!(cond)) { + break; + } + } + loop_init = false; + continue; + } + return; +} + +void main_1( +) { + bool param = {}; + param = false; + fb1_(param); + return; +} + +fragment void main_( +) { + main_1(); +} diff --git a/tests/out/wgsl/do-while.wgsl b/tests/out/wgsl/do-while.wgsl new file mode 100644 index 0000000000..d444169e41 --- /dev/null +++ b/tests/out/wgsl/do-while.wgsl @@ -0,0 +1,24 @@ +fn fb1_(cond: ptr) { + loop { + continue; + continuing { + let _e6 = (*cond); + _ = !(_e6); + break if !(_e6); + } + } + return; +} + +fn main_1() { + var param: bool; + + param = false; + fb1_((¶m)); + return; +} + +@fragment +fn main() { + main_1(); +} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 8e75f5b742..09e98607be 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -640,6 +640,11 @@ fn convert_spv_all() { convert_spv("degrees", false, Targets::empty()); convert_spv("binding-arrays.dynamic", true, Targets::WGSL); convert_spv("binding-arrays.static", true, Targets::WGSL); + convert_spv( + "do-while", + true, + Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, + ); } #[cfg(feature = "glsl-in")]