Skip to content

Commit

Permalink
[spv-in] Convert conditional backedges to break if.
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyb committed May 10, 2023
1 parent dd54aaf commit 3a10639
Show file tree
Hide file tree
Showing 10 changed files with 307 additions and 19 deletions.
8 changes: 6 additions & 2 deletions src/front/spv/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -597,15 +597,19 @@ impl<'function> BlockContext<'function> {
crate::Span::default(),
)
}
super::BodyFragment::Loop { body, continuing } => {
super::BodyFragment::Loop {
body,
continuing,
break_if,
} => {
let body = lower_impl(blocks, bodies, body);
let continuing = lower_impl(blocks, bodies, continuing);

block.push(
crate::Statement::Loop {
body,
continuing,
break_if: None,
break_if,
},
crate::Span::default(),
)
Expand Down
121 changes: 104 additions & 17 deletions src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,11 @@ enum BodyFragment {
Loop {
body: BodyIndex,
continuing: BodyIndex,

/// If the SPIR-V loop's back-edge branch is conditional, this is the
/// expression that must be `false` for the back-edge to be taken, with
/// `true` being for the "loop merge" (which breaks out of the loop).
break_if: Option<Handle<crate::Expression>>,
},
Switch {
selector: Handle<crate::Expression>,
Expand Down Expand Up @@ -429,7 +434,7 @@ struct PhiExpression {
expressions: Vec<(spirv::Word, spirv::Word)>,
}

#[derive(Debug)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
enum MergeBlockInformation {
LoopMerge,
LoopContinue,
Expand Down Expand Up @@ -3114,35 +3119,121 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
get_expr_handle!(condition_id, lexp)
};

// HACK(eddyb) Naga doesn't seem to have this helper,
// so it's declared on the fly here for convenience.
#[derive(Copy, Clone)]
struct BranchTarget {
label_id: spirv::Word,
merge_info: Option<MergeBlockInformation>,
}
let branch_target = |label_id| BranchTarget {
label_id,
merge_info: ctx.mergers.get(&label_id).copied(),
};

let true_target = branch_target(self.next()?);
let false_target = branch_target(self.next()?);

// Consume branch weights
for _ in 4..inst.wc {
let _ = self.next()?;
}

// Handle `OpBranchConditional`s used at the end of a loop
// body's "continuing" section as a "conditional backedge",
// i.e. a `do`-`while` condition, or `break if` in WGSL.

// HACK(eddyb) this has to go to the parent *twice*, because
// `OpLoopMerge` left the "continuing" section nested in the
// loop body in terms of `parent`, but not `BodyFragment`.
let parent_body_idx = ctx.bodies[body_idx].parent;
let parent_parent_body_idx = ctx.bodies[parent_body_idx].parent;
match ctx.bodies[parent_parent_body_idx].data[..] {
// The `OpLoopMerge`'s `continuing` block and the loop's
// backedge block may not be the same, but they'll both
// belong to the same body.
[.., BodyFragment::Loop {
body: loop_body_idx,
continuing: loop_continuing_idx,
break_if: ref mut break_if_slot @ None,
}] if body_idx == loop_continuing_idx => {
// Try both orderings of break-vs-backedge, because
// SPIR-V is symmetrical here, unlike WGSL `break if`.
let break_if_cond = [true, false].into_iter().find_map(|true_breaks| {
let (break_candidate, backedge_candidate) = if true_breaks {
(true_target, false_target)
} else {
(false_target, true_target)
};

if break_candidate.merge_info
!= Some(MergeBlockInformation::LoopMerge)
{
return None;
}

// HACK(eddyb) since Naga doesn't explicitly track
// backedges, this is checking for the outcome of
// `OpLoopMerge` below (even if it looks weird).
let backedge_candidate_is_backedge =
backedge_candidate.merge_info.is_none()
&& ctx.body_for_label.get(&backedge_candidate.label_id)
== Some(&loop_body_idx);
if !backedge_candidate_is_backedge {
return None;
}

Some(if true_breaks {
condition
} else {
ctx.expressions.append(
crate::Expression::Unary {
op: crate::UnaryOperator::Not,
expr: condition,
},
span,
)
})
});

if let Some(break_if_cond) = break_if_cond {
*break_if_slot = Some(break_if_cond);

// This `OpBranchConditional` ends the "continuing"
// section of the loop body as normal, with the
// `break if` condition having been stashed above.
break None;
}
}
_ => {}
}

block.extend(emitter.finish(ctx.expressions));
ctx.blocks.insert(block_id, block);
let body = &mut ctx.bodies[body_idx];
body.data.push(BodyFragment::BlockId(block_id));

let true_id = self.next()?;
let false_id = self.next()?;

let same_target = true_id == false_id;
let same_target = true_target.label_id == false_target.label_id;

// Start a body block for the `accept` branch.
let accept = ctx.bodies.len();
let mut accept_block = Body::with_parent(body_idx);

// If the `OpBranchConditional`target is somebody else's
// If the `OpBranchConditional` target is somebody else's
// merge or continue block, then put a `Break` or `Continue`
// statement in this new body block.
if let Some(info) = ctx.mergers.get(&true_id) {
if let Some(info) = true_target.merge_info {
merger(
match same_target {
true => &mut ctx.bodies[body_idx],
false => &mut accept_block,
},
info,
&info,
)
} else {
// Note the body index for the block we're branching to.
let prev = ctx.body_for_label.insert(
true_id,
true_target.label_id,
match same_target {
true => body_idx,
false => accept,
Expand All @@ -3161,10 +3252,10 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let reject = ctx.bodies.len();
let mut reject_block = Body::with_parent(body_idx);

if let Some(info) = ctx.mergers.get(&false_id) {
merger(&mut reject_block, info)
if let Some(info) = false_target.merge_info {
merger(&mut reject_block, &info)
} else {
let prev = ctx.body_for_label.insert(false_id, reject);
let prev = ctx.body_for_label.insert(false_target.label_id, reject);
debug_assert!(prev.is_none());
}

Expand All @@ -3177,11 +3268,6 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
reject,
});

// Consume branch weights
for _ in 4..inst.wc {
let _ = self.next()?;
}

return Ok(());
}
Op::Switch => {
Expand Down Expand Up @@ -3351,6 +3437,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
parent_body.data.push(BodyFragment::Loop {
body: loop_body_idx,
continuing: continue_idx,
break_if: None,
});
body_idx = loop_body_idx;
}
Expand Down
Binary file added tests/in/spv/do-while.spv
Binary file not shown.
64 changes: 64 additions & 0 deletions tests/in/spv/do-while.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
;; Ensure that `do`-`while`-style loops, with conditional backedges, are properly
;; supported, via `break if` (as `continuing { ... if c { break; } }` is illegal).
;;
;; The SPIR-V below was compiled from this GLSL fragment shader:
;; ```glsl
;; #version 450
;;
;; void f(bool cond) {
;; do {} while(cond);
;; }
;;
;; void main() {
;; f(false);
;; }
;; ```

OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main"
OpExecutionMode %main OriginUpperLeft
OpSource GLSL 450
OpName %main "main"
OpName %f_b1_ "f(b1;"
OpName %cond "cond"
OpName %param "param"
%void = OpTypeVoid
%3 = OpTypeFunction %void
%bool = OpTypeBool
%_ptr_Function_bool = OpTypePointer Function %bool
%8 = OpTypeFunction %void %_ptr_Function_bool
%false = OpConstantFalse %bool

%main = OpFunction %void None %3
%5 = OpLabel
%param = OpVariable %_ptr_Function_bool Function
OpStore %param %false
%19 = OpFunctionCall %void %f_b1_ %param
OpReturn
OpFunctionEnd

%f_b1_ = OpFunction %void None %8
%cond = OpFunctionParameter %_ptr_Function_bool

%11 = OpLabel
OpBranch %12

%12 = OpLabel
OpLoopMerge %14 %15 None
OpBranch %13

%13 = OpLabel
OpBranch %15

;; This is the "continuing" block, and it contains a conditional branch between
;; the backedge (back to the loop header) and the loop merge ("break") target.
%15 = OpLabel
%16 = OpLoad %bool %cond
OpBranchConditional %16 %12 %14

%14 = OpLabel
OpReturn

OpFunctionEnd
33 changes: 33 additions & 0 deletions tests/out/glsl/do-while.main.Fragment.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#version 310 es

precision highp float;
precision highp int;


void fb1_(inout bool cond) {
bool loop_init = true;
while(true) {
if (!loop_init) {
bool _e6 = cond;
bool unnamed = !(_e6);
if (unnamed) {
break;
}
}
loop_init = false;
continue;
}
return;
}

void main_1() {
bool param = false;
param = false;
fb1_(param);
return;
}

void main() {
main_1();
}

31 changes: 31 additions & 0 deletions tests/out/hlsl/do-while.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

void fb1_(inout bool cond)
{
bool loop_init = true;
while(true) {
if (!loop_init) {
bool _expr6 = cond;
bool unnamed = !(_expr6);
if (unnamed) {
break;
}
}
loop_init = false;
continue;
}
return;
}

void main_1()
{
bool param = (bool)0;

param = false;
fb1_(param);
return;
}

void main()
{
main_1();
}
3 changes: 3 additions & 0 deletions tests/out/hlsl/do-while.hlsl.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
vertex=()
fragment=(main:ps_5_1 )
compute=()
37 changes: 37 additions & 0 deletions tests/out/msl/do-while.msl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// language: metal2.0
#include <metal_stdlib>
#include <simd/simd.h>

using metal::uint;


void fb1_(
thread bool& cond
) {
bool loop_init = true;
while(true) {
if (!loop_init) {
bool _e6 = cond;
bool unnamed = !(_e6);
if (!(cond)) {
break;
}
}
loop_init = false;
continue;
}
return;
}

void main_1(
) {
bool param = {};
param = false;
fb1_(param);
return;
}

fragment void main_(
) {
main_1();
}
Loading

0 comments on commit 3a10639

Please sign in to comment.