Unroll switch ranges if possible.

This commit is contained in:
Stephen Chung
2022-07-18 08:54:10 +08:00
parent 107193e35f
commit 4b760d1d0f
5 changed files with 132 additions and 60 deletions

View File

@@ -525,7 +525,7 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
let (
match_expr,
SwitchCases {
blocks: blocks_list,
blocks,
cases,
ranges,
def_case,
@@ -538,29 +538,29 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
let hash = hasher.finish();
// First check hashes
if let Some(block) = cases.remove(&hash) {
let mut block = mem::take(&mut blocks_list[block]);
if let Some(b) = cases.remove(&hash) {
let mut b = mem::take(&mut blocks[b]);
cases.clear();
match block.condition {
match b.condition {
Expr::BoolConstant(true, ..) => {
// Promote the matched case
let statements: StmtBlockContainer = mem::take(&mut block.statements);
let statements: StmtBlockContainer = mem::take(&mut b.statements);
let statements = optimize_stmt_block(statements, state, true, true, false);
*stmt = (statements, block.statements.span()).into();
*stmt = (statements, b.statements.span()).into();
}
ref mut condition => {
// switch const { case if condition => stmt, _ => def } => if condition { stmt } else { def }
optimize_expr(condition, state, false);
let def_case = &mut blocks_list[*def_case].statements;
let def_case = &mut blocks[*def_case].statements;
let def_span = def_case.span_or_else(*pos, Position::NONE);
let def_case: StmtBlockContainer = mem::take(def_case);
let def_stmt = optimize_stmt_block(def_case, state, true, true, false);
*stmt = Stmt::If(
(
mem::take(condition),
mem::take(&mut block.statements),
mem::take(&mut b.statements),
StmtBlock::new_with_span(def_stmt, def_span),
)
.into(),
@@ -580,19 +580,16 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
// Only one range or all ranges without conditions
if ranges.len() == 1
|| ranges.iter().all(|r| {
matches!(
blocks_list[r.index()].condition,
Expr::BoolConstant(true, ..)
)
matches!(blocks[r.index()].condition, Expr::BoolConstant(true, ..))
})
{
for r in ranges.iter().filter(|r| r.contains(value)) {
let condition = mem::take(&mut blocks_list[r.index()].condition);
let condition = mem::take(&mut blocks[r.index()].condition);
match condition {
Expr::BoolConstant(true, ..) => {
// Promote the matched case
let block = &mut blocks_list[r.index()];
let block = &mut blocks[r.index()];
let statements = mem::take(&mut *block.statements);
let statements =
optimize_stmt_block(statements, state, true, true, false);
@@ -602,13 +599,13 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
// switch const { range if condition => stmt, _ => def } => if condition { stmt } else { def }
optimize_expr(&mut condition, state, false);
let def_case = &mut blocks_list[*def_case].statements;
let def_case = &mut blocks[*def_case].statements;
let def_span = def_case.span_or_else(*pos, Position::NONE);
let def_case: StmtBlockContainer = mem::take(def_case);
let def_stmt =
optimize_stmt_block(def_case, state, true, true, false);
let statements = mem::take(&mut blocks_list[r.index()].statements);
let statements = mem::take(&mut blocks[r.index()].statements);
*stmt = Stmt::If(
(
@@ -641,16 +638,16 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
}
for r in &*ranges {
let block = &mut blocks_list[r.index()];
let statements = mem::take(&mut *block.statements);
*block.statements =
let b = &mut blocks[r.index()];
let statements = mem::take(&mut *b.statements);
*b.statements =
optimize_stmt_block(statements, state, preserve_result, true, false);
optimize_expr(&mut block.condition, state, false);
optimize_expr(&mut b.condition, state, false);
match block.condition {
match b.condition {
Expr::Unit(pos) => {
block.condition = Expr::BoolConstant(true, pos);
b.condition = Expr::BoolConstant(true, pos);
state.set_dirty()
}
_ => (),
@@ -662,7 +659,7 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
// Promote the default case
state.set_dirty();
let def_case = &mut blocks_list[*def_case].statements;
let def_case = &mut blocks[*def_case].statements;
let def_span = def_case.span_or_else(*pos, Position::NONE);
let def_case: StmtBlockContainer = mem::take(def_case);
let def_stmt = optimize_stmt_block(def_case, state, true, true, false);
@@ -673,7 +670,7 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
let (
match_expr,
SwitchCases {
blocks: blocks_list,
blocks,
cases,
ranges,
def_case,
@@ -684,21 +681,21 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
optimize_expr(match_expr, state, false);
// Optimize blocks
for block in blocks_list.iter_mut() {
let statements = mem::take(&mut *block.statements);
*block.statements =
for b in blocks.iter_mut() {
let statements = mem::take(&mut *b.statements);
*b.statements =
optimize_stmt_block(statements, state, preserve_result, true, false);
optimize_expr(&mut block.condition, state, false);
optimize_expr(&mut b.condition, state, false);
match block.condition {
match b.condition {
Expr::Unit(pos) => {
block.condition = Expr::BoolConstant(true, pos);
b.condition = Expr::BoolConstant(true, pos);
state.set_dirty();
}
Expr::BoolConstant(false, ..) => {
if !block.statements.is_empty() {
block.statements = StmtBlock::NONE;
if !b.statements.is_empty() {
b.statements = StmtBlock::NONE;
state.set_dirty();
}
}
@@ -707,7 +704,7 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
}
// Remove false cases
cases.retain(|_, &mut block| match blocks_list[block].condition {
cases.retain(|_, &mut block| match blocks[block].condition {
Expr::BoolConstant(false, ..) => {
state.set_dirty();
false
@@ -715,7 +712,7 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
_ => true,
});
// Remove false ranges
ranges.retain(|r| match blocks_list[r.index()].condition {
ranges.retain(|r| match blocks[r.index()].condition {
Expr::BoolConstant(false, ..) => {
state.set_dirty();
false
@@ -723,9 +720,26 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
_ => true,
});
let def_case = &mut blocks_list[*def_case].statements;
let def_block = mem::take(&mut **def_case);
**def_case = optimize_stmt_block(def_block, state, preserve_result, true, false);
let def_stmt_block = &mut blocks[*def_case].statements;
let def_block = mem::take(&mut **def_stmt_block);
**def_stmt_block = optimize_stmt_block(def_block, state, preserve_result, true, false);
// Remove unused block statements
for index in 0..blocks.len() {
if *def_case == index
|| cases.values().any(|&n| n == index)
|| ranges.iter().any(|r| r.index() == index)
{
continue;
}
let b = &mut blocks[index];
if !b.statements.is_empty() {
b.statements = StmtBlock::NONE;
state.set_dirty();
}
}
}
// while false { block } -> Noop