rust/compiler/rustc_mir_transform/src/jump_threading.rs

849 lines
34 KiB
Rust

//! A jump threading optimization.
//!
//! This optimization seeks to replace join-then-switch control flow patterns by straight jumps
//! X = 0 X = 0
//! ------------\ /-------- ------------
//! X = 1 X----X SwitchInt(X) => X = 1
//! ------------/ \-------- ------------
//!
//!
//! We proceed by walking the cfg backwards starting from each `SwitchInt` terminator,
//! looking for assignments that will turn the `SwitchInt` into a simple `Goto`.
//!
//! The algorithm maintains a set of replacement conditions:
//! - `conditions[place]` contains `Condition { value, polarity: Eq, target }`
//! if assigning `value` to `place` turns the `SwitchInt` into `Goto { target }`.
//! - `conditions[place]` contains `Condition { value, polarity: Ne, target }`
//! if assigning anything different from `value` to `place` turns the `SwitchInt`
//! into `Goto { target }`.
//!
//! In this file, we denote as `place ?= value` the existence of a replacement condition
//! on `place` with given `value`, irrespective of the polarity and target of that
//! replacement condition.
//!
//! We then walk the CFG backwards transforming the set of conditions.
//! When we find a fulfilling assignment, we record a `ThreadingOpportunity`.
//! All `ThreadingOpportunity`s are applied to the body, by duplicating blocks if required.
//!
//! The optimization search can be very heavy, as it performs a DFS on MIR starting from
//! each `SwitchInt` terminator. To manage the complexity, we:
//! - bound the maximum depth by a constant `MAX_BACKTRACK`;
//! - we only traverse `Goto` terminators.
//!
//! We try to avoid creating irreducible control-flow by not threading through a loop header.
//!
//! Likewise, applying the optimisation can create a lot of new MIR, so we bound the instruction
//! cost by `MAX_COST`.
use rustc_arena::DroplessArena;
use rustc_const_eval::const_eval::DummyMachine;
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
use rustc_data_structures::fx::FxHashSet;
use rustc_index::bit_set::BitSet;
use rustc_index::IndexVec;
use rustc_middle::bug;
use rustc_middle::mir::interpret::Scalar;
use rustc_middle::mir::visit::Visitor;
use rustc_middle::mir::*;
use rustc_middle::ty::layout::LayoutOf;
use rustc_middle::ty::{self, ScalarInt, TyCtxt};
use rustc_mir_dataflow::lattice::HasBottom;
use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem};
use rustc_span::DUMMY_SP;
use rustc_target::abi::{TagEncoding, Variants};
use tracing::{debug, instrument, trace};
use crate::cost_checker::CostChecker;
pub(super) struct JumpThreading;
const MAX_BACKTRACK: usize = 5;
const MAX_COST: usize = 100;
const MAX_PLACES: usize = 100;
impl<'tcx> crate::MirPass<'tcx> for JumpThreading {
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
sess.mir_opt_level() >= 2
}
#[instrument(skip_all level = "debug")]
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let def_id = body.source.def_id();
debug!(?def_id);
// Optimizing coroutines creates query cycles.
if tcx.is_coroutine(def_id) {
trace!("Skipped for coroutine {:?}", def_id);
return;
}
let param_env = tcx.param_env_reveal_all_normalized(def_id);
let arena = &DroplessArena::default();
let mut finder = TOFinder {
tcx,
param_env,
ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine),
body,
arena,
map: Map::new(tcx, body, Some(MAX_PLACES)),
loop_headers: loop_headers(body),
opportunities: Vec::new(),
};
for bb in body.basic_blocks.indices() {
finder.start_from_switch(bb);
}
let opportunities = finder.opportunities;
debug!(?opportunities);
if opportunities.is_empty() {
return;
}
// Verify that we do not thread through a loop header.
for to in opportunities.iter() {
assert!(to.chain.iter().all(|&block| !finder.loop_headers.contains(block)));
}
OpportunitySet::new(body, opportunities).apply(body);
}
}
#[derive(Debug)]
struct ThreadingOpportunity {
/// The list of `BasicBlock`s from the one that found the opportunity to the `SwitchInt`.
chain: Vec<BasicBlock>,
/// The `SwitchInt` will be replaced by `Goto { target }`.
target: BasicBlock,
}
struct TOFinder<'a, 'tcx> {
tcx: TyCtxt<'tcx>,
param_env: ty::ParamEnv<'tcx>,
ecx: InterpCx<'tcx, DummyMachine>,
body: &'a Body<'tcx>,
map: Map<'tcx>,
loop_headers: BitSet<BasicBlock>,
/// We use an arena to avoid cloning the slices when cloning `state`.
arena: &'a DroplessArena,
opportunities: Vec<ThreadingOpportunity>,
}
/// Represent the following statement. If we can prove that the current local is equal/not-equal
/// to `value`, jump to `target`.
#[derive(Copy, Clone, Debug)]
struct Condition {
value: ScalarInt,
polarity: Polarity,
target: BasicBlock,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum Polarity {
Ne,
Eq,
}
impl Condition {
fn matches(&self, value: ScalarInt) -> bool {
(self.value == value) == (self.polarity == Polarity::Eq)
}
fn inv(mut self) -> Self {
self.polarity = match self.polarity {
Polarity::Eq => Polarity::Ne,
Polarity::Ne => Polarity::Eq,
};
self
}
}
#[derive(Copy, Clone, Debug)]
struct ConditionSet<'a>(&'a [Condition]);
impl HasBottom for ConditionSet<'_> {
const BOTTOM: Self = ConditionSet(&[]);
fn is_bottom(&self) -> bool {
self.0.is_empty()
}
}
impl<'a> ConditionSet<'a> {
fn iter(self) -> impl Iterator<Item = Condition> + 'a {
self.0.iter().copied()
}
fn iter_matches(self, value: ScalarInt) -> impl Iterator<Item = Condition> + 'a {
self.iter().filter(move |c| c.matches(value))
}
fn map(self, arena: &'a DroplessArena, f: impl Fn(Condition) -> Condition) -> ConditionSet<'a> {
ConditionSet(arena.alloc_from_iter(self.iter().map(f)))
}
}
impl<'a, 'tcx> TOFinder<'a, 'tcx> {
fn is_empty(&self, state: &State<ConditionSet<'a>>) -> bool {
state.all_bottom()
}
/// Recursion entry point to find threading opportunities.
#[instrument(level = "trace", skip(self))]
fn start_from_switch(&mut self, bb: BasicBlock) {
let bbdata = &self.body[bb];
if bbdata.is_cleanup || self.loop_headers.contains(bb) {
return;
}
let Some((discr, targets)) = bbdata.terminator().kind.as_switch() else { return };
let Some(discr) = discr.place() else { return };
debug!(?discr, ?bb);
let discr_ty = discr.ty(self.body, self.tcx).ty;
let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else { return };
let Some(discr) = self.map.find(discr.as_ref()) else { return };
debug!(?discr);
let cost = CostChecker::new(self.tcx, self.param_env, None, self.body);
let mut state = State::new_reachable();
let conds = if let Some((value, then, else_)) = targets.as_static_if() {
let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
self.arena.alloc_from_iter([
Condition { value, polarity: Polarity::Eq, target: then },
Condition { value, polarity: Polarity::Ne, target: else_ },
])
} else {
self.arena.alloc_from_iter(targets.iter().filter_map(|(value, target)| {
let value = ScalarInt::try_from_uint(value, discr_layout.size)?;
Some(Condition { value, polarity: Polarity::Eq, target })
}))
};
let conds = ConditionSet(conds);
state.insert_value_idx(discr, conds, &self.map);
self.find_opportunity(bb, state, cost, 0);
}
/// Recursively walk statements backwards from this bb's terminator to find threading
/// opportunities.
#[instrument(level = "trace", skip(self, cost), ret)]
fn find_opportunity(
&mut self,
bb: BasicBlock,
mut state: State<ConditionSet<'a>>,
mut cost: CostChecker<'_, 'tcx>,
depth: usize,
) {
// Do not thread through loop headers.
if self.loop_headers.contains(bb) {
return;
}
debug!(cost = ?cost.cost());
for (statement_index, stmt) in
self.body.basic_blocks[bb].statements.iter().enumerate().rev()
{
if self.is_empty(&state) {
return;
}
cost.visit_statement(stmt, Location { block: bb, statement_index });
if cost.cost() > MAX_COST {
return;
}
// Attempt to turn the `current_condition` on `lhs` into a condition on another place.
self.process_statement(bb, stmt, &mut state);
// When a statement mutates a place, assignments to that place that happen
// above the mutation cannot fulfill a condition.
// _1 = 5 // Whatever happens here, it won't change the result of a `SwitchInt`.
// _1 = 6
if let Some((lhs, tail)) = self.mutated_statement(stmt) {
state.flood_with_tail_elem(lhs.as_ref(), tail, &self.map, ConditionSet::BOTTOM);
}
}
if self.is_empty(&state) || depth >= MAX_BACKTRACK {
return;
}
let last_non_rec = self.opportunities.len();
let predecessors = &self.body.basic_blocks.predecessors()[bb];
if let &[pred] = &predecessors[..]
&& bb != START_BLOCK
{
let term = self.body.basic_blocks[pred].terminator();
match term.kind {
TerminatorKind::SwitchInt { ref discr, ref targets } => {
self.process_switch_int(discr, targets, bb, &mut state);
self.find_opportunity(pred, state, cost, depth + 1);
}
_ => self.recurse_through_terminator(pred, || state, &cost, depth),
}
} else if let &[ref predecessors @ .., last_pred] = &predecessors[..] {
for &pred in predecessors {
self.recurse_through_terminator(pred, || state.clone(), &cost, depth);
}
self.recurse_through_terminator(last_pred, || state, &cost, depth);
}
let new_tos = &mut self.opportunities[last_non_rec..];
debug!(?new_tos);
// Try to deduplicate threading opportunities.
if new_tos.len() > 1
&& new_tos.len() == predecessors.len()
&& predecessors
.iter()
.zip(new_tos.iter())
.all(|(&pred, to)| to.chain == &[pred] && to.target == new_tos[0].target)
{
// All predecessors have a threading opportunity, and they all point to the same block.
debug!(?new_tos, "dedup");
let first = &mut new_tos[0];
*first = ThreadingOpportunity { chain: vec![bb], target: first.target };
self.opportunities.truncate(last_non_rec + 1);
return;
}
for op in self.opportunities[last_non_rec..].iter_mut() {
op.chain.push(bb);
}
}
/// Extract the mutated place from a statement.
///
/// This method returns the `Place` so we can flood the state in case of a partial assignment.
/// (_1 as Ok).0 = _5;
/// (_1 as Err).0 = _6;
/// We want to ensure that a `SwitchInt((_1 as Ok).0)` does not see the first assignment, as
/// the value may have been mangled by the second assignment.
///
/// In case we assign to a discriminant, we return `Some(TrackElem::Discriminant)`, so we can
/// stop at flooding the discriminant, and preserve the variant fields.
/// (_1 as Some).0 = _6;
/// SetDiscriminant(_1, 1);
/// switchInt((_1 as Some).0)
#[instrument(level = "trace", skip(self), ret)]
fn mutated_statement(
&self,
stmt: &Statement<'tcx>,
) -> Option<(Place<'tcx>, Option<TrackElem>)> {
match stmt.kind {
StatementKind::Assign(box (place, _))
| StatementKind::Deinit(box place) => Some((place, None)),
StatementKind::SetDiscriminant { box place, variant_index: _ } => {
Some((place, Some(TrackElem::Discriminant)))
}
StatementKind::StorageLive(local) | StatementKind::StorageDead(local) => {
Some((Place::from(local), None))
}
StatementKind::Retag(..)
| StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(..))
// copy_nonoverlapping takes pointers and mutated the pointed-to value.
| StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping(..))
| StatementKind::AscribeUserType(..)
| StatementKind::Coverage(..)
| StatementKind::FakeRead(..)
| StatementKind::ConstEvalCounter
| StatementKind::PlaceMention(..)
| StatementKind::Nop => None,
}
}
#[instrument(level = "trace", skip(self))]
fn process_immediate(
&mut self,
bb: BasicBlock,
lhs: PlaceIndex,
rhs: ImmTy<'tcx>,
state: &mut State<ConditionSet<'a>>,
) {
let register_opportunity = |c: Condition| {
debug!(?bb, ?c.target, "register");
self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
};
if let Some(conditions) = state.try_get_idx(lhs, &self.map)
&& let Immediate::Scalar(Scalar::Int(int)) = *rhs
{
conditions.iter_matches(int).for_each(register_opportunity);
}
}
/// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
#[instrument(level = "trace", skip(self))]
fn process_constant(
&mut self,
bb: BasicBlock,
lhs: PlaceIndex,
constant: OpTy<'tcx>,
state: &mut State<ConditionSet<'a>>,
) {
self.map.for_each_projection_value(
lhs,
constant,
&mut |elem, op| match elem {
TrackElem::Field(idx) => self.ecx.project_field(op, idx.as_usize()).ok(),
TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(),
TrackElem::Discriminant => {
let variant = self.ecx.read_discriminant(op).ok()?;
let discr_value =
self.ecx.discriminant_for_variant(op.layout.ty, variant).ok()?;
Some(discr_value.into())
}
TrackElem::DerefLen => {
let op: OpTy<'_> = self.ecx.deref_pointer(op).ok()?.into();
let len_usize = op.len(&self.ecx).ok()?;
let layout = self.ecx.layout_of(self.tcx.types.usize).unwrap();
Some(ImmTy::from_uint(len_usize, layout).into())
}
},
&mut |place, op| {
if let Some(conditions) = state.try_get_idx(place, &self.map)
&& let Ok(imm) = self.ecx.read_immediate_raw(op)
&& let Some(imm) = imm.right()
&& let Immediate::Scalar(Scalar::Int(int)) = *imm
{
conditions.iter_matches(int).for_each(|c: Condition| {
self.opportunities
.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
})
}
},
);
}
#[instrument(level = "trace", skip(self))]
fn process_operand(
&mut self,
bb: BasicBlock,
lhs: PlaceIndex,
rhs: &Operand<'tcx>,
state: &mut State<ConditionSet<'a>>,
) {
match rhs {
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
Operand::Constant(constant) => {
let Ok(constant) =
self.ecx.eval_mir_constant(&constant.const_, constant.span, None)
else {
return;
};
self.process_constant(bb, lhs, constant, state);
}
// Transfer the conditions on the copied rhs.
Operand::Move(rhs) | Operand::Copy(rhs) => {
let Some(rhs) = self.map.find(rhs.as_ref()) else { return };
state.insert_place_idx(rhs, lhs, &self.map);
}
}
}
#[instrument(level = "trace", skip(self))]
fn process_assign(
&mut self,
bb: BasicBlock,
lhs_place: &Place<'tcx>,
rhs: &Rvalue<'tcx>,
state: &mut State<ConditionSet<'a>>,
) {
let Some(lhs) = self.map.find(lhs_place.as_ref()) else { return };
match rhs {
Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state),
// Transfer the conditions on the copy rhs.
Rvalue::CopyForDeref(rhs) => self.process_operand(bb, lhs, &Operand::Copy(*rhs), state),
Rvalue::Discriminant(rhs) => {
let Some(rhs) = self.map.find_discr(rhs.as_ref()) else { return };
state.insert_place_idx(rhs, lhs, &self.map);
}
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
Rvalue::Aggregate(box ref kind, ref operands) => {
let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
let lhs = match kind {
// Do not support unions.
AggregateKind::Adt(.., Some(_)) => return,
AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant)
&& let Ok(discr_value) =
self.ecx.discriminant_for_variant(agg_ty, *variant_index)
{
self.process_immediate(bb, discr_target, discr_value, state);
}
if let Some(idx) = self.map.apply(lhs, TrackElem::Variant(*variant_index)) {
idx
} else {
return;
}
}
_ => lhs,
};
for (field_index, operand) in operands.iter_enumerated() {
if let Some(field) = self.map.apply(lhs, TrackElem::Field(field_index)) {
self.process_operand(bb, field, operand, state);
}
}
}
// Transfer the conditions on the copy rhs, after inversing polarity.
Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => {
let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return };
let Some(place) = self.map.find(place.as_ref()) else { return };
let conds = conditions.map(self.arena, Condition::inv);
state.insert_value_idx(place, conds, &self.map);
}
// We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
// Create a condition on `rhs ?= B`.
Rvalue::BinaryOp(
op,
box (Operand::Move(place) | Operand::Copy(place), Operand::Constant(value))
| box (Operand::Constant(value), Operand::Move(place) | Operand::Copy(place)),
) => {
let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return };
let Some(place) = self.map.find(place.as_ref()) else { return };
let equals = match op {
BinOp::Eq => ScalarInt::TRUE,
BinOp::Ne => ScalarInt::FALSE,
_ => return,
};
if value.const_.ty().is_floating_point() {
// Floating point equality does not follow bit-patterns.
// -0.0 and NaN both have special rules for equality,
// and therefore we cannot use integer comparisons for them.
// Avoid handling them, though this could be extended in the future.
return;
}
let Some(value) =
value.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()
else {
return;
};
let conds = conditions.map(self.arena, |c| Condition {
value,
polarity: if c.matches(equals) { Polarity::Eq } else { Polarity::Ne },
..c
});
state.insert_value_idx(place, conds, &self.map);
}
_ => {}
}
}
#[instrument(level = "trace", skip(self))]
fn process_statement(
&mut self,
bb: BasicBlock,
stmt: &Statement<'tcx>,
state: &mut State<ConditionSet<'a>>,
) {
let register_opportunity = |c: Condition| {
debug!(?bb, ?c.target, "register");
self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
};
// Below, `lhs` is the return value of `mutated_statement`,
// the place to which `conditions` apply.
match &stmt.kind {
// If we expect `discriminant(place) ?= A`,
// we have an opportunity if `variant_index ?= A`.
StatementKind::SetDiscriminant { box place, variant_index } => {
let Some(discr_target) = self.map.find_discr(place.as_ref()) else { return };
let enum_ty = place.ty(self.body, self.tcx).ty;
// `SetDiscriminant` may be a no-op if the assigned variant is the untagged variant
// of a niche encoding. If we cannot ensure that we write to the discriminant, do
// nothing.
let Ok(enum_layout) = self.ecx.layout_of(enum_ty) else { return };
let writes_discriminant = match enum_layout.variants {
Variants::Single { index } => {
assert_eq!(index, *variant_index);
true
}
Variants::Multiple { tag_encoding: TagEncoding::Direct, .. } => true,
Variants::Multiple {
tag_encoding: TagEncoding::Niche { untagged_variant, .. },
..
} => *variant_index != untagged_variant,
};
if writes_discriminant {
let Ok(discr) = self.ecx.discriminant_for_variant(enum_ty, *variant_index)
else {
return;
};
self.process_immediate(bb, discr_target, discr, state);
}
}
// If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(
Operand::Copy(place) | Operand::Move(place),
)) => {
let Some(conditions) = state.try_get(place.as_ref(), &self.map) else { return };
conditions.iter_matches(ScalarInt::TRUE).for_each(register_opportunity);
}
StatementKind::Assign(box (lhs_place, rhs)) => {
self.process_assign(bb, lhs_place, rhs, state);
}
_ => {}
}
}
#[instrument(level = "trace", skip(self, state, cost))]
fn recurse_through_terminator(
&mut self,
bb: BasicBlock,
// Pass a closure that may clone the state, as we don't want to do it each time.
state: impl FnOnce() -> State<ConditionSet<'a>>,
cost: &CostChecker<'_, 'tcx>,
depth: usize,
) {
let term = self.body.basic_blocks[bb].terminator();
let place_to_flood = match term.kind {
// We come from a target, so those are not possible.
TerminatorKind::UnwindResume
| TerminatorKind::UnwindTerminate(_)
| TerminatorKind::Return
| TerminatorKind::TailCall { .. }
| TerminatorKind::Unreachable
| TerminatorKind::CoroutineDrop => bug!("{term:?} has no terminators"),
// Disallowed during optimizations.
TerminatorKind::FalseEdge { .. }
| TerminatorKind::FalseUnwind { .. }
| TerminatorKind::Yield { .. } => bug!("{term:?} invalid"),
// Cannot reason about inline asm.
TerminatorKind::InlineAsm { .. } => return,
// `SwitchInt` is handled specially.
TerminatorKind::SwitchInt { .. } => return,
// We can recurse, no thing particular to do.
TerminatorKind::Goto { .. } => None,
// Flood the overwritten place, and progress through.
TerminatorKind::Drop { place: destination, .. }
| TerminatorKind::Call { destination, .. } => Some(destination),
// Ignore, as this can be a no-op at codegen time.
TerminatorKind::Assert { .. } => None,
};
// We can recurse through this terminator.
let mut state = state();
if let Some(place_to_flood) = place_to_flood {
state.flood_with(place_to_flood.as_ref(), &self.map, ConditionSet::BOTTOM);
}
self.find_opportunity(bb, state, cost.clone(), depth + 1);
}
#[instrument(level = "trace", skip(self))]
fn process_switch_int(
&mut self,
discr: &Operand<'tcx>,
targets: &SwitchTargets,
target_bb: BasicBlock,
state: &mut State<ConditionSet<'a>>,
) {
debug_assert_ne!(target_bb, START_BLOCK);
debug_assert_eq!(self.body.basic_blocks.predecessors()[target_bb].len(), 1);
let Some(discr) = discr.place() else { return };
let discr_ty = discr.ty(self.body, self.tcx).ty;
let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else { return };
let Some(conditions) = state.try_get(discr.as_ref(), &self.map) else { return };
if let Some((value, _)) = targets.iter().find(|&(_, target)| target == target_bb) {
let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
debug_assert_eq!(targets.iter().filter(|&(_, target)| target == target_bb).count(), 1);
// We are inside `target_bb`. Since we have a single predecessor, we know we passed
// through the `SwitchInt` before arriving here. Therefore, we know that
// `discr == value`. If one condition can be fulfilled by `discr == value`,
// that's an opportunity.
for c in conditions.iter_matches(value) {
debug!(?target_bb, ?c.target, "register");
self.opportunities.push(ThreadingOpportunity { chain: vec![], target: c.target });
}
} else if let Some((value, _, else_bb)) = targets.as_static_if()
&& target_bb == else_bb
{
let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
// We only know that `discr != value`. That's much weaker information than
// the equality we had in the previous arm. All we can conclude is that
// the replacement condition `discr != value` can be threaded, and nothing else.
for c in conditions.iter() {
if c.value == value && c.polarity == Polarity::Ne {
debug!(?target_bb, ?c.target, "register");
self.opportunities
.push(ThreadingOpportunity { chain: vec![], target: c.target });
}
}
}
}
}
struct OpportunitySet {
opportunities: Vec<ThreadingOpportunity>,
/// For each bb, give the TOs in which it appears. The pair corresponds to the index
/// in `opportunities` and the index in `ThreadingOpportunity::chain`.
involving_tos: IndexVec<BasicBlock, Vec<(usize, usize)>>,
/// Cache the number of predecessors for each block, as we clear the basic block cache..
predecessors: IndexVec<BasicBlock, usize>,
}
impl OpportunitySet {
fn new(body: &Body<'_>, opportunities: Vec<ThreadingOpportunity>) -> OpportunitySet {
let mut involving_tos = IndexVec::from_elem(Vec::new(), &body.basic_blocks);
for (index, to) in opportunities.iter().enumerate() {
for (ibb, &bb) in to.chain.iter().enumerate() {
involving_tos[bb].push((index, ibb));
}
involving_tos[to.target].push((index, to.chain.len()));
}
let predecessors = predecessor_count(body);
OpportunitySet { opportunities, involving_tos, predecessors }
}
/// Apply the opportunities on the graph.
fn apply(&mut self, body: &mut Body<'_>) {
for i in 0..self.opportunities.len() {
self.apply_once(i, body);
}
}
#[instrument(level = "trace", skip(self, body))]
fn apply_once(&mut self, index: usize, body: &mut Body<'_>) {
debug!(?self.predecessors);
debug!(?self.involving_tos);
// Check that `predecessors` satisfies its invariant.
debug_assert_eq!(self.predecessors, predecessor_count(body));
// Remove the TO from the vector to allow modifying the other ones later.
let op = &mut self.opportunities[index];
debug!(?op);
let op_chain = std::mem::take(&mut op.chain);
let op_target = op.target;
debug_assert_eq!(op_chain.len(), op_chain.iter().collect::<FxHashSet<_>>().len());
let Some((current, chain)) = op_chain.split_first() else { return };
let basic_blocks = body.basic_blocks.as_mut();
// Invariant: the control-flow is well-formed at the end of each iteration.
let mut current = *current;
for &succ in chain {
debug!(?current, ?succ);
// `succ` must be a successor of `current`. If it is not, this means this TO is not
// satisfiable and a previous TO erased this edge, so we bail out.
if !basic_blocks[current].terminator().successors().any(|s| s == succ) {
debug!("impossible");
return;
}
// Fast path: `succ` is only used once, so we can reuse it directly.
if self.predecessors[succ] == 1 {
debug!("single");
current = succ;
continue;
}
let new_succ = basic_blocks.push(basic_blocks[succ].clone());
debug!(?new_succ);
// Replace `succ` by `new_succ` where it appears.
let mut num_edges = 0;
for s in basic_blocks[current].terminator_mut().successors_mut() {
if *s == succ {
*s = new_succ;
num_edges += 1;
}
}
// Update predecessors with the new block.
let _new_succ = self.predecessors.push(num_edges);
debug_assert_eq!(new_succ, _new_succ);
self.predecessors[succ] -= num_edges;
self.update_predecessor_count(basic_blocks[new_succ].terminator(), Update::Incr);
// Replace the `current -> succ` edge by `current -> new_succ` in all the following
// TOs. This is necessary to avoid trying to thread through a non-existing edge. We
// use `involving_tos` here to avoid traversing the full set of TOs on each iteration.
let mut new_involved = Vec::new();
for &(to_index, in_to_index) in &self.involving_tos[current] {
// That TO has already been applied, do nothing.
if to_index <= index {
continue;
}
let other_to = &mut self.opportunities[to_index];
if other_to.chain.get(in_to_index) != Some(&current) {
continue;
}
let s = other_to.chain.get_mut(in_to_index + 1).unwrap_or(&mut other_to.target);
if *s == succ {
// `other_to` references the `current -> succ` edge, so replace `succ`.
*s = new_succ;
new_involved.push((to_index, in_to_index + 1));
}
}
// The TOs that we just updated now reference `new_succ`. Update `involving_tos`
// in case we need to duplicate an edge starting at `new_succ` later.
let _new_succ = self.involving_tos.push(new_involved);
debug_assert_eq!(new_succ, _new_succ);
current = new_succ;
}
let current = &mut basic_blocks[current];
self.update_predecessor_count(current.terminator(), Update::Decr);
current.terminator_mut().kind = TerminatorKind::Goto { target: op_target };
self.predecessors[op_target] += 1;
}
fn update_predecessor_count(&mut self, terminator: &Terminator<'_>, incr: Update) {
match incr {
Update::Incr => {
for s in terminator.successors() {
self.predecessors[s] += 1;
}
}
Update::Decr => {
for s in terminator.successors() {
self.predecessors[s] -= 1;
}
}
}
}
}
fn predecessor_count(body: &Body<'_>) -> IndexVec<BasicBlock, usize> {
let mut predecessors: IndexVec<_, _> =
body.basic_blocks.predecessors().iter().map(|ps| ps.len()).collect();
predecessors[START_BLOCK] += 1; // Account for the implicit entry edge.
predecessors
}
enum Update {
Incr,
Decr,
}
/// Compute the set of loop headers in the given body. We define a loop header as a block which has
/// at least a predecessor which it dominates. This definition is only correct for reducible CFGs.
/// But if the CFG is already irreducible, there is no point in trying much harder.
/// is already irreducible.
fn loop_headers(body: &Body<'_>) -> BitSet<BasicBlock> {
let mut loop_headers = BitSet::new_empty(body.basic_blocks.len());
let dominators = body.basic_blocks.dominators();
// Only visit reachable blocks.
for (bb, bbdata) in traversal::preorder(body) {
for succ in bbdata.terminator().successors() {
if dominators.dominates(succ, bb) {
loop_headers.insert(succ);
}
}
}
loop_headers
}