Skip to content

Commit

Permalink
JIT: handle interaction of OSR, PGO, and tail calls (#62263)
Browse files Browse the repository at this point in the history
When both OSR and PGO are enabled, the jit will add PGO probes to OSR methods.
And if the OSR method also has a tail call, the jit must take care to not add
block probes to any return block reachable from possible tail call blocks.

Instead, instrumentation should create copies of the return block probe in each
return block predecessor (possibly splitting critical edges to make this viable).

Because all this happens early on, there are no pred lists. The analysis leverages
cheap preds instead, which means it needs to handle cases where a given pred has
multiple pred list entries. And it must also be aware that the OSR method's actual
flowgraph is a subgraph of the full initial graph.

This came up while scouting what it would take to enable OSR by default.
See #61934.
  • Loading branch information
AndyAyersMS committed Dec 2, 2021
1 parent c8cd6fe commit 3810633
Show file tree
Hide file tree
Showing 9 changed files with 404 additions and 35 deletions.
3 changes: 2 additions & 1 deletion src/coreclr/jit/block.h
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,8 @@ enum BasicBlockFlags : unsigned __int64
BBF_PATCHPOINT = MAKE_BBFLAG(36), // Block is a patchpoint
BBF_HAS_CLASS_PROFILE = MAKE_BBFLAG(37), // BB contains a call needing a class profile
BBF_PARTIAL_COMPILATION_PATCHPOINT = MAKE_BBFLAG(38), // Block is a partial compilation patchpoint
BBF_HAS_ALIGN = MAKE_BBFLAG(39), // BB ends with 'align' instruction
BBF_HAS_ALIGN = MAKE_BBFLAG(39), // BB ends with 'align' instruction
BBF_TAILCALL_SUCCESSOR = MAKE_BBFLAG(40), // BB has pred that has potential tail call

// The following are sets of flags.

Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -7360,6 +7360,7 @@ class Compiler
#define OMF_NEEDS_GCPOLLS 0x00000200 // Method needs GC polls
#define OMF_HAS_FROZEN_STRING 0x00000400 // Method has a frozen string (REF constant int), currently only on CoreRT.
#define OMF_HAS_PARTIAL_COMPILATION_PATCHPOINT 0x00000800 // Method contains partial compilation patchpoints
#define OMF_HAS_TAILCALL_SUCCESSOR 0x00001000 // Method has potential tail call in a non BBJ_RETURN block

bool doesMethodHaveFatPointer()
{
Expand Down
22 changes: 14 additions & 8 deletions src/coreclr/jit/fgbasic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,10 +536,10 @@ void Compiler::fgReplaceSwitchJumpTarget(BasicBlock* blockSwitch, BasicBlock* ne
// Notes:
// 1. Only branches are changed: BBJ_ALWAYS, the non-fallthrough path of BBJ_COND, BBJ_SWITCH, etc.
// We ignore other block types.
// 2. Only the first target found is updated. If there are multiple ways for a block
// to reach 'oldTarget' (e.g., multiple arms of a switch), only the first one found is changed.
// 2. All branch targets found are updated. If there are multiple ways for a block
// to reach 'oldTarget' (e.g., multiple arms of a switch), all of them are changed.
// 3. The predecessor lists are not changed.
// 4. The switch table "unique successor" cache is invalidated.
// 4. If any switch table entry was updated, the switch table "unique successor" cache is invalidated.
//
// This function is most useful early, before the full predecessor lists have been computed.
//
Expand Down Expand Up @@ -569,20 +569,26 @@ void Compiler::fgReplaceJumpTarget(BasicBlock* block, BasicBlock* newTarget, Bas
break;

case BBJ_SWITCH:
unsigned jumpCnt;
jumpCnt = block->bbJumpSwt->bbsCount;
BasicBlock** jumpTab;
jumpTab = block->bbJumpSwt->bbsDstTab;
{
unsigned const jumpCnt = block->bbJumpSwt->bbsCount;
BasicBlock** const jumpTab = block->bbJumpSwt->bbsDstTab;
bool changed = false;

for (unsigned i = 0; i < jumpCnt; i++)
{
if (jumpTab[i] == oldTarget)
{
jumpTab[i] = newTarget;
break;
changed = true;
}
}

if (changed)
{
InvalidateUniqueSwitchSuccMap();
}
break;
}

default:
assert(!"Block doesn't have a valid bbJumpKind!!!!");
Expand Down
197 changes: 194 additions & 3 deletions src/coreclr/jit/fgprofile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,160 @@ void BlockCountInstrumentor::Prepare(bool preImport)
return;
}

// If this is an OSR method, look for potential tail calls in
// blocks that are not BBJ_RETURN.
//
// If we see any, we need to adjust our instrumentation pattern.
//
if (m_comp->opts.IsOSR() && ((m_comp->optMethodFlags & OMF_HAS_TAILCALL_SUCCESSOR) != 0))
{
JITDUMP("OSR + PGO + potential tail call --- preparing to relocate block probes\n");

// We should be in a root method compiler instance. OSR + PGO does not
// currently try and instrument inlinees.
//
// Relaxing this will require changes below because inlinee compilers
// share the root compiler flow graph (and hence bb epoch), and flow
// from inlinee tail calls to returns can be more complex.
//
assert(!m_comp->compIsForInlining());

// Build cheap preds.
//
m_comp->fgComputeCheapPreds();
m_comp->EnsureBasicBlockEpoch();

// Keep track of return blocks needing special treatment.
// We also need to track of duplicate preds.
//
JitExpandArrayStack<BasicBlock*> specialReturnBlocks(m_comp->getAllocator(CMK_Pgo));
BlockSet predsSeen = BlockSetOps::MakeEmpty(m_comp);

// Walk blocks looking for BBJ_RETURNs that are successors of potential tail calls.
//
// If any such has a conditional pred, we will need to reroute flow from those preds
// via an intermediary block. That block will subsequently hold the relocated block
// probe for the return for those preds.
//
// Scrub the cheap pred list for these blocks so that each pred appears at most once.
//
for (BasicBlock* const block : m_comp->Blocks())
{
// Ignore blocks that we won't process.
//
if (!ShouldProcess(block))
{
continue;
}

if ((block->bbFlags & BBF_TAILCALL_SUCCESSOR) != 0)
{
JITDUMP("Return " FMT_BB " is successor of possible tail call\n", block->bbNum);
assert(block->bbJumpKind == BBJ_RETURN);
bool pushed = false;
BlockSetOps::ClearD(m_comp, predsSeen);
for (BasicBlockList* predEdge = block->bbCheapPreds; predEdge != nullptr; predEdge = predEdge->next)
{
BasicBlock* const pred = predEdge->block;

// If pred is not to be processed, ignore it and scrub from the pred list.
//
if (!ShouldProcess(pred))
{
JITDUMP(FMT_BB " -> " FMT_BB " is dead edge\n", pred->bbNum, block->bbNum);
predEdge->block = nullptr;
continue;
}

BasicBlock* const succ = pred->GetUniqueSucc();

if (succ == nullptr)
{
// Flow from pred -> block is conditional, and will require updating.
//
JITDUMP(FMT_BB " -> " FMT_BB " is critical edge\n", pred->bbNum, block->bbNum);
if (!pushed)
{
specialReturnBlocks.Push(block);
pushed = true;
}

// Have we seen this pred before?
//
if (BlockSetOps::IsMember(m_comp, predsSeen, pred->bbNum))
{
// Yes, null out the duplicate pred list entry.
//
predEdge->block = nullptr;
}
}
else
{
// We should only ever see one reference to this pred.
//
assert(!BlockSetOps::IsMember(m_comp, predsSeen, pred->bbNum));

// Ensure flow from non-critical preds is BBJ_ALWAYS as we
// may add a new block right before block.
//
if (pred->bbJumpKind == BBJ_NONE)
{
pred->bbJumpKind = BBJ_ALWAYS;
pred->bbJumpDest = block;
}
assert(pred->bbJumpKind == BBJ_ALWAYS);
}

BlockSetOps::AddElemD(m_comp, predsSeen, pred->bbNum);
}
}
}

// Now process each special return block.
// Create an intermediary that falls through to the return.
// Update any critical edges to target the intermediary.
//
// Note we could also route any non-tail-call pred via the
// intermedary. Doing so would cut down on probe duplication.
//
while (specialReturnBlocks.Size() > 0)
{
bool first = true;
BasicBlock* const block = specialReturnBlocks.Pop();
BasicBlock* const intermediary = m_comp->fgNewBBbefore(BBJ_NONE, block, /* extendRegion*/ true);

intermediary->bbFlags |= BBF_IMPORTED;
intermediary->inheritWeight(block);

for (BasicBlockList* predEdge = block->bbCheapPreds; predEdge != nullptr; predEdge = predEdge->next)
{
BasicBlock* const pred = predEdge->block;

if (pred != nullptr)
{
BasicBlock* const succ = pred->GetUniqueSucc();

if (succ == nullptr)
{
// This will update all branch targets from pred.
//
m_comp->fgReplaceJumpTarget(pred, intermediary, block);

// Patch the pred list. Note we only need one pred list
// entry pointing at intermediary.
//
predEdge->block = first ? intermediary : nullptr;
first = false;
}
else
{
assert(pred->bbJumpKind == BBJ_ALWAYS);
}
}
}
}
}

#ifdef DEBUG
// Set schema index to invalid value
//
Expand Down Expand Up @@ -449,7 +603,37 @@ void BlockCountInstrumentor::Instrument(BasicBlock* block, Schema& schema, uint8
GenTree* lhsNode = m_comp->gtNewIndOfIconHandleNode(typ, addrOfCurrentExecutionCount, GTF_ICON_BBC_PTR, false);
GenTree* asgNode = m_comp->gtNewAssignNode(lhsNode, rhsNode);

m_comp->fgNewStmtAtBeg(block, asgNode);
if ((block->bbFlags & BBF_TAILCALL_SUCCESSOR) != 0)
{
// We should have built and updated cheap preds during the prepare stage.
//
assert(m_comp->fgCheapPredsValid);

// Instrument each predecessor.
//
bool first = true;
for (BasicBlockList* predEdge = block->bbCheapPreds; predEdge != nullptr; predEdge = predEdge->next)
{
BasicBlock* const pred = predEdge->block;

// We may have scrubbed cheap pred list duplicates during Prepare.
//
if (pred != nullptr)
{
JITDUMP("Placing copy of block probe for " FMT_BB " in pred " FMT_BB "\n", block->bbNum, pred->bbNum);
if (!first)
{
asgNode = m_comp->gtCloneExpr(asgNode);
}
m_comp->fgNewStmtAtBeg(pred, asgNode);
first = false;
}
}
}
else
{
m_comp->fgNewStmtAtBeg(block, asgNode);
}

m_instrCount++;
}
Expand Down Expand Up @@ -589,7 +773,7 @@ void Compiler::WalkSpanningTree(SpanningTreeVisitor* visitor)
// graph. So for BlockSets and NumSucc, we use the root compiler instance.
//
Compiler* const comp = impInlineRoot();
comp->NewBasicBlockEpoch();
comp->EnsureBasicBlockEpoch();

// We will track visited or queued nodes with a bit vector.
//
Expand Down Expand Up @@ -1612,7 +1796,7 @@ PhaseStatus Compiler::fgPrepareToInstrumentMethod()
else
{
JITDUMP("Using block profiling, because %s\n",
(JitConfig.JitEdgeProfiling() > 0)
(JitConfig.JitEdgeProfiling() == 0)
? "edge profiles disabled"
: prejit ? "prejitting" : osrMethod ? "OSR" : "tier0 with patchpoints");

Expand Down Expand Up @@ -1793,6 +1977,13 @@ PhaseStatus Compiler::fgInstrumentMethod()
fgCountInstrumentor->InstrumentMethodEntry(schema, profileMemory);
fgClassInstrumentor->InstrumentMethodEntry(schema, profileMemory);

// If we needed to create cheap preds, we're done with them now.
//
if (fgCheapPredsValid)
{
fgRemovePreds();
}

return PhaseStatus::MODIFIED_EVERYTHING;
}

Expand Down
61 changes: 38 additions & 23 deletions src/coreclr/jit/importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9797,34 +9797,49 @@ var_types Compiler::impImportCall(OPCODE opcode,
}

// A tail recursive call is a potential loop from the current block to the start of the method.
if ((tailCallFlags != 0) && canTailCall && gtIsRecursiveCall(methHnd))
if ((tailCallFlags != 0) && canTailCall)
{
assert(verCurrentState.esStackDepth == 0);
BasicBlock* loopHead = nullptr;
if (opts.IsOSR())
// If a root method tail call candidate block is not a BBJ_RETURN, it should have a unique
// BBJ_RETURN successor. Mark that successor so we can handle it specially during profile
// instrumentation.
//
if (!compIsForInlining() && (compCurBB->bbJumpKind != BBJ_RETURN))
{
// We might not have been planning on importing the method
// entry block, but now we must.

// We should have remembered the real method entry block.
assert(fgEntryBB != nullptr);

JITDUMP("\nOSR: found tail recursive call in the method, scheduling " FMT_BB " for importation\n",
fgEntryBB->bbNum);
impImportBlockPending(fgEntryBB);
loopHead = fgEntryBB;
BasicBlock* const successor = compCurBB->GetUniqueSucc();
assert(successor->bbJumpKind == BBJ_RETURN);
successor->bbFlags |= BBF_TAILCALL_SUCCESSOR;
optMethodFlags |= OMF_HAS_TAILCALL_SUCCESSOR;
}
else

if (gtIsRecursiveCall(methHnd))
{
// For normal jitting we'll branch back to the firstBB; this
// should already be imported.
loopHead = fgFirstBB;
}
assert(verCurrentState.esStackDepth == 0);
BasicBlock* loopHead = nullptr;
if (opts.IsOSR())
{
// We might not have been planning on importing the method
// entry block, but now we must.

JITDUMP("\nFound tail recursive call in the method. Mark " FMT_BB " to " FMT_BB
" as having a backward branch.\n",
loopHead->bbNum, compCurBB->bbNum);
fgMarkBackwardJump(loopHead, compCurBB);
// We should have remembered the real method entry block.
assert(fgEntryBB != nullptr);

JITDUMP("\nOSR: found tail recursive call in the method, scheduling " FMT_BB " for importation\n",
fgEntryBB->bbNum);
impImportBlockPending(fgEntryBB);
loopHead = fgEntryBB;
}
else
{
// For normal jitting we'll branch back to the firstBB; this
// should already be imported.
loopHead = fgFirstBB;
}

JITDUMP("\nFound tail recursive call in the method. Mark " FMT_BB " to " FMT_BB
" as having a backward branch.\n",
loopHead->bbNum, compCurBB->bbNum);
fgMarkBackwardJump(loopHead, compCurBB);
}
}

// Note: we assume that small return types are already normalized by the managed callee
Expand Down
Loading

0 comments on commit 3810633

Please sign in to comment.