From 3a10639d4f43f749da210cebbf984679f2922556 Mon Sep 17 00:00:00 2001 From: Eduard-Mihai Burtescu Date: Thu, 27 Apr 2023 12:54:43 +0300 Subject: [PATCH] [spv-in] Convert conditional backedges to `break if`. --- src/front/spv/function.rs | 8 +- src/front/spv/mod.rs | 121 ++++++++++++++++++--- tests/in/spv/do-while.spv | Bin 0 -> 480 bytes tests/in/spv/do-while.spvasm | 64 +++++++++++ tests/out/glsl/do-while.main.Fragment.glsl | 33 ++++++ tests/out/hlsl/do-while.hlsl | 31 ++++++ tests/out/hlsl/do-while.hlsl.config | 3 + tests/out/msl/do-while.msl | 37 +++++++ tests/out/wgsl/do-while.wgsl | 24 ++++ tests/snapshots.rs | 5 + 10 files changed, 307 insertions(+), 19 deletions(-) create mode 100644 tests/in/spv/do-while.spv create mode 100644 tests/in/spv/do-while.spvasm create mode 100644 tests/out/glsl/do-while.main.Fragment.glsl create mode 100644 tests/out/hlsl/do-while.hlsl create mode 100644 tests/out/hlsl/do-while.hlsl.config create mode 100644 tests/out/msl/do-while.msl create mode 100644 tests/out/wgsl/do-while.wgsl diff --git a/src/front/spv/function.rs b/src/front/spv/function.rs index 5dc781504e..1f4ade20e2 100644 --- a/src/front/spv/function.rs +++ b/src/front/spv/function.rs @@ -597,7 +597,11 @@ impl<'function> BlockContext<'function> { crate::Span::default(), ) } - super::BodyFragment::Loop { body, continuing } => { + super::BodyFragment::Loop { + body, + continuing, + break_if, + } => { let body = lower_impl(blocks, bodies, body); let continuing = lower_impl(blocks, bodies, continuing); @@ -605,7 +609,7 @@ impl<'function> BlockContext<'function> { crate::Statement::Loop { body, continuing, - break_if: None, + break_if, }, crate::Span::default(), ) diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index c69a230cb0..1221c6c5aa 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -388,6 +388,11 @@ enum BodyFragment { Loop { body: BodyIndex, continuing: BodyIndex, + + /// If the SPIR-V loop's back-edge branch is conditional, this is the + /// expression that must be `false` for the back-edge to be taken, with + /// `true` being for the "loop merge" (which breaks out of the loop). + break_if: Option>, }, Switch { selector: Handle, @@ -429,7 +434,7 @@ struct PhiExpression { expressions: Vec<(spirv::Word, spirv::Word)>, } -#[derive(Debug)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] enum MergeBlockInformation { LoopMerge, LoopContinue, @@ -3114,35 +3119,121 @@ impl> Frontend { get_expr_handle!(condition_id, lexp) }; + // HACK(eddyb) Naga doesn't seem to have this helper, + // so it's declared on the fly here for convenience. + #[derive(Copy, Clone)] + struct BranchTarget { + label_id: spirv::Word, + merge_info: Option, + } + let branch_target = |label_id| BranchTarget { + label_id, + merge_info: ctx.mergers.get(&label_id).copied(), + }; + + let true_target = branch_target(self.next()?); + let false_target = branch_target(self.next()?); + + // Consume branch weights + for _ in 4..inst.wc { + let _ = self.next()?; + } + + // Handle `OpBranchConditional`s used at the end of a loop + // body's "continuing" section as a "conditional backedge", + // i.e. a `do`-`while` condition, or `break if` in WGSL. + + // HACK(eddyb) this has to go to the parent *twice*, because + // `OpLoopMerge` left the "continuing" section nested in the + // loop body in terms of `parent`, but not `BodyFragment`. + let parent_body_idx = ctx.bodies[body_idx].parent; + let parent_parent_body_idx = ctx.bodies[parent_body_idx].parent; + match ctx.bodies[parent_parent_body_idx].data[..] { + // The `OpLoopMerge`'s `continuing` block and the loop's + // backedge block may not be the same, but they'll both + // belong to the same body. + [.., BodyFragment::Loop { + body: loop_body_idx, + continuing: loop_continuing_idx, + break_if: ref mut break_if_slot @ None, + }] if body_idx == loop_continuing_idx => { + // Try both orderings of break-vs-backedge, because + // SPIR-V is symmetrical here, unlike WGSL `break if`. + let break_if_cond = [true, false].into_iter().find_map(|true_breaks| { + let (break_candidate, backedge_candidate) = if true_breaks { + (true_target, false_target) + } else { + (false_target, true_target) + }; + + if break_candidate.merge_info + != Some(MergeBlockInformation::LoopMerge) + { + return None; + } + + // HACK(eddyb) since Naga doesn't explicitly track + // backedges, this is checking for the outcome of + // `OpLoopMerge` below (even if it looks weird). + let backedge_candidate_is_backedge = + backedge_candidate.merge_info.is_none() + && ctx.body_for_label.get(&backedge_candidate.label_id) + == Some(&loop_body_idx); + if !backedge_candidate_is_backedge { + return None; + } + + Some(if true_breaks { + condition + } else { + ctx.expressions.append( + crate::Expression::Unary { + op: crate::UnaryOperator::Not, + expr: condition, + }, + span, + ) + }) + }); + + if let Some(break_if_cond) = break_if_cond { + *break_if_slot = Some(break_if_cond); + + // This `OpBranchConditional` ends the "continuing" + // section of the loop body as normal, with the + // `break if` condition having been stashed above. + break None; + } + } + _ => {} + } + block.extend(emitter.finish(ctx.expressions)); ctx.blocks.insert(block_id, block); let body = &mut ctx.bodies[body_idx]; body.data.push(BodyFragment::BlockId(block_id)); - let true_id = self.next()?; - let false_id = self.next()?; - - let same_target = true_id == false_id; + let same_target = true_target.label_id == false_target.label_id; // Start a body block for the `accept` branch. let accept = ctx.bodies.len(); let mut accept_block = Body::with_parent(body_idx); - // If the `OpBranchConditional`target is somebody else's + // If the `OpBranchConditional` target is somebody else's // merge or continue block, then put a `Break` or `Continue` // statement in this new body block. - if let Some(info) = ctx.mergers.get(&true_id) { + if let Some(info) = true_target.merge_info { merger( match same_target { true => &mut ctx.bodies[body_idx], false => &mut accept_block, }, - info, + &info, ) } else { // Note the body index for the block we're branching to. let prev = ctx.body_for_label.insert( - true_id, + true_target.label_id, match same_target { true => body_idx, false => accept, @@ -3161,10 +3252,10 @@ impl> Frontend { let reject = ctx.bodies.len(); let mut reject_block = Body::with_parent(body_idx); - if let Some(info) = ctx.mergers.get(&false_id) { - merger(&mut reject_block, info) + if let Some(info) = false_target.merge_info { + merger(&mut reject_block, &info) } else { - let prev = ctx.body_for_label.insert(false_id, reject); + let prev = ctx.body_for_label.insert(false_target.label_id, reject); debug_assert!(prev.is_none()); } @@ -3177,11 +3268,6 @@ impl> Frontend { reject, }); - // Consume branch weights - for _ in 4..inst.wc { - let _ = self.next()?; - } - return Ok(()); } Op::Switch => { @@ -3351,6 +3437,7 @@ impl> Frontend { parent_body.data.push(BodyFragment::Loop { body: loop_body_idx, continuing: continue_idx, + break_if: None, }); body_idx = loop_body_idx; } diff --git a/tests/in/spv/do-while.spv b/tests/in/spv/do-while.spv new file mode 100644 index 0000000000000000000000000000000000000000..23d45958b9e0f61ef777f5c63975ccf8a311a542 GIT binary patch literal 480 zcmYk2+e!ja6oywbIg?o)(g#Soi7up|BT~2Bc-Iq*iQt76~Bp z{Ij=tc2-8Q7ZGI)`1SN63oETC5}zk8lhNmTHoCsKwBL?gq+TZ)u}_{6%WAQ*-leZD ziM-A?7&^H-r*?OKKD`=4bpc=BRx^;a9`DshwS;}Pn{b$1Bjp2Xhty8l?Lp_&L-YZ3 z-ueYQ)=!9AoUMBokFWVgHQ;#1N3SQk(!GazuTSp1t^e^a+!I?*>mI*-$FAS@P5QUs n4?L>1{KBaXJA}vnF?3$4?WyG>^S+}?==+N{*GnJg +#include + +using metal::uint; + + +void fb1_( + thread bool& cond +) { + bool loop_init = true; + while(true) { + if (!loop_init) { + bool _e6 = cond; + bool unnamed = !(_e6); + if (!(cond)) { + break; + } + } + loop_init = false; + continue; + } + return; +} + +void main_1( +) { + bool param = {}; + param = false; + fb1_(param); + return; +} + +fragment void main_( +) { + main_1(); +} diff --git a/tests/out/wgsl/do-while.wgsl b/tests/out/wgsl/do-while.wgsl new file mode 100644 index 0000000000..d444169e41 --- /dev/null +++ b/tests/out/wgsl/do-while.wgsl @@ -0,0 +1,24 @@ +fn fb1_(cond: ptr) { + loop { + continue; + continuing { + let _e6 = (*cond); + _ = !(_e6); + break if !(_e6); + } + } + return; +} + +fn main_1() { + var param: bool; + + param = false; + fb1_((¶m)); + return; +} + +@fragment +fn main() { + main_1(); +} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 8e75f5b742..09e98607be 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -640,6 +640,11 @@ fn convert_spv_all() { convert_spv("degrees", false, Targets::empty()); convert_spv("binding-arrays.dynamic", true, Targets::WGSL); convert_spv("binding-arrays.static", true, Targets::WGSL); + convert_spv( + "do-while", + true, + Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, + ); } #[cfg(feature = "glsl-in")]